Source code for meerqat.data.infoseek

# -*- coding: utf-8 -*-
# mostly taken from https://github.com/edchengg/infoseek_eval/blob/main/infoseek_eval.py
# slightly refactored using meerqat.train.metrics and enum

import re
import json
from typing import Any, Dict, Generator, List, Tuple, Union
import enum
from jsonargparse import CLI

import pandas as pd

from datasets import load_from_disk, Dataset
from ..train.metrics import exact_match_score, metric_max_over_ground_truths


[docs]class QuestionType(enum.Enum): String = 0 Numerical = 1 Time = 2
[docs]def in_range(number: float, range_list: Tuple[float, float]) -> bool: """Check if a number is within the specified range (inclusive).""" min_num, max_num = range_list return min_num <= number <= max_num
[docs]def safe_division(x: float, y: float) -> float: """Divide x by y, returning 0 if y is 0.""" return x / y if y != 0 else 0
[docs]def metric_numerical_range( pred: Union[float, Tuple[float, float], List[float]], answer: Union[float, Tuple[float, float], List[float]], tolerance: float = 0.1, ) -> int: """Scores numerical questions based on ranges and tolerances. 1) First, convert single number answer to a range with +/- tolerance. 2) If prediction is a single number, return 1 if it's in the answer range, 0 otherwise. 3) If prediction is a range, return 1 if the range is in the answer range or if the IOU (overlap between prediction and answer range) > 0.5, 0 otherwise. Args: pred: A list/tuple of 2 numbers or a single number. answer: A list/tuple of 2 numbers or a single number. tolerance: A float value for the tolerance range (default: 0.1). Returns: int: 1 if conditions are met, 0 otherwise. """ answer = list(answer) if isinstance(answer, tuple) else answer pred = list(pred) if isinstance(pred, tuple) else pred if not isinstance(answer, list): answer = [answer * (1 - tolerance), answer * (1 + tolerance)] # Prediction is a single number if not isinstance(pred, list): return 1 if in_range(pred, answer) else 0 # Prediction is a range if answer[0] <= pred[0] <= answer[1] and answer[0] <= pred[1] <= answer[1]: return 1 else: iou = range_intersection_over_union(pred, answer) return 1 if iou >= 0.5 - 1e-12 else 0
[docs]def find_numbers(string_number: str) -> List[float]: # Clean string string_number = clean_str_range(string_number) numerical_numbers_tmp = re.findall( r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', string_number ) numerical_numbers = [] for n in numerical_numbers_tmp: n = n.replace(',', '').strip('.') if n.count('.') > 1: n = n.split('.')[0] numerical_numbers.append(float(n)) else: numerical_numbers.append(float(n)) return numerical_numbers, numerical_numbers_tmp
[docs]def process_numerical_answer(string_number: str) -> Union[float, List[float]]: """Parses numerical answer string into numbers (a single number or a range). 1) Clean the string and extract numbers; 2) if there are 2 numbers, return a range as [minimum value, maximum value] else if there is 1 number, return a single number else return [0, 0] Args: string_number: A string representing a numerical answer. Returns: A single digit or a list with 2 numbers. """ numerical_numbers, _ = find_numbers(string_number) # Use the first 2 numbers if len(numerical_numbers) > 2: numerical_numbers = numerical_numbers[:2] if len(numerical_numbers) == 2: first_val = numerical_numbers[0] second_val = numerical_numbers[1] return [first_val, second_val] if first_val <= second_val else first_val elif len(numerical_numbers) == 1: return numerical_numbers[0] else: return [0, 0]
[docs]def find_all(s: str, c: str) -> Generator[int, None, None]: """Find all occurrences of a character in a string and return their indices. Args: s: The input string to search. c: The character to search for. Yields: int: The index of the next occurrence of the character. """ idx = s.find(c) while idx != -1: yield idx idx = s.find(c, idx + 1)
[docs]def clean_str_range(text: str) -> str: """Clean range expression in a string (e.g., '9-10' --> '9 - 10'). Args: text: The input string containing the range expression. Returns: str: The cleaned string with proper spacing around the hyphen. """ idx_list = list(find_all(text, '-')) idx_replace = [ idx for idx in idx_list if idx >= 1 and text[idx - 1].isdigit() ] new_str = ''.join( ' - ' if idx in idx_replace else s for idx, s in enumerate(text) ) return new_str
[docs]def range_intersection_over_union( x_list: List[float], y_list: List[float] ) -> float: """Calculate the intersection over union (IOU) of two ranges.""" min_1, max_1 = min(x_list), max(x_list) min_2, max_2 = min(y_list), max(y_list) overlap = max(0.0, min(max_1, max_2) - max(min_1, min_2)) length_x = (max_1 - min_1) + 1e-12 length_y = (max_2 - min_2) + 1e-12 iou = safe_division(overlap, length_x + length_y - overlap) return iou
[docs]def evaluate_quantity( quantity_pred: List[Union[float, List[float]]], quantity_answer: List[List[float]], ) -> List[int]: """Evaluate numerical predictions against numerical answers.""" return [ metric_numerical_range(pred, ans) for pred, ans in zip(quantity_pred, quantity_answer) ]
[docs]def evaluate_entity( entity_pred: List[str], entity_answer: List[List[str]] ) -> List[int]: """Evaluate entity predictions against entity answers. Criteria: Maximum score of exact match to entity answer. Args: entity_pred: prediction of a string entity_answer: a list of string answer reference Returns: List: 0 or 1 """ return [ metric_max_over_ground_truths(exact_match_score, pred, ans) for pred, ans in zip(entity_pred, entity_answer) ]
[docs]def evaluate_time( time_pred: List[str], time_answer: List[List[str]] ) -> List[int]: """Evaluate time predictions against time answers. Criteria: 1) +/- one year --> correct 2) if asking for date, but the year is correct --> correct Args: time_pred: prediction of time time_answer: a list of time reference Returns: List: 0 or 1 """ return evaluate_entity(time_pred, time_answer)
[docs]def evaluation( predictions: List[Dict[str, Any]], qid2example: Dict[str, Dict[str, Any]] ) -> Tuple[List[int], List[int], List[int]]: """Evaluate predictions against ground truth answers. Separate questions into time, numerical, and string categories. Args: predictions: A list of predictions. qid2example: A mapping from question ID to ground truth examples. Returns: Tuple[List[int], List[int], List[int]]: Lists of scores for time, quantity, and entity predictions. """ time_pred, quantity_pred, entity_pred = [], [], [] time_answer, quantity_answer, entity_answer = [], [], [] for p in predictions: quid = p['data_id'] if quid not in qid2example: continue example = qid2example[quid] pred = p['prediction'] answer = example['answer_eval'] question_type = QuestionType[example['question_type']] if question_type == QuestionType.Time: time_pred.append(pred) time_answer.append(answer) elif question_type == QuestionType.Numerical: pred_range = process_numerical_answer(pred) answer_range = [float(a) for a in answer] quantity_pred.append(pred_range) quantity_answer.append(answer_range) else: entity_pred.append(pred) entity_answer.append(answer) score_time = evaluate_time(time_pred, time_answer) score_quantity = evaluate_quantity(quantity_pred, quantity_answer) score_entity = evaluate_entity(entity_pred, entity_answer) return score_time, score_quantity, score_entity
[docs]def get_results( predictions: List[Dict[str, Any]], qid2example: Dict[str, Dict[str, Any]] ) -> Tuple[float, float, float, float]: """Get evaluation scores for predictions. Args: predictions: A list of predictions. qid2example: A mapping from question ID to ground truth examples. Returns: Tuple[float, float, float, float]: Final scores for time, quantity, entity, and overall predictions. """ score_time, score_quantity, score_entity = evaluation( predictions, qid2example ) final_score_time = safe_division(sum(score_time), len(score_time)) final_score_quantity = safe_division(sum(score_quantity), len(score_quantity)) final_score_entity = safe_division(sum(score_entity), len(score_entity)) final_score = safe_division( sum(score_time + score_quantity + score_entity), len(score_time + score_quantity + score_entity), ) return final_score, final_score_time, final_score_quantity, final_score_entity
[docs]def harmonic_mean(*args: float) -> float: """Calculate the harmonic mean of the input arguments.""" args_safe = [a if a != 0 else 1e-12 for a in args] hmean = len(args_safe) / sum((1.0 / val) for val in args_safe) return hmean
[docs]def evaluate_infoseek( predictions: List[Dict[str, Any]], qid2example: Dict[str, Dict[str, Any]] ) -> Dict[str, float]: """Evaluate predictions against references. Args: predictions: A list of predictions. qid2example: A dictionary of reference with question_id as key. Returns: Dict[str, float]: A dictionary containing the final scores for time, quantity, entity, and overall predictions. """ final_score, score_time, score_num, score_string = get_results( predictions, qid2example ) return { 'score': round(final_score * 100, 2), 'score_time': round(score_time * 100, 2), 'score_num': round(score_num * 100, 2), 'score_string': round(score_string * 100, 2), }
[docs]def evaluate_infoseek_full( predictions: Dict[str,List[Dict[str, Any]]], qid2example: Dict[str, Dict[str, Any]], ) -> Dict[str, Any]: infoseek_score = {} for split, pred in predictions.items(): split_score = evaluate_infoseek(pred, qid2example) split_score['split'] = split infoseek_score[split] = split_score print(pd.DataFrame(infoseek_score.values()).to_latex(float_format="%.2f")) split_scores = [score['score'] for score in infoseek_score.values()] return { 'final_score': round(harmonic_mean(*split_scores), 2), 'unseen_question_score': infoseek_score['unseen_question'], 'unseen_entity_score': infoseek_score['unseen_entity'], }
[docs]def fix_space(string): return re.sub(r'(\d+[\.,]) (\d+)',r'\1\2',string)
[docs]def evaluate( prediction_path: Union[str, List[str]], reference_path: Union[str, Dataset], do_fix_space: bool = False ) -> Dict[str, Any]: """Evaluate predictions against references. Args: prediction_path: Path to prediction file. reference_path: Path to reference file. Returns: Dict[str, Any]: A dictionary containing the final scores for time, quantity, entity, and overall predictions. """ if isinstance(reference_path, Dataset) or not reference_path.endswith('jsonl'): if isinstance(reference_path, Dataset): reference = reference_path else: reference = load_from_disk(reference_path) reference = reference.remove_columns([c for c in reference.column_names if c not in {"id", "output", "data_split", "question_type"}]) qid2example = {} for item in reference: item['answer_eval'] = item['output']['answer'] qid2example[item['id']] = item else: reference = load_jsonl(reference_path) qid2example = prepare_qid2example(reference) if not isinstance(prediction_path, List) and prediction_path.endswith('jsonl'): predictions = load_jsonl(prediction_path) else: if isinstance(prediction_path, List): predictions = prediction_path else: with open(prediction_path, 'rt') as file: predictions = json.load(file) predictions = [{"data_id": q_id, "prediction": p} for q_id, p in zip(reference['id'], predictions)] # split predictions into two splits: unseen_question and unseen_entity splits = dict(unseen_question = [], unseen_entity = []) for pred in predictions: if do_fix_space: pred['prediction'] = fix_space(fix_space(fix_space(pred['prediction']))) data_id = pred['data_id'] if data_id in qid2example: if qid2example[data_id]['data_split'].endswith('unseen_question'): splits['unseen_question'].append(pred) else: splits['unseen_entity'].append(pred) else: pass return evaluate_infoseek_full(splits, qid2example)
[docs]def prepare_qid2example( reference: List[Dict[str, Any]] ) -> Dict[str, Dict[str, Any]]: """Convert reference to qid2example dictionary.""" qid2example = dict() for r in reference: qid = r['data_id'] q_type = QuestionType[r['question_type']] if q_type == QuestionType.Numerical: # Process numerical answer: # "answer_eval": [{"wikidata": 1.0, "range": [0.9, 1.1]}] # --> "answer_eval": [0.9, 1.1] if isinstance(r['answer_eval'], list): ans_eval = r['answer_eval'][0]['range'] else: ans_eval = r['answer_eval']['range'] r['answer_eval'] = [str(ans) for ans in ans_eval][:2] qid2example[qid] = r return qid2example
[docs]def load_jsonl(path: str) -> List[Dict[str, Any]]: """Load a JSONL file into a list of Dict[strionaries.""" data = [] with open(path, 'r', encoding='utf-8') as file: for line in file: data.append(json.loads(line)) return data
[docs]def main(prediction_path: str, reference_path: str, do_fix_space: bool = False): result = evaluate(prediction_path, reference_path, do_fix_space=do_fix_space) final_score = result["final_score"] unseen_question_score = result["unseen_question_score"]["score"] unseen_entity_score = result["unseen_entity_score"]["score"] print(f"final score: {final_score}") print(f"unseen question score: {unseen_question_score}") print(f"unseen entity score: {unseen_entity_score}")
if __name__ == "__main__": CLI(main)