Source code for meerqat.models.vilt

# -*- coding: utf-8 -*-
"""Copied from transformers because not accessible otherwise."""

# coding=utf-8
# Copyright 2022 NAVER AI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ViLT model."""

import collections.abc

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.models.vilt import ViltLayer
from transformers.modeling_outputs import BaseModelOutput


[docs]class ViltEmbeddings(nn.Module): """ Construct the text and patch embeddings. Text embeddings are equivalent to BERT embeddings. Patch embeddings are equivalent to ViT embeddings. """ def __init__(self, config): super().__init__() # text embeddings self.text_embeddings = TextEmbeddings(config) # patch embeddings self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.patch_embeddings = ViltPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) # modality type (text/patch) embeddings self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config
[docs] def visual_embed(self, pixel_values, pixel_mask, max_image_length=200): _, _, ph, pw = self.patch_embeddings.projection.weight.shape x = self.patch_embeddings(pixel_values) x_mask = pixel_mask[:, None, :, :].float() x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long() x_h = x_mask[:, 0].sum(dim=1)[:, 0] x_w = x_mask[:, 0].sum(dim=2)[:, 0] batch_size, num_channels, height, width = x.shape patch_dim = self.config.image_size // self.config.patch_size spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim) pos_embed = torch.cat( [ nn.functional.pad( nn.functional.interpolate( spatial_pos, size=(h, w), mode="bilinear", align_corners=True, ), (0, width - w, 0, height - h), ) for h, w in zip(x_h, x_w) ], dim=0, ) pos_embed = pos_embed.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2) # Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13 patch_index = torch.stack( torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1 ).to(device=x_mask.device) patch_index = patch_index[None, None, :, :, :] patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1) patch_index = patch_index.flatten(1, 3) x_mask = x_mask.flatten(1) if max_image_length < 0 or max_image_length is None or not isinstance(max_image_length, int): # suppose aug is 800 x 1333, then, maximum effective res is 800 x 1333 (if one side gets bigger, the other will be constrained and be shrinked) # (800 // self.patch_size) * (1333 // self.patch_size) is the maximum number of patches that single image can get. # if self.patch_size = 32, 25 * 41 = 1025 # if res is 384 x 640, 12 * 20 = 240 effective_resolution = x_h * x_w max_image_length = effective_resolution.max() else: effective_resolution = x_h * x_w max_image_length = min(effective_resolution.max(), max_image_length) valid_idx = x_mask.nonzero(as_tuple=False) non_valid_idx = (1 - x_mask).nonzero(as_tuple=False) unique_rows = valid_idx[:, 0].unique() valid_row_idx = [valid_idx[valid_idx[:, 0] == u] for u in unique_rows] non_valid_row_idx = [non_valid_idx[non_valid_idx[:, 0] == u] for u in unique_rows] valid_nums = [v.size(0) for v in valid_row_idx] non_valid_nums = [v.size(0) for v in non_valid_row_idx] pad_nums = [max_image_length - v for v in valid_nums] select = list() for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)): if p <= 0: valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length) select.append(valid_row_idx[i][valid_choice]) else: pad_choice = torch.multinomial(torch.ones(nv).float(), p, replacement=True) select.append(torch.cat([valid_row_idx[i], non_valid_row_idx[i][pad_choice]], dim=0)) select = torch.cat(select, dim=0) x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels) x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1) # `patch_index` should be on the same device as `select` (for torch>=1.13), which is ensured at definition time. patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2) pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels) cls_tokens = self.cls_token.expand(batch_size, -1, -1) x = torch.cat((cls_tokens, x), dim=1) pos_embed = torch.cat( (self.position_embeddings[:, 0, :][:, None, :].expand(batch_size, -1, -1), pos_embed), dim=1 ) x = x + pos_embed x = self.dropout(x) x_mask = torch.cat([torch.ones(x_mask.shape[0], 1).to(x_mask), x_mask], dim=1) return x, x_mask, (patch_index, (height, width))
[docs] def forward( self, input_ids, attention_mask, token_type_ids, pixel_values, pixel_mask, inputs_embeds, image_embeds, image_token_type_idx=1, ): # PART 1: text embeddings text_embeds = self.text_embeddings( input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # PART 2: patch embeddings (with interpolated position encodings) if image_embeds is None: image_embeds, image_masks, patch_index = self.visual_embed( pixel_values, pixel_mask, max_image_length=self.config.max_image_length ) else: image_masks = pixel_mask.flatten(1) # PART 3: add modality type embeddings # 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2) if image_token_type_idx is None: image_token_type_idx = 1 text_embeds = text_embeds + self.token_type_embeddings( torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device) ) image_embeds = image_embeds + self.token_type_embeddings( torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device) ) # PART 4: concatenate embeddings = torch.cat([text_embeds, image_embeds], dim=1) masks = torch.cat([attention_mask, image_masks], dim=1) return embeddings, masks
[docs]class TextEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False )
[docs] def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, :seq_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings
[docs]class ViltPatchEmbeddings(nn.Module): """ Image to Patch Embedding. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
[docs] def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) x = self.projection(pixel_values) return x
[docs]class ViltEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False
[docs] def forward( self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, )
[docs]class ViltPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh()
[docs] def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output