meerqat.models.mm module#

Implements the two main architectures presented in the ECIR-2023 paper.

class meerqat.models.mm.MMConfig(*args, n_images=1, n_faces=4, face_kwargs=None, image_kwargs=None, face_and_image_are_exclusive=False, no_text=False, gating=False, **kwargs)[source]#

Bases: BertConfig

Base configuration class for multimodal models based on BertConfig.

Parameters:
  • *args – additional arguments are passed to BertConfig.

  • **kwargs – additional arguments are passed to BertConfig.

  • n_images (int, optional) – Number of images to embed alongside with text. Each image can be mapped to multiple face features or image features. If greater than 1, will be assigned to a type embedding (analog to BERT).

  • n_faces (int, optional) – Number of faces that the multimodal model should take as input. Defaults to 4.

  • face_kwargs (dict, optional) – Keyword arguments used for the FaceEmbedding module. Defaults to dict(face_dim=512, bbox_dim=7).

  • image_kwargs (dict, optional) –

    Keyword arguments used for as many ImageEmbedding modules (one per key). Defaults to {

    ”clip-RN50”: {“input_dim”: 1024}, “imagenet-RN50”: {“input_dim”: 2048}

    }

  • face_and_image_are_exclusive (bool, optional) – Whether face and full-image representation should be combined (default) or exclusive. Handled with attention masks in transformers

  • no_text (bool, optional) – Whether to rely only on faces and images. In this case, only the [CLS] token embedding is concatenated to the image features. Defaults to False.

  • gating (bool, optional) –

    Whether to use flamingo-style tanh gating (init at 0) [2]_

    Defaults to no gating

    References

  • ----------

  • (2022). (.. [2] Jean-Baptiste Alayrac et al.) – Flamingo: a Visual Language Model for Few-Shot Learning. ArXiv:2204.14198.

class meerqat.models.mm.FlamantConfig(*args, multimodal_attention_every=1, image_num_attention_heads=12, image_intermediate_size=3072, image_hidden_dropout_prob=0.1, image_attention_probs_dropout_prob=0.1, **kwargs)[source]#

Bases: MMConfig

Hyperparameters for multimodal cross-attention layers

Same defaults as BertConfig.

meerqat.models.mm.overwrite_bert_config(flamant_config)[source]#

Overwrite BERT parameters in the input flamant_config if they start with “image_”. See usage in FlamantLayer.

Parameters:

flamant_config (FlamantConfig) –

Returns:

bert_config

Return type:

BertConfig

class meerqat.models.mm.FlamantLayer(config)[source]#

Bases: Module

Adapted from transformers.BertLayer

