Source code for meerqat.models.utils
"""Misc. utility functions."""
import torch
from torch import nn
import numpy as np
from transformers.tokenization_utils_base import BatchEncoding
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class TanhGate(nn.Module):
"""
Flamingo-style tanh gating (init at 0) [1]_
References
----------
.. [1] Jean-Baptiste Alayrac et al. (2022).
Flamingo: a Visual Language Model for Few-Shot Learning. ArXiv:2204.14198.
"""
def __init__(self):
super().__init__()
self.gate_param = nn.Parameter(torch.tensor([0.]))
self.tanh = nn.Tanh()
[docs] def forward(self, x):
return x * self.tanh(self.gate_param)
[docs]def map_if_not_None(values, function, *args, default_value=None, **kwargs):
"""
Map all not None values through function (along with additionnal arguments)
Values that are None will output ``default_value``
Parameters
----------
values: list
of len batch_size
function: callable
default_value: optional
Defaults to None
*args, **kwargs:
additionnal arguments are passed to function
Returns
-------
Output: list
of len batch_size (same as values), with ``default_value`` where values are None
"""
# 1. filter out values that are None
output = []
not_None_values, not_None_values_indices = [], []
for i, value in enumerate(values):
# will be overwritten for not_None_values
output.append(default_value)
if value is not None:
not_None_values.append(value)
not_None_values_indices.append(i)
if not not_None_values:
return output
# 2. map values that are not None to function
not_None_output = function(not_None_values, *args, **kwargs)
# 3. return the results in a list of list with proper indices
for j, i in enumerate(not_None_values_indices):
output[i] = not_None_output[j]
return output
[docs]def debug_shape(batch, prefix=""):
"""Recursively prints the shape of Tensor and ndarray in nested dict/BatchEncoding/tuple/list"""
for key, data in batch.items():
if isinstance(data, (dict, BatchEncoding)):
debug_shape(data, prefix=f"{prefix}.{key}")
elif isinstance(data, (tuple, list)):
for i, v in enumerate(data):
debug_shape(v, prefix=f"{prefix}.{key}.{i}")
elif isinstance(data, (torch.Tensor, np.ndarray)):
print(f"{prefix}.{key} ({type(data)}): {data.shape}")
else:
print(f"{prefix}.{key} ({type(data)})")