meerqat.train.trainee module#

Trainee is a pl.LightningModule that computes the loss so it is compatible with Trainer.

meerqat.train.trainee.batched_cpu(batch)[source]#
class meerqat.train.trainee.Trainee(*args, freeze_regex=None, gradient_checkpointing=False, warmup_steps=0, lr=2e-05, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, output_cpu=False, **kwargs)[source]#

Bases: LightningModule

Base class for all Trainee models (to be trained by a Trainer)

Parameters:
  • *args (additionnal arguments are passed to pl.LightningModule) –

  • **kwargs (additionnal arguments are passed to pl.LightningModule) –

  • freeze_regex (str, optional) – represents a regex used to match the model parameters to freeze (i.e. set requires_grad = False). Defaults to None (keep model fully-trainable)

  • gradient_checkpointing (bool, optional) –

  • lr (float, optional) –

  • eps (float, optional) –

  • weight_decay (float, optional) –

  • betas (Tuple[float], optional) –

  • warmup_steps (int, optional) – Defaults to no warm-up

  • output_cpu (bool, optional) –

post_init()[source]#
step(batch, batch_idx)[source]#
eval_step(batch, batch_idx)[source]#
log(name, value, **kwargs)[source]#

Ignores None values.

training_step(batch, batch_idx)[source]#

Step and log training metrics

validation_step(batch, batch_idx)[source]#

Step and log validation metrics

test_step(batch, batch_idx)[source]#

Step and log test metrics

eval_epoch_end(eval_outputs)[source]#
validation_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log

test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log

freeze(regex)[source]#

Overrides freeze to freeze only parameters that match the regex. Caveat: does not call .eval() so does not disable Dropout

configure_optimizers()[source]#

Prepare optimizer and schedule (linear warmup and decay)

gradient_checkpointing_enable()[source]#

Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

gradient_checkpointing_disable()[source]#

Deactivates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

property is_gradient_checkpointing: bool#

Whether gradient checkpointing is activated for this model or not. Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

class meerqat.train.trainee.CrossModal(*args, model_kwargs: dict, **kwargs)[source]#

Bases: Trainee

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

step(batch, batch_idx)[source]#
eval_step(inputs, batch_idx)[source]#
eval_epoch_end(eval_outputs)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.JointMonoAndCrossModal(*args, model_kwargs: dict, image_weight=0.5, cm_weight=0.5, learn_weights=False, mm_weights_lr=None, **kwargs)[source]#

Bases: Trainee

forward(input_ids: Optional[LongTensor] = None, pixel_values: Optional[FloatTensor] = None, paired_pixel_values: Optional[FloatTensor] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

step(inputs, _)[source]#
eval_step(inputs, batch_idx)[source]#
eval_epoch_end(eval_outputs)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.BiEncoder(*args, question_class, question_model_name_or_path, context_class=None, context_model_name_or_path=None, question_kwargs={}, context_kwargs=None, superclass=False, **kwargs)[source]#

Bases: Trainee

The training objective is to minimize the negative log-likelihood of the similarities (dot product) between the questions and the passages embeddings, as described in [3].

References

Parameters:
  • *args (additionnal arguments are passed to Trainee) –

  • **kwargs (additionnal arguments are passed to Trainee) –

  • question_class (str) – Name of the class used for question_model. See get_class_from_name.

  • question_model_name_or_path (str) – Passed to from_pretrained. See transformers.PreTrainedModel.from_pretrained

  • context_class (str, optional) – Analog to question_class for context_model. Defaults to question_class. If ‘shared’, then use the same model to encode questions and passages. Will set shared_encoders=True

  • context_model_name_or_path (str) – Analog to question_model_name_or_path for context_model. Defaults to question_model_name_or_path.

  • question_kwargs (dict, optional) –

  • context_kwargs (dict, optional) –

  • superclass (bool, optional) – Means that BiEncoder is instantiated from a subclass. Disables post_init. Defaults to False.

forward(question_inputs, context_inputs)[source]#
Parameters:
  • question_inputs (dict) – passed to the respective encoder

  • context_inputs (dict) – passed to the respective encoder

step(inputs, _)[source]#

Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.

Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390

Notes

This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.

eval_step(inputs, batch_idx)[source]#
eval_epoch_end(eval_outputs)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.JointBiEncoderAndClip(*args, clip, question_weight=0.3333333333333333, image_weight=0.3333333333333333, cm_weight=0.3333333333333333, learn_weights=False, clip_lr=None, mm_weights_lr=None, **kwargs)[source]#

Bases: BiEncoder

forward(question_inputs, context_inputs)[source]#
Parameters:
  • question_inputs (dict) – passed to the respective encoder

  • context_inputs (dict) – passed to the respective encoder

step(inputs, _)[source]#

Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations across all the nodes.

Adapted from facebookresearch/DPR and Lightning-AI/lightning#14390

Notes

This means that the whole representations of questions and contexts, and their similarity matrix, must fit on a single GPU.

eval_step(inputs, batch_idx)[source]#
eval_epoch_end(eval_outputs)[source]#
save_pretrained(ckpt_path, bert=False)[source]#
class meerqat.train.trainee.ReRanker(*args, model_kwargs, metric_kwargs={}, **kwargs)[source]#

Bases: Trainee

Parameters:
  • model_kwargs (dict[str, str]) – Passed to get_pretrained

  • metric_kwargs (dict[str, str], optional) – Passed to ranx.evaluate to compute metrics during evaluation

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

step(inputs, _)[source]#
eval_epoch_end(eval_outputs)[source]#
test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end, log and save run

save_pretrained(ckpt_path, bert=False)[source]#
meerqat.train.trainee.power_range(maximum)[source]#
class meerqat.train.trainee.Reader(*args, model_kwargs, tune_M=False, **kwargs)[source]#

Bases: Trainee

Parameters:
  • model_kwargs (dict[str, str]) – Passed to get_pretrained

  • tune_M (bool, optional) – Instead of extracting answers from the top-M input passages, try every value in {2^i, s.t. 2^i <= M} Defaults to False (use only self.trainer.datamodule.M)

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

step(inputs, _)[source]#
eval_step(inputs, batch_idx)[source]#
log_probs_to_answers(start_log_probs, end_log_probs, input_ids, **kwargs)[source]#

“” 1. get span start and end positions from log-probabilities 2. extract actual tokens (answer) from input_ids

eval_epoch_end(eval_outputs)[source]#
test_epoch_end(*args, **kwargs)[source]#

eval_epoch_end and log

M_tuning(all_start_log_probs, all_end_log_probs, all_input_ids, all_answer_strings, all_passage_scores=None)[source]#
save_pretrained(ckpt_path, bert=False)[source]#