Source code for meerqat.viz.wtl_filter

# coding: utf-8
"""
usage: wtl_filter.py [-h] --runA RUNA [--runB RUNB]
                     [--filter {win,loose,intersection,union,nunion,nintersection}] [--kA KA]
                     [--kB KB]
                     <dataset> <output> <qrels>

Filter a dataset for questions according to runs A and B and filter option. For example, find
questions where A wins over B. By default (kA=kB=1), winning means that A has relevant result in
top-1, while B top-1 is irrelevant. However, you can change the definition, e.g. kA=1, kB=100,
means that A wins if it has a relevant result in top-1 and B has only irrelevant in its top-100.

positional arguments:
  <dataset>             Path of the dataset to filter.
  <output>              Path to save the filtered dataset.
  <qrels>               Path to the qrels.

options:
  -h, --help            show this help message and exit
  --runA RUNA           Path to run A. (default: None)
  --runB RUNB           Path to run B. Defaults: run A always win when its right. (default: None)
  --filter {win,loose,intersection,union,nunion,nintersection}
                        Save when A wins over B (default), looses over B, both wins (intersection),
                        either win (union), both loose (nunion), or one looses (nintersection)
                        (default: win)
  --kA KA               Judge A on its top-K. (default: 1)
  --kB KB               Judge B on its top-K. (default: 1)
"""

import argparse
import ranx
from datasets import load_from_disk

[docs]def get_wins(run, topk=1): return {k for k, v in run.scores[f'hit_rate@{topk}'].items() if v > 0}
[docs]def main(dataset, output, qrels, runA, runB=None, filter='win', kA=1, kB=1): d = load_from_disk(dataset) qrels = ranx.Qrels.from_file(qrels) runA = ranx.Run.from_file(runA) evaluateA = ranx.evaluate(qrels=qrels, run=runA, metrics=f'hit_rate@{kA}') runA_wins = get_wins(runA, topk=kA) runA.name = "run A" if runA.name is None else runA.name print(runA.name, f'hit_rate@{kA}:', evaluateA, f"wins: {len(runA_wins)}") if runB is not None: runB = ranx.Run.from_file(runB) evaluateB = ranx.evaluate(qrels=qrels, run=runB, metrics=f'hit_rate@{kB}') runB_wins = get_wins(runB, topk=kB) runB.name = "run B" if runB.name is None else runB.name print(runB.name, f'hit_rate@{kB}:', evaluateB, f"wins: {len(runB_wins)}") intersection, union = runA_wins&runB_wins, runA_wins|runB_wins print(f"intersection: {len(intersection)}, union: {len(union)}") print(f"A over B: {len(runA_wins-runB_wins)}, B over A: {len(runB_wins-runA_wins)}") switch_filter = { 'win': lambda id_: id_ in runA_wins-runB_wins, 'loose': lambda id_: id_ not in runB_wins-runA_wins, 'intersection': lambda id_: id_ in intersection, 'union': lambda id_: id_ in union, 'nunion': lambda id_: id_ not in union, 'nintersection': lambda id_: id_ not in intersection } filtered = d.filter(switch_filter[filter], input_columns='id') else: switch_filter = { 'win': lambda id_: id_ in runA_wins, 'loose': lambda id_: id_ not in runA_wins } filtered = d.filter(switch_filter[filter], input_columns='id') print(filtered) filtered.save_to_disk(output)
if __name__ =='__main__': description = """ Filter a dataset for questions according to runs A and B and filter option. For example, find questions where A wins over B. By default (kA=kB=1), winning means that A has relevant result in top-1, while B top-1 is irrelevant. However, you can change the definition, e.g. kA=1, kB=100, means that A wins if it has a relevant result in top-1 and B has only irrelevant in its top-100. """ parser = argparse.ArgumentParser( description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('dataset', metavar='<dataset>', type=str, help='Path of the dataset to filter.') parser.add_argument('output', metavar='<output>', type=str, help='Path to save the filtered dataset.') parser.add_argument('qrels', metavar='<qrels>', type=str, help='Path to the qrels.') parser.add_argument('--runA', type=str, help='Path to run A.', required=True) parser.add_argument('--runB', type=str, help='Path to run B. Defaults: run A always win when its right.', required=False, default=None) parser.add_argument('--filter', type=str, choices=['win', 'loose', 'intersection', 'union', 'nunion', 'nintersection'], required=False, default='win', help="Save when A wins over B (default), looses over B, both wins (intersection), either win (union), both loose (nunion), or one looses (nintersection)") parser.add_argument('--kA', type=int, help='Judge A on its top-K.', required=False, default=1) parser.add_argument('--kB', type=int, help='Judge B on its top-K.', required=False, default=1) args = parser.parse_args() main(**vars(args))