Source code for meerqat.train.save_ptm
# -*- coding: utf-8 -*-
"""
usage: save_ptm.py [-h] [--config <config>] [--bert] <ckpt>
Save the PreTrainedModel(s) wrapped inside the Trainee (LightningModule).
positional arguments:
<ckpt> Path to the lightning checkpoint.
options:
-h, --help show this help message and exit
--config <config> Path to the lightning config file (YAML).
--bert For DPR-based BiEncoder, save BertModel instead of DPR*Encoder
"""
import argparse
import yaml
from pathlib import Path
from . import trainee
[docs]def main(ckpt, config=None, **kwargs):
if config is None:
config = ckpt.parent.parent/'config.yaml'
with open(config, 'rt') as file:
config = yaml.load(file, yaml.Loader)
class_name = config['model']['class_path'].split('.')[-1]
Class = getattr(trainee, class_name)
model = Class.load_from_checkpoint(ckpt, **config['model']['init_args'])
ckpt_path = ckpt.with_suffix('')
model.save_pretrained(ckpt_path, **kwargs)
if __name__ =='__main__':
parser = argparse.ArgumentParser(description='Save the PreTrainedModel(s) wrapped inside the Trainee (LightningModule).')
parser.add_argument('ckpt', metavar='<ckpt>', type=Path, help='Path to the lightning checkpoint.')
parser.add_argument('--config', metavar='<config>', default=None, type=str, help='Path to the lightning config file (YAML).')
parser.add_argument('--bert', action='store_true', help='For DPR-based BiEncoder, save BertModel instead of DPR*Encoder')
args = parser.parse_args()
main(**vars(args))