meerqat.models.utils module#

Misc. utility functions.

class meerqat.models.utils.TanhGate[source]#

Bases: Module

Flamingo-style tanh gating (init at 0) [1]

References

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

meerqat.models.utils.map_if_not_None(values, function, *args, default_value=None, **kwargs)[source]#

Map all not None values through function (along with additionnal arguments)

Values that are None will output default_value

Parameters:
  • values (list) – of len batch_size

  • function (callable) –

  • default_value (optional) – Defaults to None

  • *args – additionnal arguments are passed to function

  • **kwargs – additionnal arguments are passed to function

Returns:

Output – of len batch_size (same as values), with default_value where values are None

Return type:

list

meerqat.models.utils.debug_shape(batch, prefix='')[source]#

Recursively prints the shape of Tensor and ndarray in nested dict/BatchEncoding/tuple/list

meerqat.models.utils.prepare_inputs(data)[source]#

Moves tensors in data to device, be it a tensor or a nested list/dictionary of tensors. Adapted from transformers.Trainer