# -*- coding: utf-8 -*-
"""Copied from transformers because not accessible otherwise."""
# coding=utf-8
# Copyright 2022 NAVER AI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ViLT model."""
import collections.abc
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.vilt import ViltLayer
from transformers.modeling_outputs import BaseModelOutput
[docs]class ViltEmbeddings(nn.Module):
"""
Construct the text and patch embeddings.
Text embeddings are equivalent to BERT embeddings.
Patch embeddings are equivalent to ViT embeddings.
"""
def __init__(self, config):
super().__init__()
# text embeddings
self.text_embeddings = TextEmbeddings(config)
# patch embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = ViltPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
# modality type (text/patch) embeddings
self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
[docs] def visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
_, _, ph, pw = self.patch_embeddings.projection.weight.shape
x = self.patch_embeddings(pixel_values)
x_mask = pixel_mask[:, None, :, :].float()
x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long()
x_h = x_mask[:, 0].sum(dim=1)[:, 0]
x_w = x_mask[:, 0].sum(dim=2)[:, 0]
batch_size, num_channels, height, width = x.shape
patch_dim = self.config.image_size // self.config.patch_size
spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim)
pos_embed = torch.cat(
[
nn.functional.pad(
nn.functional.interpolate(
spatial_pos,
size=(h, w),
mode="bilinear",
align_corners=True,
),
(0, width - w, 0, height - h),
)
for h, w in zip(x_h, x_w)
],
dim=0,
)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
x = x.flatten(2).transpose(1, 2)
# Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
patch_index = torch.stack(
torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
).to(device=x_mask.device)
patch_index = patch_index[None, None, :, :, :]
patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
patch_index = patch_index.flatten(1, 3)
x_mask = x_mask.flatten(1)
if max_image_length < 0 or max_image_length is None or not isinstance(max_image_length, int):
# suppose aug is 800 x 1333, then, maximum effective res is 800 x 1333 (if one side gets bigger, the other will be constrained and be shrinked)
# (800 // self.patch_size) * (1333 // self.patch_size) is the maximum number of patches that single image can get.
# if self.patch_size = 32, 25 * 41 = 1025
# if res is 384 x 640, 12 * 20 = 240
effective_resolution = x_h * x_w
max_image_length = effective_resolution.max()
else:
effective_resolution = x_h * x_w
max_image_length = min(effective_resolution.max(), max_image_length)
valid_idx = x_mask.nonzero(as_tuple=False)
non_valid_idx = (1 - x_mask).nonzero(as_tuple=False)
unique_rows = valid_idx[:, 0].unique()
valid_row_idx = [valid_idx[valid_idx[:, 0] == u] for u in unique_rows]
non_valid_row_idx = [non_valid_idx[non_valid_idx[:, 0] == u] for u in unique_rows]
valid_nums = [v.size(0) for v in valid_row_idx]
non_valid_nums = [v.size(0) for v in non_valid_row_idx]
pad_nums = [max_image_length - v for v in valid_nums]
select = list()
for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
if p <= 0:
valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
select.append(valid_row_idx[i][valid_choice])
else:
pad_choice = torch.multinomial(torch.ones(nv).float(), p, replacement=True)
select.append(torch.cat([valid_row_idx[i], non_valid_row_idx[i][pad_choice]], dim=0))
select = torch.cat(select, dim=0)
x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1)
# `patch_index` should be on the same device as `select` (for torch>=1.13), which is ensured at definition time.
patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2)
pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
pos_embed = torch.cat(
(self.position_embeddings[:, 0, :][:, None, :].expand(batch_size, -1, -1), pos_embed), dim=1
)
x = x + pos_embed
x = self.dropout(x)
x_mask = torch.cat([torch.ones(x_mask.shape[0], 1).to(x_mask), x_mask], dim=1)
return x, x_mask, (patch_index, (height, width))
[docs] def forward(
self,
input_ids,
attention_mask,
token_type_ids,
pixel_values,
pixel_mask,
inputs_embeds,
image_embeds,
image_token_type_idx=1,
):
# PART 1: text embeddings
text_embeds = self.text_embeddings(
input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
# PART 2: patch embeddings (with interpolated position encodings)
if image_embeds is None:
image_embeds, image_masks, patch_index = self.visual_embed(
pixel_values, pixel_mask, max_image_length=self.config.max_image_length
)
else:
image_masks = pixel_mask.flatten(1)
# PART 3: add modality type embeddings
# 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
if image_token_type_idx is None:
image_token_type_idx = 1
text_embeds = text_embeds + self.token_type_embeddings(
torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
)
image_embeds = image_embeds + self.token_type_embeddings(
torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
)
# PART 4: concatenate
embeddings = torch.cat([text_embeds, image_embeds], dim=1)
masks = torch.cat([attention_mask, image_masks], dim=1)
return embeddings, masks
[docs]class TextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
[docs] def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
[docs]class ViltPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
[docs] def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
x = self.projection(pixel_values)
return x
[docs]class ViltEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
[docs] def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
[docs]class ViltPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
[docs] def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output