Source code for meerqat.interact.system
# -*- coding: utf-8 -*-
from jsonargparse import CLI
import ranx
from ..data.loading import IMAGE_PATH, get_pretrained
from ..image.embedding import get_model_and_transform as get_image_model, embed as embed_image
from ..ir.embedding import embed as embed_text
from ..ir.search import Searcher
from ..ir.fuse import Fusion
[docs]class System:
"""Interact with the KVQAE system"""
def __init__(self, searcher_kwargs: dict, image_kwargs: dict, text_kwargs: dict,
tokenizer_kwargs: dict, tokenization_kwargs: dict):
self.searcher = Searcher(**searcher_kwargs)
self.image_model = get_image_model(**image_kwargs)
self.text_model = get_pretrained(**text_kwargs)
self.tokenizer = get_pretrained(**tokenizer_kwargs)
self.tokenization_kwargs = tokenization_kwargs
[docs] def pipeline(self, batch):
# 1. process input
batch = embed_image(batch, **self.image_model)
batch = embed_text(batch, **self.text_model, **self.tokenizer,
tokenization_kwargs=self.tokenization_kwargs)
# 2. IR
batch = self.searcher(batch)
self.searcher.qrels = ranx.Qrels(self.searcher.qrels)
for name, run in self.searcher.runs.items():
self.searcher.runs[name] = ranx.Run(run, name=name)
fuser = Fusion(
qrels=self.searcher.qrels,
runs=list(self.searcher.runs.values()),
**self.searcher.fusion_kwargs
)
run = fuser.test(**self.searcher.fusion_kwargs['subcommand_kwargs'])
# 3. RC: TODO
[docs] def user_loop(self):
image = None
while True:
# TODO download from URL
# 1. image
if image is None:
answer = input(f"Enter the image file name stored in '{IMAGE_PATH}' or enter 'q' to quit.\n")
else:
answer = input(f"Enter the image file name stored in '{IMAGE_PATH}' or press Enter to keep the previous one or enter 'q' to quit.\n")
answer = answer.strip()
if answer.lower() == 'q':
break
elif len(answer) > 0:
image = answer
# else keep previous image
# 2. question
answer = input("Ask your question in English or enter 'q' to quit.\n")
answer = answer.strip()
if answer.lower() == 'q':
break
question = answer
# 3. answer
batch = {'image': [image], 'input': [question], 'id': ['FAKE'], 'output': ['FAKE']}
self.pipeline(batch)
print(f"> {answer}\n")
if __name__ == '__main__':
CLI(System)