meerqat.train.trainee module#

Trainee is a pl.LightningModule that computes the loss so it is compatible with Trainer.

class meerqat.train.trainee.Trainee(*args, freeze_regex=None, gradient_checkpointing=False, warmup_steps=0, lr=2e-05, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, output_cpu=False, **kwargs)[source]#

Bases: LightningModule

Base class for all Trainee models (to be trained by a Trainer)

  • *args (additionnal arguments are passed to pl.LightningModule) –

  • **kwargs (additionnal arguments are passed to pl.LightningModule) –

  • freeze_regex (str, optional) – represents a regex used to match the model parameters to freeze (i.e. set requires_grad = False). Defaults to None (keep model fully-trainable)

  • gradient_checkpointing (bool, optional) –

  • lr (float, optional) –

  • eps (float, optional) –

  • weight_decay (float, optional) –

  • betas (Tuple[float], optional) –

  • warmup_steps (int, optional) – Defaults to no warm-up

  • output_cpu (bool, optional) –

step(batch, batch_idx)[source]#
eval_step(batch, batch_idx)[source]#
log(name, value, **kwargs)[source]#

Ignores None values.

training_step(batch, batch_idx)[source]#

Step and log training metrics

validation_step(batch, batch_idx)[source]#

Step and log validation metrics

test_step(batch, batch_idx)[source]#

Step and log test metrics

validation_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log

test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log


Overrides freeze to freeze only parameters that match the regex. Caveat: does not call .eval() so does not disable Dropout


Prepare optimizer and schedule (linear warmup and decay)


Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.


Deactivates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

property is_gradient_checkpointing: bool#

Whether gradient checkpointing is activated for this model or not. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

class meerqat.train.trainee.CrossModal(*args, model_kwargs: dict, **kwargs)[source]#

Bases: Trainee

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

step(batch, batch_idx)[source]#
eval_step(inputs, batch_idx)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.JointMonoAndCrossModal(*args, model_kwargs: dict, image_weight=0.5, cm_weight=0.5, learn_weights=False, mm_weights_lr=None, **kwargs)[source]#

Bases: Trainee

forward(input_ids: Optional[LongTensor] = None, pixel_values: Optional[FloatTensor] = None, paired_pixel_values: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None)[source]#

Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

step(inputs, _)[source]#
eval_step(inputs, batch_idx)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.BiEncoder(*args, question_class, question_model_name_or_path, context_class=None, context_model_name_or_path=None, question_kwargs={}, context_kwargs=None, superclass=False, **kwargs)[source]#

Bases: Trainee

The training objective is to minimize the negative log-likelihood of the similarities (dot product) between the questions and the passages embeddings, as described in [3].


  • *args (additionnal arguments are passed to Trainee) –

  • **kwargs (additionnal arguments are passed to Trainee) –

  • question_class (str) – Name of the class used for question_model. See get_class_from_name.

  • question_model_name_or_path (str) – Passed to from_pretrained. See transformers.PreTrainedModel.from_pretrained

  • context_class (str, optional) – Analog to question_class for context_model. Defaults to question_class. If ‘shared’, then use the same model to encode questions and passages. Will set shared_encoders=True

  • context_model_name_or_path (str) – Analog to question_model_name_or_path for context_model. Defaults to question_model_name_or_path.

  • question_kwargs (dict, optional) –

  • context_kwargs (dict, optional) –

  • superclass (bool, optional) – Means that BiEncoder is instantiated from a subclass. Disables post_init. Defaults to False.

forward(question_inputs, context_inputs)[source]#
  • question_inputs (dict) – passed to the respective encoder

  • context_inputs (dict) – passed to the respective encoder

step(inputs, _)[source]#

Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.

Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390


This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.

eval_step(inputs, batch_idx)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.JointBiEncoderAndClip(*args, clip, question_weight=0.3333333333333333, image_weight=0.3333333333333333, cm_weight=0.3333333333333333, learn_weights=False, clip_lr=None, mm_weights_lr=None, **kwargs)[source]#

Bases: BiEncoder

forward(question_inputs, context_inputs)[source]#
  • question_inputs (dict) – passed to the respective encoder

  • context_inputs (dict) – passed to the respective encoder

step(inputs, _)[source]#

Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.

Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390


This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.

eval_step(inputs, batch_idx)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.ReRanker(*args, model_kwargs, metric_kwargs={}, **kwargs)[source]#

Bases: Trainee

  • model_kwargs (dict[str, str]) – Passed to get_pretrained

  • metric_kwargs (dict[str, str], optional) – Passed to ranx.evaluate to compute metrics during evaluation

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

step(inputs, _)[source]#
test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end, log and save run

save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.Reader(*args, model_kwargs, tune_M=False, **kwargs)[source]#

Bases: Trainee

  • model_kwargs (dict[str, str]) – Passed to get_pretrained

  • tune_M (bool, optional) – Instead of extracting answers from the top-M input passages, try every value in {2^i, s.t. 2^i <= M} Defaults to False (use only self.trainer.datamodule.M)

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

step(inputs, _)[source]#
eval_step(inputs, batch_idx)[source]#
log_probs_to_answers(start_log_probs, end_log_probs, input_ids, **kwargs)[source]#

“” 1. get span start and end positions from log-probabilities 2. extract actual tokens (answer) from input_ids

test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log

M_tuning(all_start_log_probs, all_end_log_probs, all_input_ids, all_answer_strings, all_passage_scores=None)[source]#
save_pretrained(ckpt_path, bert=False)[source]#