"""Utility functions specific to Question Answering."""
import warnings
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import (
BertForQuestionAnswering, ViltPreTrainedModel, ViltModel
)
from transformers.modeling_outputs import BaseModelOutputWithPooling
from ..train.optim import multi_passage_rc_loss
from .outputs import ReaderOutput
from .vilt import ViltEmbeddings, ViltPooler, ViltEncoder
from .mm import ECAEncoder
[docs]def get_best_spans(start_probs, end_probs, weights=None, cannot_be_first_token=True):
"""
Get the best scoring spans from start and end probabilities
notations:
- N - number of distinct questions
- M - number of passages per question in a batch
- L - sequence length
Parameters
----------
start_probs, end_probs: Tensor
shape (N, M, L)
weights: Tensor, optional
shape (N, M)
Used to weigh the spans scores, e.g. might be BM25 scores from the retriever
cannot_be_first_token: bool, optional
(Default) null out the scores of start/end in the first token
(e.g. "[CLS]", used during training for irrelevant passages)
Returns
-------
passage_indices: Tensor
shape (N, )
start_indices, end_indices: Tensor
shape (N, )
start (inclusive) and end (exclusive) index of each span
"""
N, M, L = start_probs.shape
# 1. compute pairwise scores -> shape (N, M, L, L)
pairwise = start_probs.reshape(N, M, L, 1) @ end_probs.reshape(N, M, 1, L)
# fix scores where end < start
pairwise = torch.triu(pairwise)
# null out the scores of start in the first token (and thus end because of the upper triangle)
# (e.g. [CLS], used during training for irrelevant passages)
if cannot_be_first_token:
pairwise[:, :, 0, :] = 0
# eventually weigh the scores
if weights is not None:
minimum = weights.min()
if minimum < 1:
warnings.warn("weights should be > 1, adding 1-minimum")
weights += 1-minimum
pairwise *= weights.reshape(N, M, 1, 1)
# 2. find the passages with the maximum score
pairwise = pairwise.reshape(N, M, L * L)
max_per_passage = pairwise.max(axis=2).values
passage_indices = max_per_passage.argmax(axis=1)
pairwise_best_passages = pairwise[torch.arange(N), passage_indices]
# 3. finally find the best spans for each question
flat_argmaxes = pairwise_best_passages.argmax(axis=-1)
# convert from flat argmax to line index (start) and column index (end)
start_indices = torch.div(flat_argmaxes, L, rounding_mode='floor')
# add +1 to make end index exclusive so the spans can easily be used with slices
end_indices = (flat_argmaxes % L) + 1
return passage_indices, start_indices, end_indices
[docs]class MultiPassageBERT(BertForQuestionAnswering):
"""
PyTorch/Transformers implementation of Multi-passage BERT [1]_ (based on the global normalization [2]_)
i.e. groups passages per question before computing the softmax (and the NLL loss)
so that spans scores are comparable across passages
Code based on transformers.BertForQuestionAnswering, dpr.models.Reader
and https://github.com/allenai/document-qa/blob/master/docqa/nn/span_prediction.py
References
----------
.. [1] Zhiguo Wang, Patrick Ng, Xiaofei Ma, Ramesh Nallapati, and Bing Xiang.
2019. Multi-passage BERT: A Globally Normalized BERT Model for Open-domain Question Answering.
In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing
and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP),
pages 5878–5882, Hong Kong, China. Association for Computational Linguistics.
.. [2] Christopher Clark and Matt Gardner. 2018. Simple and Effective Multi-Paragraph Reading Comprehension.
In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers),
pages 845–855, Melbourne, Australia. Association for Computational Linguistics.
"""
def __init__(self, *args, fuse_ir_score=False, **kwargs):
super().__init__(*args, **kwargs)
self.fuse_ir_score = fuse_ir_score
if fuse_ir_score:
# easier than overriding Bert wieght initialization
self.score_proj_w = nn.Parameter(torch.ones((1,1)))
self.score_proj_b = nn.Parameter(torch.zeros(1))
self.weights_to_log = {
"score_proj_w": self.score_proj_w,
"score_proj_b": self.score_proj_b
}
[docs] def forward(self,
input_ids, passage_scores=None,
start_positions=None, end_positions=None, answer_mask=None,
return_dict=None, **kwargs):
"""
notations:
* N - number of distinct questions
* M - number of passages per question in a batch
* L - sequence length
Parameters
----------
input_ids: Tensor[int]
shape (N * M, L)
There should always be a constant number of passages (relevant or not) per question
passage_scores: FloatTensor, optional
shape (N * M, )
If self.fuse_ir_score, will be fused with start_logits and end_logits before computing loss
start_positions, end_positions: Tensor[int], optional
shape (N, M, max_n_answers)
The answer might be found several time in the same passage, maximum ``max_n_answers`` times
Defaults to None (i.e. don’t compute the loss)
answer_mask: Tensor[int], optional
shape (N, M, max_n_answers)
Used to mask the loss for answers that are not ``max_n_answers`` times in the passage
Required if start_positions and end_positions are specified
**kwargs: additional arguments are passed to BERT after being reshape like
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(input_ids, return_dict=True, **kwargs)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
if self.fuse_ir_score:
passage_scores = passage_scores.unsqueeze(1) @ self.score_proj_w + self.score_proj_b
start_logits += passage_scores
end_logits += passage_scores
# compute loss
if start_positions is not None and end_positions is not None:
pack = multi_passage_rc_loss(
input_ids,
start_positions,
end_positions,
start_logits,
end_logits,
answer_mask
)
# unpack so that the line is not hundreds columns long
total_loss, start_positions, end_positions, start_logits, end_logits, start_log_probs, end_log_probs = pack
else:
total_loss, start_log_probs, end_log_probs = None, None, None
if not return_dict:
output = (start_logits, end_logits, start_log_probs, end_log_probs) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return ReaderOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
start_log_probs=start_log_probs,
end_log_probs=end_log_probs,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
[docs]class MultiPassageECA(ECAEncoder):
"""Like MultiPassageBERT with a ECA backbone instead of BERT"""
def __init__(self, config, **kwargs):
assert not config.no_text, "no_text option is only for IR"
super().__init__(config, **kwargs)
self.fuse_ir_score = False
# like BertForQuestionAnswering
self.num_labels = config.num_labels
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(self, text_inputs, *args,
start_positions=None, end_positions=None, answer_mask=None,
return_dict=True, **kwargs):
input_ids = text_inputs['input_ids']
outputs = super().forward(text_inputs, *args, return_dict=return_dict, **kwargs)
# truncate to keep only text representations
# the answer is extracted from text and the sequence length must match start/end positions shape (L)
sequence_output = outputs.last_hidden_state[:, :input_ids.shape[1]]
# same as MultiPassageBERT
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
# compute loss
if start_positions is not None and end_positions is not None:
pack = multi_passage_rc_loss(
input_ids,
start_positions,
end_positions,
start_logits,
end_logits,
answer_mask
)
# unpack so that the line is not hundreds columns long
total_loss, start_positions, end_positions, start_logits, end_logits, start_log_probs, end_log_probs = pack
else:
total_loss, start_log_probs, end_log_probs = None, None, None
if not return_dict:
output = (start_logits, end_logits, start_log_probs, end_log_probs) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return ReaderOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
start_log_probs=start_log_probs,
end_log_probs=end_log_probs,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
[docs]class ViltMultiImageEmbeddings(ViltEmbeddings):
"""
Similar to the 'triplet' strategy of UNITER,
patches of multiple images are concatenated in the sequence dimension.
The resulting embedding thus have a sequence length of #tokens + num_patches*num_images
"""
[docs] def forward(
self,
input_ids,
attention_mask,
token_type_ids,
pixel_values,
pixel_mask,
passage_pixel_values,
passage_pixel_mask,
inputs_embeds
):
"""
Parameters
----------
input_ids (`torch.LongTensor` of shape `(batch_size, #tokens)`):
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
attention_mask (`torch.FloatTensor` of shape `(batch_size, #tokens})`):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
token_type_ids (`torch.LongTensor` of shape `(batch_size, #tokens)`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
pixel_values, passage_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`ViltFeatureExtractor`]. See
[`ViltFeatureExtractor.__call__`] for details.
pixel_mask, passage_pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, #tokens, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
"""
# PART 1: text embeddings
text_embeds = self.text_embeddings(
input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
# PART 2: patch embeddings (with interpolated position encodings)
image_embeds, image_masks, patch_index = self.visual_embed(
pixel_values, pixel_mask, max_image_length=self.config.max_image_length
)
passage_image_embeds, passage_image_masks, _ = self.visual_embed(
passage_pixel_values, passage_pixel_mask, max_image_length=self.config.max_image_length
)
# PART 3: add modality type embeddings
# 0 indicates text, 1 question image, 2 passage image
text_embeds = text_embeds + self.token_type_embeddings(
torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
)
image_embeds = image_embeds + self.token_type_embeddings(
torch.full_like(image_masks, 1, dtype=torch.long, device=text_embeds.device)
)
passage_image_embeds = passage_image_embeds + self.token_type_embeddings(
torch.full_like(passage_image_masks, 2, dtype=torch.long, device=text_embeds.device)
)
# PART 4: concatenate
embeddings = torch.cat([text_embeds, image_embeds, passage_image_embeds], dim=1)
masks = torch.cat([attention_mask, image_masks, passage_image_masks], dim=1)
return embeddings, masks
[docs]class ViltMultiImageModel(ViltModel):
"""Same as ViltModel with ViltMultiImageEmbeddings instead of ViltEmbeddings"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = ViltMultiImageEmbeddings(config)
self.encoder = ViltEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViltPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
passage_pixel_values: Optional[torch.FloatTensor] = None,
passage_pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
image_token_type_idx: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutputWithPooling, Tuple[torch.FloatTensor]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
text_batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
if pixel_values is not None and image_embeds is not None:
raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
elif pixel_values is None and image_embeds is None:
raise ValueError("You have to specify either pixel_values or image_embeds")
image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
if image_batch_size != text_batch_size:
raise ValueError("The text inputs and image inputs need to have the same batch size")
if pixel_mask is None:
pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, attention_mask = self.embeddings(
input_ids,
attention_mask,
token_type_ids,
pixel_values,
pixel_mask,
passage_pixel_values,
passage_pixel_mask,
inputs_embeds
)
# broadcast attention mask to all heads. N. B input_shape is only used for decoder models
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# TODO: alternatively, subclass ViltForImagesAndTextClassification and feed it two text-image pairs
[docs]class MultiPassageVilt(ViltPreTrainedModel):
"""Like MultiPassageBERT with a ViLT backbone instead of BERT"""
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
self.vilt = ViltMultiImageModel(config, add_pooling_layer=add_pooling_layer)
self.fuse_ir_score = False
# like BertForQuestionAnswering
self.num_labels = config.num_labels
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(self, input_ids, *args,
start_positions=None, end_positions=None, answer_mask=None,
return_dict=True, **kwargs):
outputs = self.vilt(input_ids, *args, return_dict=return_dict, **kwargs)
sequence_output = outputs[0]
# truncate to keep only text representations
# the answer is extracted from text and the sequence length must match start/end positions shape (L)
sequence_output = sequence_output[:, :input_ids.shape[1]]
# same as MultiPassageBERT
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
# compute loss
if start_positions is not None and end_positions is not None:
pack = multi_passage_rc_loss(
input_ids,
start_positions,
end_positions,
start_logits,
end_logits,
answer_mask
)
# unpack so that the line is not hundreds columns long
total_loss, start_positions, end_positions, start_logits, end_logits, start_log_probs, end_log_probs = pack
else:
total_loss, start_log_probs, end_log_probs = None, None, None
if not return_dict:
output = (start_logits, end_logits, start_log_probs, end_log_probs) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return ReaderOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
start_log_probs=start_log_probs,
end_log_probs=end_log_probs,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)