forward(hidden_states: Tensor, image_embeddings: Tensor, attention_mask: Optional[FloatTensor] = None, image_attention_mask: Optional[FloatTensor] = None, head_mask: Optional[FloatTensor] = None, encoder_hidden_states: Optional[FloatTensor] = None, encoder_attention_mask: Optional[FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[FloatTensor]]] = None, output_attentions: Optional[bool] = False) Tuple[Tensor][source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

feed_forward_chunk(attention_output)[source]#
class meerqat.models.mm.FlamantEncoder(config)[source]#

Bases: Module

Like BertEncoder but with FlamantLayer instead of BertLayer every n layers

forward(hidden_states: Tensor, image_embeddings: Tensor, attention_mask: Optional[FloatTensor] = None, image_attention_mask: Optional[FloatTensor] = None, head_mask: Optional[FloatTensor] = None, encoder_hidden_states: Optional[FloatTensor] = None, encoder_attention_mask: Optional[FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class meerqat.models.mm.FlamantModel(config, add_pooling_layer=False)[source]#

Bases: BertPreTrainedModel

Fuses modalities with gated cross-attention layers like in Flamingo [2]_ Adapted from transformers.BertModel

config_class#

alias of FlamantConfig

load_tf_weights = None#
get_input_embeddings()[source]#

Returns the model’s input embeddings.

Returns:

A torch module mapping vocabulary to hidden states.

Return type:

nn.Module

set_input_embeddings(value)[source]#

Set model’s input embeddings.

Parameters:

value (nn.Module) – A module mapping vocabulary to hidden states.

forward(text_inputs, face_inputs, image_inputs, output_attentions=False, output_hidden_states=False, return_dict=True)[source]#
Parameters:
  • text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.BertModel

  • face_inputs (dict[str, torch.FloatTensor]) –

    {

    “face”: (batch_size, n_images, n_faces, face_dim), “bbox”: (batch_size, n_images, n_faces, bbox_dim), “attention_mask”: (batch_size, n_images, n_faces)

    }

  • image_inputs (dict[str, dict[str, torch.FloatTensor]]) –

    {

    model: {

    ”input”: (batch_size, n_images, image_dim) “attention_mask”: (batch_size, n_images)

    }

    }

class meerqat.models.mm.ViltForIR(config, add_pooling_layer=False)[source]#

Bases: ViltPreTrainedModel

Pools ViLT using the representation of the [CLS] token, i.e. DPR-style, not with ViltPooler (ITM pre-trained layer), except if add_pooling_layer=True

forward(*args, return_dict=True, **kwargs)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class meerqat.models.mm.CLIPForIR(config)[source]#

Bases: PreTrainedModel

Fuses image and text embeddings simply by summing them to be compatible with BiEncoder.

Because BiEncoder uses dot-product similarity, note that this will be equivalent to computing:

i_q*i_p + i_q*t_p + t_q*t_p + t_q*i_p

Where i, t stand for image, text and _q and _p suffixes stand for question and passage (or context) i.e. computing all mono-modal and cross-modal similarities.

But it might be worth using another trainee than BiEncoder to be able to scale these similarities.

config_class#

alias of CLIPConfig

base_model_prefix = 'clip'#
forward(*args, return_dict=True, return_loss=False, **kwargs)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class meerqat.models.mm.ECAEncoder(config, init_weights_like_bert=False)[source]#

Bases: PreTrainedModel

Text and image are fused by concatenating them at the sequence-level then feeding them to BERT, à la UNITER [1]
  • one face ≃ one token

  • one image ≃ one token

The multimodal representation is obtained from the “[CLS]” token.

When using gating (see MMConfig), it is done before the attention layer, unlike in Flamingo [2]_

References

config_class#

alias of MMConfig

load_tf_weights = None#
base_model_prefix = 'bert_model'#
forward(text_inputs, face_inputs, image_inputs, output_attentions=False, output_hidden_states=False, return_dict=True)[source]#
Parameters:
  • text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.BertModel

  • face_inputs (dict[str, torch.FloatTensor]) –

    {

    “face”: (batch_size, n_images, n_faces, face_dim), “bbox”: (batch_size, n_images, n_faces, bbox_dim), “attention_mask”: (batch_size, n_images, n_faces)

    }

  • image_inputs (dict[str, dict[str, torch.FloatTensor]]) –

    {

    model: {

    ”input”: (batch_size, n_images, image_dim) “attention_mask”: (batch_size, n_images)

    }

    }

class meerqat.models.mm.ILFConfig(*args, question_encoder=True, **kwargs)[source]#

Bases: MMConfig

Same as MMConfig with an extra parameter: question_encoder: bool, optional

Whether to use DPRQuestionEncoder (default) or DPRContextEncoder. This makes no real differences in the architecture, only the name changes.

class meerqat.models.mm.IntermediateLinearFusion(config)[source]#

Bases: PreTrainedModel

Fuses DPR’s text representation with image embeddings by projecting them linearly in the same space

config_class#

alias of ILFConfig

load_tf_weights = None#
base_model_prefix = 'dpr_encoder'#
forward(text_inputs, face_inputs, image_inputs)[source]#
Parameters:
  • text_inputs (dict[str, torch.LongTensor]) – usual BERT inputs, see transformers.DPRQuestionEncoder

  • face_inputs (dict[str, torch.FloatTensor]) –

    {

    “face”: (batch_size, n_faces, face_dim), “bbox”: (batch_size, n_faces, bbox_dim), “attention_mask”: (batch_size, n_faces)

    }

  • image_inputs (dict[str, dict[str, torch.FloatTensor]]) –

    {

    model: {

    ”input”: (batch_size, image_dim) “attention_mask”: (batch_size, )

    }

    }