# coding: utf-8
"""
========
Overview
========
.. image:: ../source_docs/kilt2vqa_big_picture.png
All the data should be stored in the `data` folder, at the root of this repo.
The goal is to generate questions suitable for VQA by replacing explicit entity mentions in existing textual QA datasets
by an ambiguous one and illustrate the question with an image (that depicts the entity).
-------
``ner``
-------
.. image:: ../source_docs/kilt2vqa_nlp.png
Slight misnomer, does a bit more than NER, i.e. dependency parsing.
Detected entities with valid type and dependency are replaced by a placeholder along with its syntactic children.
e.g. 'Who wrote *the opera **Carmen***?' → 'Who wrote `{mention}`'
Note that, not only the entity mention ('Carmen') but its syntactic children ('the opera')
are replaced by the placeholder.
-------
``ner``
-------
.. image:: ../source_docs/kilt2vqa_nlp.png
Disambiguate entity mentions using Wikipedia pages provided in KILT.
TriviaQA was originally framed as a reading-comprehension problem so the authors applied off-the-shelf NED and filtered
out pages that didn't contain the answer.
For every entity mention we compute Word Error Rate (WER, i.e. word-level Levenshtein distance) for every wikipedia title
and aliases. We save the minimal match and WER and recommand filtering out WER > 0.5
More data about these entities is gathered in `wiki.py`,
just run `kilt2vqa.py count_entities` first to save a dict with all disambiguated entities (outputs `entities.json`).
---------------------
``generate mentions``
---------------------
.. image:: ../source_docs/kilt2vqa_mentiong_gen.png
Generate ambiguous entity mentions that can be used to replace the placeholder
in the input question (you need to run `wiki.py data` first):
- if the gender is available (not animal sex):
- 'this man' or 'this woman' (respecting transgender)
- 'he/him/his' or 'she/her/hers' w.r.t mention dependency
- if human and occupation is available : 'this `{occupation}`' (respecting gender if relevant, e.g. for 'actress')
- else if non-human:
- if a taxon : 'this `{taxon rank}`' (e.g. 'species')
- else 'this `{class}`' (e.g. 'this tower')
---------------
``generate vq``
---------------
Make the VQA triple by choosing:
- uniformly a mention type and a mention from this mention type (generated in the previous step)
- the image with the best score (according to the heuristics computed in `wiki.py commons heuristics`).
Tries to use a unique image per entity.
---------------
``labelstudio``
---------------
First calls `generate vq` i.e. no need to call both!
The dataset is then converted to the Label Studio JSON format so you can annotate and convert the errors of the automatic pipeline (see [`ANNOTATION.md`](./ANNOTATION.md)).
------------
``download``
------------
Downloads images (set in `meerqat.data.wiki data entities`) from Wikimedia Commons using `meerqat.data.wiki.save_image`.
This might take a while (thus the sharding options), any help/advice is appreciated :)
==============
For ``docopt``
==============
Usage:
kilt2vqa.py ner <subset> [--disable_caching]
kilt2vqa.py ned <subset> [--map_kwargs=<path> --disable_caching]
kilt2vqa.py generate mentions <subset> [--threshold=<threshold> --disable_caching]
kilt2vqa.py generate vq <subset> [--image_width=<n> --map_kwargs=<path> --disable_caching <categories_to_exclude>...]
kilt2vqa.py count_entities <subset> [--threshold=<threshold> --map_kwargs=<path> --disable_caching]
kilt2vqa.py labelstudio <subset> [--image_width=<n> --alternative_images=<n> --disable_caching <categories_to_exclude>...]
kilt2vqa.py download <subset> [--image_width=<n> --map_kwargs=<path> --disable_caching --num_shards=<n> --shard_index=<n>]
Options:
--threshold=<threshold> Threshold for Word Error Rate (WER, i.e. word-level Levenshtein distance)
to consider the entity disambiguated [default: 0.5].
--alternative_images=<n> Number of alternative images to provide [default: 8].
--image_width=<n> Desired thumbnail width in pixels for the image url. Defaults to full-size
--map_kwargs=<path> Path towards a JSON file containing key-words arguments for the dataset.map function (e.g. batch size)
--disable_caching Disables Dataset caching (useless when using save_to_disk), see datasets.set_caching_enabled()
--num_shards=<n> Shard the dataset in n parts when downloading images
--shard_index=<n> Index of the desired shard when downloading images (use along with --num_shards)
=========
Functions
=========
"""
import warnings
import json
import numpy as np
import pandas as pd
import random
import re
import spacy
try:
from spacy.gold import align
except ImportError as e:
warnings.warn(f"Got the following ImportError: {e}.\nTry using spacy==2.2.4")
from spacy.symbols import DATE, TIME, PERCENT, MONEY, QUANTITY, ORDINAL, CARDINAL
from spacy.symbols import dobj, nsubj, pobj, obj, nsubjpass, poss, obl, root
from docopt import docopt
from tqdm import tqdm
from tabulate import tabulate
import requests
from datasets import load_dataset, load_from_disk, set_caching_enabled
from .loading import map_kilt_triviaqa, DATA_ROOT_PATH
from .wiki import HUMAN, RESERVED_IMAGES, special_path_to_file_name, file_name_to_thumbnail, thumbnail_to_file_name, save_image
from .utils import md5
# spacy constants for NER
INVALID_ENTITIES = {DATE, TIME, PERCENT, MONEY, QUANTITY, ORDINAL, CARDINAL}
# TODO check root and obj: in practice, it never happened on TriviaQA dev set
VALID_DEP = {dobj, nsubj, pobj, obj, nsubjpass, poss, obl, root}
# spacy constants for pronoun-mention generation
HE_SHE_DEP = {spacy.symbols.NAMES[dep] for dep in [nsubj, nsubjpass]}
HIM_HER_DEP = {spacy.symbols.NAMES[dep] for dep in [dobj, obj, obl]}
HIS_HERS_DEP = {spacy.symbols.NAMES[poss]}
# Wikidata constants for pronoun-mention generation
# 'male' 'trans. male'
HE_GENDER = {'Q6581097', 'Q2449503'}
# 'female' 'trans. female'
SHE_GENDER = {'Q6581072', 'Q1052281'}
# 'intersex' 'non-binary'
NA_GENDER = {'Q1097630', 'Q48270'}
# 'male' 'female'
ANIMAL_SEX = {'Q44148', 'Q43445'}
# set random seed to get consistent random examples
np.random.seed(0)
random.seed(0)
[docs]def wer(a, b):
"""Compute Word Error Rate (word-level Levenshtein distance) using spacy"""
length = max(len(a), len(b))
return align(a, b)[0] / length
[docs]def item2placeholder(item, model=None):
"""Make input question suitable for VQA
by replacing an explicit entity mention and its syntactic children by a placeholder.
e.g. 'Who wrote the opera Carmen?' -> 'Who wrote {mention}'
Note that, not only the entity mention ('Carmen') but its syntactic children ('the opera')
are replaced by the placeholder.
The final goal is to find an image that represents the entity
and fill the placeholder with an appropriate (ambiguous) mention (e.g. 'this opera', 'it')
Parameters
----------
item: dict
original question should be in 'input' key
model: spacy.lang.en.English
Full spacy pipeline, we use both NER and dependency parsing
Returns
-------
item: dict
same as input with extra keys:
- "placeholder": List[dict]
One dict like {"input": str, "entity": dict, "dependency": str}
- "spacy_input": dict
Original input, POS and NER-tagged with spacy in dict format
(using Doc.to_json())
Usage
-----
hugging_face_dataset.map(item2placeholder, fn_kwargs={"model": spacy_model})
"""
item['placeholder'] = []
question = model(item['input'])
item['spacy_input'] = question.to_json()
# filter questions without entities
if not question.ents:
return item
potential_questions = {}
for e in question.ents:
# filter invalid entities
if e.label in INVALID_ENTITIES:
continue
for token in e:
# filter invalid dependencies
if token.dep not in VALID_DEP:
continue
# get leftmost and rightmost syntactic children
# min/max hack is in case the "valid dependency token" is not the head in the entity span
# e.g. "Who wrote the poem ‘The Lady of the Lake’?", "Lake" is pobj but a leaf
start = min(token.left_edge.i, e.start)
end = max(token.right_edge.i, e.end-1)
potential_questions[(start, end)] = (e, token)
# keep only the biggest span for overlapping mentions
for (start, end), (e, token) in potential_questions.items():
included = False
for other_start, other_end in potential_questions:
# included from the left
if start >= other_start and end < other_end:
included = True
# included from the right
elif start > other_start and end <= other_end:
included = True
if not included:
# replace entity and its syntactic children by a placeholder
placeholder = question[:start].text_with_ws + "{mention}" + token.right_edge.whitespace_ + question[end + 1:].text
item['placeholder'].append({'input': placeholder,
'entity': e.as_doc().to_json(),
'dependency': token.dep_})
return item
[docs]def stats(kilt_subset):
stat_dict = {
"placeholders": 0,
"originals": len(kilt_subset),
"distinct source": 0,
"vqs": 0
}
for item in kilt_subset:
len_placeholder = len(item["placeholder"])
stat_dict["placeholders"] += len_placeholder
stat_dict["distinct source"] += min(1, len_placeholder)
stat_dict["vqs"] += len(item.get("vq", []))
for vq in item['placeholder']:
stat_dict.setdefault(vq['dependency'], 0)
stat_dict[vq['dependency']] += 1
return tabulate([stat_dict], headers="keys", tablefmt='latex')
[docs]def stringify(kilt_subset, field="placeholder", include_answer=True, include_provenance=True, include_dep=False):
results = []
invalid = []
for item in kilt_subset:
if item[field]:
result = [f"Q: {item['input']}"]
for vq in item[field]:
result.append(f"-> {vq['input']} {vq['dependency'] if include_dep else ''}")
if include_answer:
result.append(f"A: {item['output']['answer'][0]}")
if include_provenance and item['output']['provenance']:
result.append(f"\t{item['output']['provenance'][0]['title'][0]}")
results.append("\n".join(result))
else:
invalid.append(item['spacy_input'])
return "\n\n\n".join(results), "\n".join(invalid)
[docs]def ner(subset):
"""
1st step: Named Entity Recognition (NER):
Goes through the kilt subset and apply 'item2placeholder' function (see its docstring)
Save the resulting dataset to f"{DATA_ROOT_PATH}/meerqat_{subset}"
"""
# load model and data
model = spacy.load("en_core_web_lg")
kilt_tasks = map_kilt_triviaqa()
kilt_subset = kilt_tasks[subset]
# go through the dataset and make input question suitable for VQA
fn_kwargs = {"model": model}
kilt_subset = kilt_subset.map(item2placeholder, fn_kwargs=fn_kwargs)
print(stats(kilt_subset))
# save data
output_path = DATA_ROOT_PATH / f"meerqat_{subset}"
kilt_subset.save_to_disk(output_path)
print(f"Successfully saved output to '{output_path}'")
# show N random examples
N = 100
indices = np.arange(kilt_subset.shape[0])
np.random.shuffle(indices)
randoms = [kilt_subset[i.item()] for i in indices[:N]]
results, invalid = stringify(randoms)
print(f"\nGenerated questions out of {N} random examples:\n")
print(results)
print(f"\nPruned questions out of {N} random examples:\n")
print(invalid)
[docs]def disambiguate(item, wikipedia, wikipedia_ids, pedia_index):
"""Go through candidate pages from TriviaQA and compute WER between entity mention and Wikipedia title/aliases
One should filter entities with a minimal WER of 0.5 (see 'wer' key)
"""
for vq in item["placeholder"]:
ent = vq["entity"]['text'].lower().strip().split()
wers = {}
# process each wikipedia article only once (answer might come from different paragraphs but it's irrelevant for this)
provenances = {provenance['wikipedia_id'][0]: re.sub("\(.+\)", "", provenance['title'][0].lower()).strip() for
provenance in item['output']['provenance']}
for wid, title in provenances.items():
aliases = {title}
# get aliases from wikipedia
pedia_index.setdefault(wid, np.where(wikipedia_ids == wid)[0].item())
wiki_item = wikipedia[pedia_index[wid]]
aliases.update({alias.lower().strip() for alias in wiki_item['wikidata_info']['aliases']['alias']})
# compute WER and keep minimal for all aliases
word_er = min([wer(ent, alias.split()) for alias in aliases])
wers[wid] = word_er
# keep minimal WER for all candidate articles
best_provenance = min(wers, key=wers.get)
best_wer = wers[best_provenance]
wiki_item = wikipedia[pedia_index[best_provenance]]
vq["entity"]['wikidata_info'] = wiki_item['wikidata_info']
vq["entity"]['wikipedia_id'] = wiki_item['wikipedia_id']
vq["entity"]["wer"] = best_wer
return item
[docs]def ned(subset, **map_kwargs):
"""
2nd step: Named Entity Disambiguation (NED) using TriviaQA provided list
Assumes that you already ran NER and loads dataset from f"{DATA_ROOT_PATH}/meerqat_{subset}"
and wikipedia from DATA_ROOT_PATH
"""
# load data
dataset = load_from_disk(DATA_ROOT_PATH / f"meerqat_{subset}")
wikipedia = load_dataset('kilt_wikipedia', cache_dir=DATA_ROOT_PATH)['full']
wikipedia_ids = np.array(wikipedia["wikipedia_id"])
pedia_index = {}
fn_kwargs = {"wikipedia": wikipedia, "wikipedia_ids": wikipedia_ids, "pedia_index": pedia_index}
# go through dataset
dataset = dataset.map(disambiguate, fn_kwargs=fn_kwargs, **map_kwargs)
# save data
output_path = DATA_ROOT_PATH / f"meerqat_{subset}"
dataset.save_to_disk(output_path)
print(f"Successfully saved output to '{output_path}'")
[docs]def count_entities(subset, wer_threshold=0.5):
path = DATA_ROOT_PATH / f"meerqat_{subset}"
dataset = load_from_disk(path)
entities = {}
total, disambiguated = 0, 0
for item in tqdm(dataset):
for vq in item['placeholder']:
total += 1
entity = vq['entity']
if entity['wer'] > wer_threshold:
continue
disambiguated += 1
wikidata_id = entity['wikidata_info']['wikidata_id']
entities.setdefault(wikidata_id, {})
entities[wikidata_id]["wikipedia_id"] = entity["wikipedia_id"]
entities[wikidata_id].setdefault("n_questions", 0)
entities[wikidata_id]["n_questions"] += 1
output_path = path / "entities.json"
with open(output_path, 'w') as file:
json.dump(entities, file)
print(f"\nSuccessfully saved output to {output_path}")
print(f"Disambiguated {disambiguated} questions ({len(entities)} unique entities) "
f"out of {total} questions with a threshold of {wer_threshold}")
print(pd.DataFrame([entity["n_questions"] for entity in entities.values()]).describe())
[docs]def generate_mention(item, entities, wer_threshold=0.5, feminine_labels={}):
for vq in item["placeholder"]:
entity = vq['entity']
ambiguous_mentions = {
"pronouns": [],
"man_woman": [],
"occupation": [],
"instanceof": []
}
# filter ambiguous entities and skip filtered entities
qid = entity['wikidata_info']['wikidata_id']
entity_data = entities.get(qid)
if entity['wer'] > wer_threshold or not entity_data:
vq['ambiguous_mentions'] = ambiguous_mentions
continue
dependency = vq['dependency']
gender = entity_data.get('gender', {}).get('value')
gender = gender.split("/")[-1] if gender else gender
human = HUMAN in entity_data.get('instanceof', {})
taxon_rankLabel = entity_data.get('taxon_rankLabel', {}).get('value')
# man_woman and pronouns
if gender not in ANIMAL_SEX:
# man_woman
if gender in HE_GENDER:
ambiguous_mentions["man_woman"].append("this man")
elif gender in SHE_GENDER:
ambiguous_mentions["man_woman"].append("this woman")
elif gender in NA_GENDER or not gender:
pass
else:
warnings.warn(f"No case were set for this gender: '{gender}'")
# pronouns
if dependency in HE_SHE_DEP:
if gender in HE_GENDER:
ambiguous_mentions["pronouns"].append("he")
elif gender in SHE_GENDER:
ambiguous_mentions["pronouns"].append("she")
elif dependency in HIM_HER_DEP:
if gender in HE_GENDER:
ambiguous_mentions["pronouns"].append("him")
elif gender in SHE_GENDER:
ambiguous_mentions["pronouns"].append("her")
elif dependency in HIS_HERS_DEP:
if gender in HE_GENDER:
ambiguous_mentions["pronouns"].append("his")
elif gender in SHE_GENDER:
ambiguous_mentions["pronouns"].append("hers")
else:
warnings.warn(f"No case were set for this dependency: '{dependency}'")
# occupation
if entity_data.get('occupation') and human:
for occupation in entity_data['occupation'].values():
feminine_label = feminine_labels.get(occupation['value'])
if feminine_label and gender in SHE_GENDER:
occupation_label = feminine_label
# default label is default value since most names in English don't have genders
else:
occupation_label = occupation['label']['value']
ambiguous_mentions['occupation'].append(f"this {occupation_label}")
# taxon rank (e.g. "species") or class (aka instanceof)
elif not human:
# taxon rank
if taxon_rankLabel:
ambiguous_mentions['instanceof'].append(f"this {taxon_rankLabel}")
# class (instanceof)
else:
for instanceof in entity_data.get('instanceof', {}).values():
feminine_label = feminine_labels.get(instanceof['value'])
if feminine_label and gender in SHE_GENDER:
instanceof_label = feminine_label
# default label is default value since most names in English don't have genders
else:
instanceof_label = instanceof['label']['value']
ambiguous_mentions['instanceof'].append(f"this {instanceof_label}")
vq['ambiguous_mentions'] = ambiguous_mentions
return item
[docs]def generate_mentions(subset, wer_threshold=0.5, **map_kwargs):
"""3rd step: generate ambiguous mentions given entities attributes (run `wiki.py data` first)"""
# load data
dataset_path = DATA_ROOT_PATH / f"meerqat_{subset}"
dataset = load_from_disk(dataset_path)
with open(dataset_path / "entities.json", 'r') as file:
entities = json.load(file)
feminine_labels_path = dataset_path / "feminine_labels.json"
if feminine_labels_path.exists():
with open(feminine_labels_path, "r") as file:
feminine_labels = json.load(file)
else:
feminine_labels = {}
fn_kwargs = {
"entities": entities,
"wer_threshold": wer_threshold,
"feminine_labels": feminine_labels
}
# go through dataset
dataset = dataset.map(generate_mention, fn_kwargs=fn_kwargs, **map_kwargs)
# save data
dataset.save_to_disk(dataset_path)
print(f"Successfully saved output to '{dataset_path}'")
total, with_mention = 0, 0
for item in dataset:
for vq in item["placeholder"]:
total += 1
if [mention for mention_type in vq['ambiguous_mentions'] for mention in mention_type]:
with_mention += 1
print(f"{with_mention*100/total:.2f}% of the visual questions have at least one ambiguous mention")
[docs]def generate_vq(item, entities, image_width=512):
"""
Generate a image (url), question, answer triple by choosing:
- uniformly a mention type and a mention from this mention type
- the image with the best score (with its title sorted last in "titles").
Tries to use a unique image per entity.
Parameters
----------
item: Dataset item
entities: dict (see wiki.py)
image_width: int, optional
desired thumbnail width in pixels for the image url
Defaults to 512
Returns
-------
item: Dataset item
with a new 'vq' key (List[dict])
"""
item['vq'] = []
kilt_id = item['id']
for placeholder in item['placeholder']:
mention_types = [mention_type for mention_type in placeholder.get('ambiguous_mentions', {}).values() if mention_type]
if not mention_types:
continue
qid = placeholder['entity']['wikidata_info']['wikidata_id']
description = placeholder['entity']['wikidata_info']['description']
# entity might have been filtered before-hand -> get qid instead of "[qid]"
entity = entities.get(qid, {})
titles = entity.get("titles")
if not titles:
continue
# try to use unique images per entity -> pop titles
if len(titles) > 1:
# note we assume that the images are sorted in ascending order w.r.t. their score
title = titles.pop()
else:
title = titles[0]
url = file_name_to_thumbnail(title[len("File:"):], image_width=image_width)
# choose mention type (e.g. pronoun or occupation) uniformly from all types (that are not empty)
mention_type = random.choice(mention_types)
# choose mention uniformly from all mentions in this type (e.g. Barack Obama is a politician and a statesperson)
mention = random.choice(mention_type)
inp = placeholder['input'].format(mention=mention)
meerqat_id = md5("".join((kilt_id, qid, inp, url)))
vq = {'input': inp,
'url': url,
'wikidata_id': qid,
'meerqat_id': meerqat_id,
'mentions': [mention for mention_type in mention_types for mention in mention_type],
'description': description
}
item['vq'].append(vq)
return item
[docs]def generate_vqs(subset, exclude_categories=set(), image_width=512, **map_kwargs):
"""
Parameters
----------
subset: str
Name of the subset to load (e.g. validation_triviaqa)
exclude_categories: set, optional
Exclude image where these keywords are included in one of its categories
e.g. {'cosplay'} might save you some trouble with GDPR
Defaults to empty set (i.e. keep all)
"""
# load data
print("loading data...")
dataset_path = DATA_ROOT_PATH / f"meerqat_{subset}"
dataset = load_from_disk(dataset_path)
with open(dataset_path / "entities.json", 'r') as file:
entities = json.load(file)
# sort images and remove forbidden ones (move to wiki.py if it's too slow?)
for entity in tqdm(entities.values(), desc="Processing images"):
images = entity.get("images")
if not images:
continue
# remove reserved images (e.g. illustrative_image) from the candidates
for reserved_image_key in RESERVED_IMAGES:
for reserved_image in map(special_path_to_file_name, entity.get(reserved_image_key, {})):
images.pop(reserved_image, None)
# Exclude image where these keywords are included in one of its categories
if exclude_categories:
todel = []
for title, image in images.items():
del_image = False
image_categories = image.get("categories")
if image_categories is None:
continue
for image_category in image_categories:
image_category = image_category.lower()
for category_to_exclude in exclude_categories:
if category_to_exclude in image_category:
del_image = True
break
if del_image:
todel.append(title)
break
for title in todel:
images.pop(title)
# sort images w.r.t. their score in ASCENDING order (allows simpler use of pop)
entity["titles"] = sorted(images, key=lambda title: len(images[title]['heuristics']))
# go through dataset
dataset = dataset.map(generate_vq, fn_kwargs=dict(entities=entities, image_width=image_width), **map_kwargs)
# save data
dataset.save_to_disk(dataset_path)
print(f"Successfully saved output to '{dataset_path}'")
print(stats(dataset))
return dataset, entities
[docs]def labelstudio(*args, image_width=512, alternative_images=8, **kwargs):
"""run generate_vqs and convert dataset to the Label Studio JSON format"""
print("Generating visual questions...")
dataset, entities = generate_vqs(*args, image_width=image_width, **kwargs)
# convert dataset to the Label Studio JSON format
ls = {}
i = 0
for item in tqdm(dataset, desc="Converting to Label Studio"):
for vq in item["vq"]:
# make some names more explicit and copy some stuff from original QA
vq["image"] = vq.pop('url')
title = thumbnail_to_file_name(vq["image"]).replace('_', ' ')
caption = re.match(r"(.+)\.\w+", title)
caption = caption.group(1) if caption is not None else title
vq["image_caption"] = caption
vq['question'] = item['input']
vq["vq"] = vq.pop('input')
vq['answer'] = item['output']['answer'][0]
vq['mentions'] = ", ".join(vq['mentions'])
qid = vq['wikidata_id']
entity = entities[qid]
vq['entityLabel'] = entity.get("entityLabel", {}).get("value", "")
vq['entity_image'] = entity.get('reference_image', '')
# gather alternative images to vq["image"]
# remember images are sorted in ASC order wrt their score, thus the [::-1] to reverse the list
for j, title in enumerate(entity["titles"][-alternative_images: ][::-1]):
# remove "File:" prefix and extension
caption = re.match(r"File:(.+)\.\w+", title)
caption = caption.group(1) if caption is not None else title
# title to url
url = file_name_to_thumbnail(title[len("File:"):], image_width=image_width)
vq[f"altimage{j}"] = url
vq[f"altimage{j}caption"] = caption
# no missing values: use empty string instead
for j in range(j+1, alternative_images):
vq[f"altimage{j}"] = ""
vq[f"altimage{j}caption"] = ""
ls[str(i)] = {"data": vq}
i += 1
# save output
out_path = DATA_ROOT_PATH / f"meerqat_{subset}" / "labelstudio.json"
with open(out_path, 'w') as file:
json.dump(ls, file)
print(f"Successfully saved output to '{out_path}'")
[docs]def download_image(item, session, image_width=512):
file_name = thumbnail_to_file_name(item['url'])
thumbnail = file_name_to_thumbnail(file_name, image_width=image_width)
file_path = save_image(thumbnail, session)
file_name = file_path.name if file_path is not None else None
item['image'] = file_name
return item
[docs]def download_images(subset, fn_kwargs, num_shards=None, shard_index=None, **map_kwargs):
print("loading data...")
dataset_path = DATA_ROOT_PATH / f"meerqat_{subset}"
dataset = load_from_disk(dataset_path)
if num_shards is not None:
dataset = dataset.shard(num_shards, shard_index)
fn_kwargs.update(session=requests.Session())
dataset = dataset.map(download_image, fn_kwargs=fn_kwargs, **map_kwargs)
if num_shards is None:
dataset.save_to_disk(dataset_path)
else:
dataset.save_to_disk(dataset_path/f"shard_{shard_index}_of_{num_shards}")
if __name__ == '__main__':
# parse arguments
args = docopt(__doc__)
subset = args['<subset>']
map_kwargs_path = args['--map_kwargs']
if map_kwargs_path:
with open(map_kwargs_path, 'r') as file:
map_kwargs = json.load(file)
else:
map_kwargs = {}
set_caching_enabled(not args['--disable_caching'])
if args['ner']:
ner(subset)
elif args['ned']:
ned(subset, **map_kwargs)
elif args['count_entities']:
wer_threshold = float(args['--threshold'])
count_entities(subset, wer_threshold=wer_threshold)
elif args['generate']:
if args['mentions']:
wer_threshold = float(args['--threshold'])
generate_mentions(subset, wer_threshold=wer_threshold, **map_kwargs)
elif args['vq']:
exclude_categories = set(args['<categories_to_exclude>'])
image_width = int(args['--image_width']) if args['--image_width'] is not None else None
generate_vqs(subset, exclude_categories, image_width=image_width, **map_kwargs)
elif args['labelstudio']:
exclude_categories = set(args['<categories_to_exclude>'])
alternative_images = int(args['--alternative_images'])
image_width = int(args['--image_width']) if args['--image_width'] is not None else None
labelstudio(subset, exclude_categories=exclude_categories, alternative_images=alternative_images, image_width=image_width)
elif args['download']:
image_width = int(args['--image_width']) if args['--image_width'] is not None else None
num_shards = int(args['--num_shards']) if args['--num_shards'] is not None else None
shard_index = int(args['--shard_index']) if args['--shard_index'] is not None else None
download_images(subset, fn_kwargs=dict(image_width=image_width), num_shards=num_shards, shard_index=shard_index, **map_kwargs)