"""Metrics to be used in trainer."""
import warnings
from collections import Counter
import ranx
from ..data.loading import answer_preprocess
[docs]def accumulate_batch_metrics(batch_metrics):
metrics = Counter()
for metric in batch_metrics:
for k, v in metric.items():
metrics[k] += v
effective_size = metrics.pop("batch_size") - metrics.pop("ignored_predictions", 0)
for k, v in metrics.items():
metrics[k] = v/effective_size
return metrics
# TODO https://torchmetrics.readthedocs.io/en/stable/retrieval/mrr.html
[docs]def batch_retrieval(log_probs, labels, ignore_index=-100):
mrr, hits_at_1, ignored_predictions = 0, 0, 0
batch_size, _ = log_probs.shape
# use argsort to rank the passages w.r.t. their log-probability (`-` to sort in desc. order)
rankings = (-log_probs).argsort(axis=1)
for ranking, label in zip(rankings, labels):
if label == ignore_index:
ignored_predictions += 1
continue
if ranking[0] == label:
hits_at_1 += 1
# +1 to count from 1 instead of 0
rank = (ranking == label).nonzero()[0].item() + 1
mrr += 1/rank
return {"MRR@N*M": mrr, "hits@1": hits_at_1,
"ignored_predictions": ignored_predictions, "batch_size": batch_size}
[docs]def retrieval(eval_outputs, ignore_index=-100, output_key='log_probs'):
"""
Computes metric for retrieval training (at the batch-level)
Parameters
----------
eval_outputs: List[dict[str, Tensor]]
Contains log_probs and labels for all batches in the evaluation step (either validation or test)
ignore_index: int, optional
Labels with this value are not taken into account when computing metrics.
Defaults to -100
output_key: str, optional
Name of the model output in eval_outputs
"""
metrics = {}
mrr, hits_at_1, ignored_predictions, dataset_size = 0, 0, 0, 0
for batch in eval_outputs:
log_probs = batch[output_key].numpy()
labels = batch['labels'].numpy()
batch_size, _ = log_probs.shape
dataset_size += batch_size
# use argsort to rank the passages w.r.t. their log-probability (`-` to sort in desc. order)
rankings = (-log_probs).argsort(axis=1)
for ranking, label in zip(rankings, labels):
if label == ignore_index:
ignored_predictions += 1
continue
if ranking[0] == label:
hits_at_1 += 1
# +1 to count from 1 instead of 0
rank = (ranking == label).nonzero()[0].item() + 1
mrr += 1/rank
metrics["MRR@N*M"] = mrr / (dataset_size-ignored_predictions)
metrics["hits@1"] = hits_at_1 / (dataset_size-ignored_predictions)
return metrics
[docs]def get_run(eval_outputs, ir_run):
"""
Parameters
----------
eval_outputs: List[dict[str, Tensor]]
Contains logits for all batches in the evaluation step (either validation or test)
ir_run: ranx.Run
Original IR run being re-ranked.
"""
run = {}
for batch in eval_outputs:
logits = batch['logits'].numpy()
N, M = logits.shape
question_ids = [batch['ids'][i] for i in range(0, N*M, M)]
rankings = (-logits).argsort(axis=1)
for ranking, logit, question_id in zip(rankings, logits, question_ids):
ir_results = ir_run.run[question_id]
# nothing to re-rank.
# this can happen if searching for something unavailable in the query
# e.g. no face was detected but you are searching for face similarity (see ir.search)
if not ir_results:
run[question_id] = ir_results
else:
doc_ids = list(ir_results.keys())[: M]
run[question_id] = {doc_ids[i]: logit[i] for i in ranking}
return ranx.Run(run)
[docs]def f1_score(prediction, ground_truth):
prediction_tokens = answer_preprocess(prediction).split()
ground_truth_tokens = answer_preprocess(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
[docs]def exact_match_score(prediction, ground_truth):
return answer_preprocess(prediction) == answer_preprocess(ground_truth)
[docs]def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
[docs]def squad(predictions, references):
"""
Adapted from datasets.load_metric('squad')
Parameters
----------
predictions: List[str]
references: List[List[str]]
Returns
-------
metrics: dict[str, float]
"""
assert len(predictions) == len(references)
f1, exact_match = 0, 0
for prediction, ground_truths in zip(predictions, references):
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = exact_match / len(references)
f1 = f1 / len(references)
return {"exact_match": exact_match, "f1": f1}
[docs]def squad_per_question(predictions, references):
"""
Returns the score of the metrics for each question instead of averaging like squad.
Keep different implementation because squad should in principle be loaded from datasets.
This should allow for stastitical significant testing downstream.
Parameters
----------
predictions: List[str]
references: List[List[str]]
Returns
-------
metrics: dict[str, List[float]]
"""
assert len(predictions) == len(references)
f1, exact_match = [], []
for prediction, ground_truths in zip(predictions, references):
exact_match.append(metric_max_over_ground_truths(exact_match_score, prediction, ground_truths))
f1.append(metric_max_over_ground_truths(f1_score, prediction, ground_truths))
return {"exact_match": exact_match, "f1": f1}