meerqat.train.trainee module#
Trainee is a pl.LightningModule that computes the loss so it is compatible with Trainer.
- class meerqat.train.trainee.Trainee(*args, freeze_regex=None, gradient_checkpointing=False, warmup_steps=0, lr=2e-05, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, output_cpu=False, **kwargs)[source]#
Bases:
LightningModule
Base class for all Trainee models (to be trained by a Trainer)
- Parameters:
*args (additionnal arguments are passed to pl.LightningModule) –
**kwargs (additionnal arguments are passed to pl.LightningModule) –
freeze_regex (str, optional) – represents a regex used to match the model parameters to freeze (i.e. set
requires_grad = False
). Defaults to None (keep model fully-trainable)gradient_checkpointing (bool, optional) –
lr (float, optional) –
eps (float, optional) –
weight_decay (float, optional) –
betas (Tuple[float], optional) –
warmup_steps (int, optional) – Defaults to no warm-up
output_cpu (bool, optional) –
- freeze(regex)[source]#
Overrides freeze to freeze only parameters that match the regex. Caveat: does not call .eval() so does not disable Dropout
- gradient_checkpointing_enable()[source]#
Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.
- gradient_checkpointing_disable()[source]#
Deactivates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.
- property is_gradient_checkpointing: bool#
Whether gradient checkpointing is activated for this model or not. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.
- class meerqat.train.trainee.CrossModal(*args, model_kwargs: dict, **kwargs)[source]#
Bases:
Trainee
- class meerqat.train.trainee.JointMonoAndCrossModal(*args, model_kwargs: dict, image_weight=0.5, cm_weight=0.5, learn_weights=False, mm_weights_lr=None, **kwargs)[source]#
Bases:
Trainee
- forward(input_ids: Optional[LongTensor] = None, pixel_values: Optional[FloatTensor] = None, paired_pixel_values: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None)[source]#
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- class meerqat.train.trainee.BiEncoder(*args, question_class, question_model_name_or_path, context_class=None, context_model_name_or_path=None, question_kwargs={}, context_kwargs=None, superclass=False, **kwargs)[source]#
Bases:
Trainee
The training objective is to minimize the negative log-likelihood of the similarities (dot product) between the questions and the passages embeddings, as described in [3].
References
- Parameters:
*args (additionnal arguments are passed to Trainee) –
**kwargs (additionnal arguments are passed to Trainee) –
question_class (str) – Name of the class used for question_model. See get_class_from_name.
question_model_name_or_path (str) – Passed to from_pretrained. See transformers.PreTrainedModel.from_pretrained
context_class (str, optional) – Analog to question_class for context_model. Defaults to question_class. If ‘shared’, then use the same model to encode questions and passages. Will set shared_encoders=True
context_model_name_or_path (str) – Analog to question_model_name_or_path for context_model. Defaults to question_model_name_or_path.
question_kwargs (dict, optional) –
context_kwargs (dict, optional) –
superclass (bool, optional) – Means that BiEncoder is instantiated from a subclass. Disables post_init. Defaults to False.
- forward(question_inputs, context_inputs)[source]#
- Parameters:
question_inputs (dict) – passed to the respective encoder
context_inputs (dict) – passed to the respective encoder
- step(inputs, _)[source]#
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.
Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390
Notes
This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.
- class meerqat.train.trainee.JointBiEncoderAndClip(*args, clip, question_weight=0.3333333333333333, image_weight=0.3333333333333333, cm_weight=0.3333333333333333, learn_weights=False, clip_lr=None, mm_weights_lr=None, **kwargs)[source]#
Bases:
BiEncoder
- forward(question_inputs, context_inputs)[source]#
- Parameters:
question_inputs (dict) – passed to the respective encoder
context_inputs (dict) – passed to the respective encoder
- step(inputs, _)[source]#
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.
Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390
Notes
This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.
- class meerqat.train.trainee.ReRanker(*args, model_kwargs, metric_kwargs={}, **kwargs)[source]#
Bases:
Trainee
- Parameters:
model_kwargs (dict[str, str]) – Passed to get_pretrained
metric_kwargs (dict[str, str], optional) – Passed to ranx.evaluate to compute metrics during evaluation
- class meerqat.train.trainee.Reader(*args, model_kwargs, tune_M=False, **kwargs)[source]#
Bases:
Trainee
- Parameters:
model_kwargs (dict[str, str]) – Passed to get_pretrained
tune_M (bool, optional) – Instead of extracting answers from the top-M input passages, try every value in {2^i, s.t. 2^i <= M} Defaults to False (use only self.trainer.datamodule.M)
- forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- log_probs_to_answers(start_log_probs, end_log_probs, input_ids, **kwargs)[source]#
“” 1. get span start and end positions from log-probabilities 2. extract actual tokens (answer) from input_ids