# coding: utf-8
"""
WIT for MICT
============
Generates the WIT subset for Multimodal Inverse Cloze Task as described in the ECIR-2023 paper:
- english-only subset
- images paired with the sections
- filtering out images with irrelevant formats (e.g. .svg) or not downloaded (e.g. you got a 404)
- splitting in train/validation/test without overlap between the articles
- splitting sections in sentences (``meerqat.data.loading sentences``)
- removing sections with a single sentence (DIY after)
- images should be resized to have a maximum height or width of 512 pixels using ``meerqat.image.resize`` (DIY after)
You should end up with:
- 877,635 pairs in train
- 48,271 pairs in validation
- 48,815 pairs in test
What you should have first
==========================
Downloaded from https://github.com/google-research-datasets/wit
(By any chance, if you have access to Jean Zay, it is available at ``$DSDIR/WIT`` with the right format).::
$ tree WIT
WIT/
├── train
│ ├── 00
│ │ ├── 000004379cfea6d71f7c47180c2163ee40887b7b23798535435d9b2c0065cea5.png
│ │ ├── 000004528fa952ab9e2212ff7c749dfb1f28eb0fae2f45bec768e3ba72265420.jpg
│ │ ├── ...
│ │ └── 00ffff77789c938b5c2ce004d09246d1d54ef5d325d831adf3611413794d757f.jpg
│ ├── 01
│ ├── ...
│ └── ff
├── train_images.tsv
├── wit_v1.train.all-00000-of-00010.tsv
├── wit_v1.train.all-00001-of-00010.tsv
├── wit_v1.train.all-00002-of-00010.tsv
├── wit_v1.train.all-00003-of-00010.tsv
├── wit_v1.train.all-00004-of-00010.tsv
├── wit_v1.train.all-00005-of-00010.tsv
├── wit_v1.train.all-00006-of-00010.tsv
├── wit_v1.train.all-00007-of-00010.tsv
├── wit_v1.train.all-00008-of-00010.tsv
└── wit_v1.train.all-00009-of-00010.tsv
Instructions for train_images.tsv
---------------------------------
The images from WIT are stored in the "train" directory with the following naming convention:
"train/<xy>/<hash>.<ext>" where
- <hash> is the SHA256 hash of the image's URL
- <xy> are the first two characters of the hash (which means there are 256 subfolders named "00" to "ff")
- <ext> is the extension of the image.
The file "train_images.tsv" contains all the URL of the images with their download status
("True" if the image could be downloaded, "False" otherwise) and the corresponding path.
Once you’ve done this mapping you hsould add it rouself to the dataset.
Sample from "train_images.tsv":::
url downloaded path
http://upload.wikimedia.org/wikipedia/ca/d/d4/Trobadores.jpeg True train/95/953feec3651efda25c166841ec8c0cd8d2064bf59f668c8dcb62dc823963a385.jpg
http://upload.wikimedia.org/wikipedia/commons/0/00/%2703-%2705_Pontiac_Montana_Taxi.jpg True train/35/35bcbf0f09424126932707a702b152fac7ebd9c932a877a3f2515d9fe67bb44d.jpg
http://upload.wikimedia.org/wikipedia/commons/0/00/%2755_Singer_4ADT_Roadster_%28Hudson%29.JPG True train/dd/dd10ea054385d8fac82a7bca15202434b7ce0facb01519021980ba07c5e6f626.jpg
http://upload.wikimedia.org/wikipedia/commons/0/00/%2768_Chevrolet_Biscayne_Coupe_%28Centropolis_Laval_%2710%29.jpg True train/44/44a11a487b09c8118e1066491880ad7045513379b5c16cdc9460321db113ad2d.jpg
http://upload.wikimedia.org/wikipedia/commons/0/00/%2783_Buick_Century_Sedan.JPG False HTTP Error 404: Not Found
Docopt
======
Usage:
wit.py ict <root_path> <output_path> [--split]
wit.py caption <root_path> <output_path> [--split --dedup]
Options:
--split Whether to split in train/dev/test sets
--dedup Whether to de-duplicate identical caption-image pairs
"""
import json
from tqdm import tqdm
from docopt import docopt
import random
import pandas as pd
from datasets import Dataset, DatasetDict, concatenate_datasets
from pathlib import Path
random.seed(0)
VALID_ENCODING = {'jpeg', 'jpg', 'png'}
[docs]def check_encoding(url):
if url.split('.')[-1].lower() in VALID_ENCODING:
return True
return False
[docs]def fill_wit_for_mict(wit, wit_for_mict, downloaded_images):
for _, row in tqdm(wit.iterrows(), total=len(wit)):
wit_for_mict.setdefault(row.page_title, {})
image_path = downloaded_images[row.image_url]
if row.is_main_image:
wit_for_mict[row.page_title]['context_image_url'] = row.image_url
wit_for_mict[row.page_title]['context_image_path'] = image_path
if not isinstance(row.context_section_description, str):
continue
# not used in ICT
wit_for_mict[row.page_title]['context_text'] = row.context_section_description
else:
if not isinstance(row.context_section_description, str):
continue
wit_for_mict[row.page_title].setdefault('sections', {})
key = str(hash((row.context_section_description, row.image_url)))
# images paired with the sections
wit_for_mict[row.page_title]['sections'][key] = {
"text": row.context_section_description,
"image_url": row.image_url,
"image_path": image_path
}
[docs]def dict_to_dataset(d):
table=[]
for title, article in tqdm(d.items()):
if 'context_image_path' not in article:
continue
for section in article.get('sections', {}).values():
section['title'] = title
section['context_image_url'] = article['context_image_url']
section['context_image_path'] = article['context_image_path']
table.append(section)
df = pd.DataFrame(table)
dataset = Dataset.from_pandas(df)
return dataset
[docs]def common_filter(wit, downloaded_images):
# english-only subset
wit = wit[wit.language=='en']
# downloaded and valid encoding
wit = wit[wit.image_url.isin(downloaded_images)]
wit = wit[wit.image_url.map(check_encoding)]
return wit
[docs]def mict(paths, downloaded_images, output, split=False):
unique_wit_for_mict={}
for path in tqdm(paths):
wit = pd.read_csv(path, delimiter='\t')
wit = common_filter(wit, downloaded_images)
fill_wit_for_mict(wit, unique_wit_for_mict, downloaded_images)
with open(output/'english_wikipedia.json', 'wt') as file:
json.dump(unique_wit_for_mict, file)
# split in test/validation/train without overlap between the articles
if split:
titles = list(fill_wit_for_mict)
random.shuffle(titles)
# 5% in test and validation, rest in train
n_in_test = round(len(titles)*0.05)
superset = {}
for title in titles[:n_in_test]:
superset['test'][title] = unique_wit_for_mict.pop(title)
for title in titles[n_in_test: n_in_test*2]:
superset['validation'][title] = unique_wit_for_mict.pop(title)
for title in titles[n_in_test*2: ]:
superset['train'][title] = unique_wit_for_mict.pop(title)
dataset_dict = DatasetDict()
for name, subset in superset.values():
dataset_dict[name] = dict_to_dataset(subset)
else:
dataset_dict = dict_to_dataset(unique_wit_for_mict)
print(dataset_dict)
dataset_dict.save_to_disk(output)
[docs]def is_unique(item, unique_pairs):
pair = (item['input'], item['image'])
if pair in unique_pairs:
return False
unique_pairs.add(pair)
return True
[docs]def caption(paths, downloaded_images, output, split=False, dedup=False):
dataset_list = []
for path in tqdm(paths):
wit = pd.read_csv(path, delimiter='\t')
wit = common_filter(wit, downloaded_images)
wit['image'] = [downloaded_images[url] for url in wit.image_url]
# duplicate caption_reference_description and caption_attribution_description
ref = wit[wit.caption_reference_description.notna()]
ref.rename(columns = {'caption_reference_description': 'input'}, inplace=True)
dataset_list.append(Dataset.from_pandas(ref))
attr = wit[wit.caption_attribution_description.notna()]
attr.rename(columns = {'caption_attribution_description': 'input'}, inplace=True)
dataset_list.append(Dataset.from_pandas(attr))
dataset = concatenate_datasets(dataset_list)
print(dataset)
if dedup:
before = len(dataset)
unique_pairs = set()
# slower than numpy but saves 3.72 TiB of RAM
dataset = dataset.filter(is_unique, fn_kwargs=dict(unique_pairs=unique_pairs), batched=False)
print(f"De-duplication done. Removed {before-len(dataset)} image-caption pairs. {len(dataset)} remaining.")
if split:
raise NotImplementedError()
dataset.save_to_disk(output)
if __name__ == '__main__':
args = docopt(__doc__)
root = Path(args['<root_path>'])
output = Path(args['<output_path>'])
output.mkdir(exist_ok=True)
paths = sorted(root.glob('wit_v1.train.all-00*'))
# TODO make this optional
train_images = pd.read_csv(
root/'train_images.tsv',
sep='\t',
# FIXME this is a hack for Jean Zay’s version which last rows are "url downloaded path"
nrows=11419528
)
downloaded_images = train_images[train_images.downloaded]
downloaded_images = dict(zip(downloaded_images.url, downloaded_images.path))
print(f"You have downloaded {len(downloaded_images)} out of {len(train_images)} images.")
if args['ict']:
mict(paths, downloaded_images=downloaded_images, output=output, split=args['--split'])
elif args['caption']:
caption(paths, downloaded_images=downloaded_images, output=output,
split=args['--split'], dedup=args['--dedup'])