meerqat.models.mm module#
Implements the two main architectures presented in the ECIR-2023 paper.
- class meerqat.models.mm.MMConfig(*args, n_images=1, n_faces=4, face_kwargs=None, image_kwargs=None, face_and_image_are_exclusive=False, no_text=False, gating=False, **kwargs)[source]#
Bases:
BertConfig
Base configuration class for multimodal models based on BertConfig.
- Parameters:
*args – additional arguments are passed to BertConfig.
**kwargs – additional arguments are passed to BertConfig.
n_images (int, optional) – Number of images to embed alongside with text. Each image can be mapped to multiple face features or image features. If greater than 1, will be assigned to a type embedding (analog to BERT).
n_faces (int, optional) – Number of faces that the multimodal model should take as input. Defaults to 4.
face_kwargs (dict, optional) – Keyword arguments used for the FaceEmbedding module. Defaults to dict(face_dim=512, bbox_dim=7).
image_kwargs (dict, optional) –
Keyword arguments used for as many ImageEmbedding modules (one per key). Defaults to {
”clip-RN50”: {“input_dim”: 1024}, “imagenet-RN50”: {“input_dim”: 2048}
}
face_and_image_are_exclusive (bool, optional) – Whether face and full-image representation should be combined (default) or exclusive. Handled with attention masks in transformers
no_text (bool, optional) – Whether to rely only on faces and images. In this case, only the [CLS] token embedding is concatenated to the image features. Defaults to False.
gating (bool, optional) –
- Whether to use flamingo-style tanh gating (init at 0) [2]_
Defaults to no gating
References
---------- –
(2022). (.. [2] Jean-Baptiste Alayrac et al.) – Flamingo: a Visual Language Model for Few-Shot Learning. ArXiv:2204.14198.
- class meerqat.models.mm.FlamantConfig(*args, multimodal_attention_every=1, image_num_attention_heads=12, image_intermediate_size=3072, image_hidden_dropout_prob=0.1, image_attention_probs_dropout_prob=0.1, **kwargs)[source]#
Bases:
MMConfig
Hyperparameters for multimodal cross-attention layers
Same defaults as BertConfig.
- meerqat.models.mm.overwrite_bert_config(flamant_config)[source]#
Overwrite BERT parameters in the input flamant_config if they start with “image_”. See usage in FlamantLayer.
- Parameters:
flamant_config (FlamantConfig) –
- Returns:
bert_config
- Return type:
BertConfig
- class meerqat.models.mm.FlamantLayer(config)[source]#
Bases:
Module
Adapted from transformers.BertLayer
- forward(hidden_states: Tensor, image_embeddings: Tensor, attention_mask: Optional[FloatTensor] = None, image_attention_mask: Optional[FloatTensor] = None, head_mask: Optional[FloatTensor] = None, encoder_hidden_states: Optional[FloatTensor] = None, encoder_attention_mask: Optional[FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[FloatTensor]]] = None, output_attentions: Optional[bool] = False) Tuple[Tensor] [source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class meerqat.models.mm.FlamantEncoder(config)[source]#
Bases:
Module
Like BertEncoder but with FlamantLayer instead of BertLayer every n layers
- forward(hidden_states: Tensor, image_embeddings: Tensor, attention_mask: Optional[FloatTensor] = None, image_attention_mask: Optional[FloatTensor] = None, head_mask: Optional[FloatTensor] = None, encoder_hidden_states: Optional[FloatTensor] = None, encoder_attention_mask: Optional[FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class meerqat.models.mm.FlamantModel(config, add_pooling_layer=False)[source]#
Bases:
BertPreTrainedModel
Fuses modalities with gated cross-attention layers like in Flamingo [2]_ Adapted from transformers.BertModel
- config_class#
alias of
FlamantConfig
- load_tf_weights = None#
- get_input_embeddings()[source]#
Returns the model’s input embeddings.
- Returns:
A torch module mapping vocabulary to hidden states.
- Return type:
nn.Module
- set_input_embeddings(value)[source]#
Set model’s input embeddings.
- Parameters:
value (nn.Module) – A module mapping vocabulary to hidden states.
- forward(text_inputs, face_inputs, image_inputs, output_attentions=False, output_hidden_states=False, return_dict=True)[source]#
- Parameters:
text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.BertModel
face_inputs (dict[str, torch.FloatTensor]) –
- {
“face”: (batch_size, n_images, n_faces, face_dim), “bbox”: (batch_size, n_images, n_faces, bbox_dim), “attention_mask”: (batch_size, n_images, n_faces)
}
image_inputs (dict[str, dict[str, torch.FloatTensor]]) –
- {
model: {
”input”: (batch_size, n_images, image_dim) “attention_mask”: (batch_size, n_images)
}
}
- class meerqat.models.mm.ViltForIR(config, add_pooling_layer=False)[source]#
Bases:
ViltPreTrainedModel
Pools ViLT using the representation of the [CLS] token, i.e. DPR-style, not with ViltPooler (ITM pre-trained layer), except if add_pooling_layer=True
- forward(*args, return_dict=True, **kwargs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class meerqat.models.mm.CLIPForIR(config)[source]#
Bases:
PreTrainedModel
Fuses image and text embeddings simply by summing them to be compatible with BiEncoder.
- Because BiEncoder uses dot-product similarity, note that this will be equivalent to computing:
i_q*i_p + i_q*t_p + t_q*t_p + t_q*i_p
Where i, t stand for image, text and _q and _p suffixes stand for question and passage (or context) i.e. computing all mono-modal and cross-modal similarities.
But it might be worth using another trainee than BiEncoder to be able to scale these similarities.
- config_class#
alias of
CLIPConfig
- base_model_prefix = 'clip'#
- forward(*args, return_dict=True, return_loss=False, **kwargs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class meerqat.models.mm.ECAEncoder(config, init_weights_like_bert=False)[source]#
Bases:
PreTrainedModel
- Text and image are fused by concatenating them at the sequence-level then feeding them to BERT, à la UNITER [1]
one face ≃ one token
one image ≃ one token
The multimodal representation is obtained from the “[CLS]” token.
When using gating (see MMConfig), it is done before the attention layer, unlike in Flamingo [2]_
References
- load_tf_weights = None#
- base_model_prefix = 'bert_model'#
- forward(text_inputs, face_inputs, image_inputs, output_attentions=False, output_hidden_states=False, return_dict=True)[source]#
- Parameters:
text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.BertModel
face_inputs (dict[str, torch.FloatTensor]) –
- {
“face”: (batch_size, n_images, n_faces, face_dim), “bbox”: (batch_size, n_images, n_faces, bbox_dim), “attention_mask”: (batch_size, n_images, n_faces)
}
image_inputs (dict[str, dict[str, torch.FloatTensor]]) –
- {
model: {
”input”: (batch_size, n_images, image_dim) “attention_mask”: (batch_size, n_images)
}
}
- class meerqat.models.mm.ILFConfig(*args, question_encoder=True, **kwargs)[source]#
Bases:
MMConfig
Same as MMConfig with an extra parameter: question_encoder: bool, optional
Whether to use DPRQuestionEncoder (default) or DPRContextEncoder. This makes no real differences in the architecture, only the name changes.
- class meerqat.models.mm.IntermediateLinearFusion(config)[source]#
Bases:
PreTrainedModel
Fuses DPR’s text representation with image embeddings by projecting them linearly in the same space
- load_tf_weights = None#
- base_model_prefix = 'dpr_encoder'#
- forward(text_inputs, face_inputs, image_inputs)[source]#
- Parameters:
text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.DPRQuestionEncoder
face_inputs (dict[str, torch.FloatTensor]) –
- {
“face”: (batch_size, n_faces, face_dim), “bbox”: (batch_size, n_faces, bbox_dim), “attention_mask”: (batch_size, n_faces)
}
image_inputs (dict[str, dict[str, torch.FloatTensor]]) –
- {
model: {
”input”: (batch_size, image_dim) “attention_mask”: (batch_size, )
}
}