Source code for meerqat.train.optim

"""Loss functions, optimizers, and schedulers."""
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR


[docs]class LinearLRWithWarmup(LambdaLR): """ Linear learning rate scheduler with linear warmup. Adapted from https://github.com/huggingface/transformers/blob/v4.23.0/src/transformers/optimization.py#L75 Parameters ---------- *args, **kwargs: additionnal arguments are passed to LambdaLR warmup_steps: int total_steps: int """ def __init__(self, *args, warmup_steps, total_steps, **kwargs): self.warmup_steps = warmup_steps self.total_steps = total_steps super().__init__(*args, **kwargs, lr_lambda=self.lr_lambda)
[docs] def lr_lambda(self, current_step: int): if current_step < self.warmup_steps: return float(current_step) / float(max(1, self.warmup_steps)) return max( 0.0, float(self.total_steps - current_step) / float(max(1, self.total_steps - self.warmup_steps)) )
def _calc_mml(loss_tensor): """Taken from dpr.models.reader to avoid extra-dependency""" marginal_likelihood = torch.sum(torch.exp(- loss_tensor - 1e10 * (loss_tensor == 0).float()), 1) # Mean reduction: this is different from https://github.com/facebookresearch/DPR/blob/a31212dc0a54dfa85d8bfa01e1669f149ac832b7/dpr/models/reader.py#L180 # who use sum reduction # by averaging, the loss does not depend on the batch size N (number of questions) # it might still depend on M, the number of passages, if `max_pooling=True` in `multi_passage_rc_loss` (not recommanded) return -torch.mean(torch.log(marginal_likelihood + torch.ones(loss_tensor.size(0), device=marginal_likelihood.device) * (marginal_likelihood == 0).float()))
[docs]def multi_passage_rc_loss(input_ids, start_positions, end_positions, start_logits, end_logits, answer_mask, max_pooling=False): n_times_m, L = input_ids.shape M = start_positions.shape[1] assert n_times_m % M == 0 N = n_times_m//M # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = L start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = nn.NLLLoss(reduction='none', ignore_index=ignored_index) log_softmax = nn.LogSoftmax(1) # reshape from (N * M, L) to (N, M * L) so that all M passages related to the same question # will share the same softmax normalization start_logits, end_logits = start_logits.view(N, M*L), end_logits.view(N, M*L) start_log_probs, end_log_probs = log_softmax(start_logits), log_softmax(end_logits) # after computing the softmax, reshape back to (N * M, L) # because the last dimension, L, must match the position indices (i.e. class label) in start_positions, end_positions start_log_probs, end_log_probs = start_log_probs.view(N*M, L), end_log_probs.view(N*M, L) start_logits, end_logits = start_logits.view(N*M, L), end_logits.view(N*M, L) # reshape to match model output start_positions, end_positions = start_positions.view(N*M, -1), end_positions.view(N*M, -1) answer_mask = answer_mask.to(device=input_ids.device, dtype=torch.float32).view(N*M, -1) # compute span loss for each answer position in passage (in range `max_n_answers`) # note that start_log_probs is constant through the loop start_losses = [(loss_fct(start_log_probs, _start_positions) * _span_mask) for (_start_positions, _span_mask) in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1))] end_losses = [(loss_fct(end_log_probs, _end_positions) * _span_mask) for (_end_positions, _span_mask) in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))] loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \ torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) # LEGACY: keep the maximum per passage for each question # this might be used to reproduce the experiments of the ViQuAE paper (Lerner et al. 2022) # but hurts performance. see https://github.com/facebookresearch/DPR/issues/244 if max_pooling: loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0] total_loss = _calc_mml(loss_tensor) return total_loss, start_positions, end_positions, start_logits, end_logits, start_log_probs, end_log_probs