meerqat.train.metrics module#
Metrics to be used in trainer.
- meerqat.train.metrics.retrieval(eval_outputs, ignore_index=- 100, output_key='log_probs')[source]#
Computes metric for retrieval training (at the batch-level)
- Parameters:
eval_outputs (List[dict[str, Tensor]]) – Contains log_probs and labels for all batches in the evaluation step (either validation or test)
ignore_index (int, optional) – Labels with this value are not taken into account when computing metrics. Defaults to -100
output_key (str, optional) – Name of the model output in eval_outputs
- meerqat.train.metrics.get_run(eval_outputs, ir_run)[source]#
- Parameters:
eval_outputs (List[dict[str, Tensor]]) – Contains logits for all batches in the evaluation step (either validation or test)
ir_run (ranx.Run) – Original IR run being re-ranked.
- meerqat.train.metrics.squad(predictions, references)[source]#
Adapted from datasets.load_metric(‘squad’)
- Parameters:
predictions (List[str]) –
references (List[List[str]]) –
- Returns:
metrics
- Return type:
dict[str, float]
- meerqat.train.metrics.squad_per_question(predictions, references)[source]#
Returns the score of the metrics for each question instead of averaging like squad. Keep different implementation because squad should in principle be loaded from datasets. This should allow for stastitical significant testing downstream.
- Parameters:
predictions (List[str]) –
references (List[List[str]]) –
- Returns:
metrics
- Return type:
dict[str, List[float]]