meerqat.train.metrics module#

Metrics to be used in trainer.

meerqat.train.metrics.accumulate_batch_metrics(batch_metrics)[source]#
meerqat.train.metrics.batch_retrieval(log_probs, labels, ignore_index=- 100)[source]#
meerqat.train.metrics.retrieval(eval_outputs, ignore_index=- 100, output_key='log_probs')[source]#

Computes metric for retrieval training (at the batch-level)

Parameters:
  • eval_outputs (List[dict[str, Tensor]]) – Contains log_probs and labels for all batches in the evaluation step (either validation or test)

  • ignore_index (int, optional) – Labels with this value are not taken into account when computing metrics. Defaults to -100

  • output_key (str, optional) – Name of the model output in eval_outputs

meerqat.train.metrics.get_run(eval_outputs, ir_run)[source]#
Parameters:
  • eval_outputs (List[dict[str, Tensor]]) – Contains logits for all batches in the evaluation step (either validation or test)

  • ir_run (ranx.Run) – Original IR run being re-ranked.

meerqat.train.metrics.f1_score(prediction, ground_truth)[source]#
meerqat.train.metrics.exact_match_score(prediction, ground_truth)[source]#
meerqat.train.metrics.metric_max_over_ground_truths(metric_fn, prediction, ground_truths)[source]#
meerqat.train.metrics.squad(predictions, references)[source]#

Adapted from datasets.load_metric(‘squad’)

Parameters:
  • predictions (List[str]) –

  • references (List[List[str]]) –

Returns:

metrics

Return type:

dict[str, float]

meerqat.train.metrics.squad_per_question(predictions, references)[source]#

Returns the score of the metrics for each question instead of averaging like squad. Keep different implementation because squad should in principle be loaded from datasets. This should allow for stastitical significant testing downstream.

Parameters:
  • predictions (List[str]) –

  • references (List[List[str]]) –

Returns:

metrics

Return type:

dict[str, List[float]]