meerqat.train.data module#
Classes to format data in proper batches to train models. Also holds example generation methods such as Multimodal Inverse Cloze Task (ICT), and dynamic examples based on passages retrieved from KB.
- class meerqat.train.data.DataModule(tokenizer_class, tokenizer_name_or_path, dataset_path=None, train_path=None, validation_path=None, test_path=None, batch_size=8, train_batch_size=None, eval_batch_size=None, M=24, n_relevant_passages=1, keep_dataset_columns=None, tokenization_kwargs=None, image_kwargs={}, loader_kwargs={}, dataset_format=None, input_key='input')[source]#
Bases:
LightningDataModule
Base class for all data modules. It has a tokenizer and handles dataset loading with train/validation/test subsets. For multimodal models, it can also handle image features or pixels using ImageFormatter
- Parameters:
tokenizer_class (str) – Name of a transformers.PreTrainedTokenizer subclass
tokenizer_name_or_path (str) – see transformers.PreTrainedTokenizer.from_pretrained
dataset_path (str, optional) – Path to a DatasetDict that should have ‘train’, ‘validation’ and ‘test’ subsets. Alternatively, you can specify those using the dedicated variables.
train_path (str, optional) –
validation_path (str, optional) –
test_path (str, optional) –
batch_size (int, optional) – batch_size is needed to be able to tune it automatically using auto_scale_batch_size in Trainer It is overriden by train_batch_size, eval_batch_size (if you want to use different batch sizes for training and evaluation)
train_batch_size (int, optional) – batch_size is needed to be able to tune it automatically using auto_scale_batch_size in Trainer It is overriden by train_batch_size, eval_batch_size (if you want to use different batch sizes for training and evaluation)
eval_batch_size (int, optional) – batch_size is needed to be able to tune it automatically using auto_scale_batch_size in Trainer It is overriden by train_batch_size, eval_batch_size (if you want to use different batch sizes for training and evaluation)
M (int, optional) – Number of passages (relevant or irrelevant) per question in a batch Defaults to 24
n_relevant_passages (int, optional) – Defaults to 1
keep_dataset_columns (list, optional) – Keep only these features in the dataset (defaults to keep everything)
tokenization_kwargs (dict, optional) – To be passed to self.tokenizer
image_kwargs (dict, optional) – Passed to ImageFormatter. Optional for text-only models.
loader_kwargs (dict, optional) – Passed to the data loaders (e.g. self.train_dataloader())
dataset_format (dict, optional) – see Dataset.set_format Overrides keep_dataset_columns.
input_key (str, optional) – Holds input text (e.g. question, caption), defaults to ‘input’
- setup(stage=None)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader()[source]#
Implement one or more PyTorch DataLoaders for training.
- Returns:
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- test_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for testing.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
prepare_data()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- class meerqat.train.data.ImageFormatter(*args, precomputed=None, **kwargs)[source]#
Bases:
object
Helper to format image features (precomputed or pixels) in nice square Tensors expected by mm models.
- class meerqat.train.data.PreComputedImageFeatures(config_class, config_path, **kwargs)[source]#
Bases:
object
Helper to format image features in nice square Tensors expected by mm models.
- Parameters:
config_class (str) – Name of a subclass of MMConfig
config_path (str) –
- get_face_inputs(items)[source]#
Formats pre-computed face features in nice square tensors.
The extra dimension 1 stands for the number of images (images are processed one by one and are concatenated in ImageFormatter)
- Returns:
face_inputs –
- {
face: Tensor(batch_size, 1, self.n_faces, self.face_dim)
bbox: Tensor(batch_size, 1, self.n_faces, self.bbox_dim)
attention_mask: Tensor(batch_size, 1, self.n_faces)
}
- Return type:
dict[str, Tensor]
- get_image_inputs(items)[source]#
Formats pre-computed full-image features in nice square tensors.
The extra dimension 1 stands for the number of images (images are processed one by one and are concatenated in ImageFormatter)
- Returns:
image_inputs – one key per image feature {
input: Tensor(batch_size, 1, ?)
attention_mask: Tensor(batch_size, 1)
}
- Return type:
dict[str, dict[str,Tensor]]
- class meerqat.train.data.CrossModalDataModule(*args, paired_image=None, deduplicate=False, **kwargs)[source]#
Bases:
DataModule
Used for cross-modal retrieval (text-to-image or image-to-text) and optionally for joint cross-modal and image-image retrieval.
- Parameters:
*args (additionnal arguments are passed to DataModule) –
**kwargs (additionnal arguments are passed to DataModule) –
paired_image (str, optional) – If not None (default), should hold the key for the path to an image paired with ‘image’, so that a joint image-image contrastive loss may be applied in CrossModal(Trainee).
deduplicate (bool, optional) – Will remove text (and paired_image) duplicates. Defaults to False (assumes there are no duplicates).
- class meerqat.train.data.QuestionAnsweringDataModule(*args, kb, image_kb=None, search_key='search', filter_train_rels=False, keep_kb_columns=None, kb_format=None, image_kb_format=None, kb_input_key='passage', **kwargs)[source]#
Bases:
DataModule
Base class for Question Answering. Should work for both IR and RC.
The core idea is that it relies on a Knowledge Base (KB) to retrieve relevant and irrelevant passages for the questions in the dataset.
We need to create the batch of questions and passages on-the-fly The inputs should be shaped like (N * M, L), where:
N - number of distinct questions (equal to the batch size)
M - number of passages per question in a batch
L - sequence length
- Parameters:
*args (additionnal arguments are passed to DataModule) –
**kwargs (additionnal arguments are passed to DataModule) –
kb (str) – path towards the knowledge base (Dataset) used to get the passages
image_kb (str, optional) – Path to the KB that holds pre-computed image features Can be mapped from kb using kb[‘index’]
search_key (str, optional) –
This column in the dataset suffixed by ‘_indices’ and ‘_scores’ should hold the result of information retrieval used during evaluation (e.g. the output of ir.search) Suffixed by “_provenance_indices” and “_irrelevant_indices” it should hold:
the union of relevant search and provenance_indices
irrelevant results from the search
used during training (according to M and n_relevant_passages) Defaults to ‘search’
filter_train_rels (bool, optional) –
keep_kb_columns (list, optional) – Keep only these features in kb and image_kb (defaults to keep everything)
kb_format (dict, optional) – see Dataset.set_format Overrides keep_kb_columns.
image_kb_format (dict, optional) – see Dataset.set_format Overrides keep_kb_columns.
kb_input_key (str, optional) – Defaults to ‘passage’
- setup(stage=None)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- filter_rels(subset='train')[source]#
Filter out questions of the dataset without any relevant passages.
- get_training_passages(item, with_scores=False)[source]#
- Parameters:
item (dict) – item (e.g. question) from self.train_dataset or self.eval_dataset.
with_scores (bool, optional) – Also return the scores corresponding to the passages Defaults to False.
- Returns:
relevant_passages, irrelevant_passages (list[dict]) – List of relevant and irrelevant passages selected from self.kb according to:
self.n_relevant_passages
self.M
self.search_key
relevant_scores (np.ndarray, optional) – Shape (self.n_relevant_passages, ) Returned only if with_scores
irrelevant_scores (np.ndarray, optional) – Shape (self.M-self.n_relevant_passages, ) Returned only if with_scores
- class meerqat.train.data.BiEncoderDataModule(*args, passage_type_ids=False, **kwargs)[source]#
Bases:
QuestionAnsweringDataModule
- Parameters:
*args (additionnal arguments are passed to QuestionAnsweringDataModule) –
**kwargs (additionnal arguments are passed to QuestionAnsweringDataModule) –
passage_type_ids (bool, optional) – Pass token_type_ids=1 for passages (see BertTokenizer for details). This might be useful if you use a shared encoder to encode questions and passages. Defaults to False (by default you use different models to encode questions and passages).
- collate_fn(items)[source]#
Collate batch so that each question is associate with n_relevant_passages and M-n irrelevant ones. Also tokenizes input strings
N - number of questions in a batch
M - number of passages per questions
d - dimension of the model/embeddings
- question_inputs: dict[torch.LongTensor]
- input_ids: torch.LongTensor
shape (N, L)
- **kwargs:
more tensors depending on the tokenizer, e.g. attention_mask
- context_inputs: dict[torch.LongTensor]
- input_ids: torch.LongTensor
shape (N*M, L) The first N rows correspond to the relevant contexts for the N questions The rest N*(M-1) rows are irrelevant contexts for all questions.
- **kwargs:
idem
- labels: torch.LongTensor
shape (N, ) Index of the relevant passage in context_inputs. Should be arange(N) except for padding passages.
- class meerqat.train.data.JointBiEncoderAndClipDataModule(*args, cm_tokenizer_class, cm_tokenizer_name_or_path, cm_tokenization_kwargs=None, **kwargs)[source]#
Bases:
BiEncoderDataModule
- collate_fn(items)[source]#
Collate batch so that each question is associate with n_relevant_passages and M-n irrelevant ones. Also tokenizes input strings
N - number of questions in a batch
M - number of passages per questions
d - dimension of the model/embeddings
- question_inputs: dict[torch.LongTensor]
- input_ids: torch.LongTensor
shape (N, L)
- **kwargs:
more tensors depending on the tokenizer, e.g. attention_mask
- context_inputs: dict[torch.LongTensor]
- input_ids: torch.LongTensor
shape (N*M, L) The first N rows correspond to the relevant contexts for the N questions The rest N*(M-1) rows are irrelevant contexts for all questions.
- **kwargs:
idem
- labels: torch.LongTensor
shape (N, ) Index of the relevant passage in context_inputs. Should be arange(N) except for padding passages.
- class meerqat.train.data.ReRankerDataModule(*args, run_path=None, qrels_path=None, **kwargs)[source]#
Bases:
QuestionAnsweringDataModule
- Parameters:
*args – additional arguments are passed to QuestionAnsweringDataModule
**kwargs – additional arguments are passed to QuestionAnsweringDataModule
run_path (str, optional) – Path to the ranx run stored in the TREC format that holds the IR results. Optional if you want to train only. Defaults to None.
qrels_path (str, optional) – Path to the ranx qrels stored in the TREC format. Used during evaluation. Optional if you want to train only. Defaults to None.
- collate_fn(items)[source]#
Collate batch so that each question is associate with 1 and M-1 irrelevant ones. Also tokenizes input strings
- input_ids: Tensor[int]
shape (N * M, L) 1 relevant passage followed by M-1 irrelevant ones, N times /!different from BiEncoderDataModule
- labels: torch.LongTensor, optional
shape (N, ) Index of the relevant passage in input_ids. Should be 0 except for padding passages. Returned only when training.
**kwargs: more tensors depending on the tokenizer, e.g. attention_mask
- class meerqat.train.data.ReaderDataModule(*args, max_n_answers=10, run_path=None, train_original_answer_only=True, oracle=False, extract_name=False, mapping_run=None, **kwargs)[source]#
Bases:
QuestionAnsweringDataModule
- Parameters:
*args – additional arguments are passed to QuestionAnsweringDataModule
**kwargs – additional arguments are passed to QuestionAnsweringDataModule
max_n_answers (int, optional) – The answer might be found several time in the same passage, this is a threshold to enable batching Defaults to 10.
train_original_answer_only (bool, optional) – Whether the model should be trained to predict only the original answer (default) or all alternative answers (with the only limit of max_n_answers) This has no effect on the evaluation (where all alternative answers are always considered)
oracle (bool, optional) – Whether to use only relevant passages at inference (stored in {search_key}_provenance_indices) Will enforce n_relevant_passages=M Defaults to False (use IR passages at inference, stored in {search_key}_indices)
run_path (str, optional) – Path to the ranx run stored in the TREC format that holds the IR results. To be used instead of search_key at inference. Defaults to None.
extract_name (bool, optional) – Train the model to extract the name of the entity instead of the answer. Defaults to False.
mapping_run (str, optional) – Path to the mapping
- collate_fn(items)[source]#
Collate batch so that each question is associate with n_relevant_passages and M-n irrelevant ones. Also tokenizes input strings
- input_ids: Tensor[int]
shape (N * M, L)
- start_positions, end_positions: Tensor[int]
shape (N, M, max_n_answers)
- answer_mask: Tensor[int]
shape (N, M, max_n_answers)
- passage_scores: Tensor[float], optional
shape (N * M) only in evaluation mode
**kwargs: more tensors depending on the tokenizer, e.g. attention_mask
- class meerqat.train.data.ICT(*args, biencoder=True, sentences_per_target=4, prepend_title=False, text_mask_rate=1.0, image_mask_rate=1.0, **kwargs)[source]#
Bases:
DataModule
Extends the Inverse Cloze Task (ICT, [2]) to multimodal documents. Given a wikipedia section, one sentence is considered as a pseudo-question and the nearby sentences as a relevant passage. In this multimodal setting, we also consider the image of the section in the query and the infobox/main image of the article in the visual passage.
The only point in common with QuestionAnsweringDataModule is the use of PreComputedImageFeatures
- Parameters:
*args – additional arguments are passed to DataModule
**kwargs – additional arguments are passed to DataModule
biencoder (bool, optional) – Expected kind of model: bi-encoder or cross-encoder i.e. whether to concatenate questions with passages or leave them in separate tensors
sentences_per_target (int, optional) – Number of sentences in the target passages
prepend_title (bool, optional) – Whether to preprend the title of the article to the target passage
text_mask_rate (float, optional) – Rate at which the pseudo-question is masked in the target passage
image_mask_rate (float, optional) – Rate at which the infobox image is used as target (keep input image otherwise)
References