Source code for meerqat.train.callbacks
# -*- coding: utf-8 -*-
from typing import Optional
from pytorch_lightning.callbacks import Callback
[docs]class TestAfterFit(Callback):
"""
Calls trainer.test with 'best' ckpt on fit end
(so make sure you configure ModelCheckpoint to save best model).
Parameters
----------
data_update: dict, optional
Arguments of trainer.datamodule to update before running test
I.e. differences between your validation and test setups
E.g. for re-ranking, you might want to pass:
{
"M": 100, # re-rank top-100 passages
"eval_batch_size": 2, # lower batch size to fit in a GPU
"run_path": "/path/to/test/run.json",
"qrels_path": "/path/to/test/qrels.json"
}
"""
def __init__(self, data_update: Optional[dict] = None):
super().__init__()
self.data_update = data_update
[docs] def on_fit_end(self, trainer, pl_module):
if self.data_update is not None:
for k, v in self.data_update.items():
if not hasattr(trainer.datamodule, k):
raise AttributeError(f"{trainer.datamodule.__class__.__name__} has no attribute '{k}'")
setattr(trainer.datamodule, k, v)
# N. B. dataloader reloading is done in Trainer._run_evaluate
trainer.test(pl_module, ckpt_path="best")