Source code for meerqat.models.image
"""Building blocks for computer vision models."""
from torch import nn
[docs]class FaceEmbedding(nn.Module):
"""Projects a face feature in the embedding space using a linear layer together with the corresponding bounding box."""
def __init__(self, face_dim, bbox_dim, embedding_dim, dropout=0.1, layer_norm_eps=1e-12):
super().__init__()
self.face_proj = nn.Linear(face_dim, embedding_dim)
self.bbox_proj = nn.Linear(bbox_dim, embedding_dim)
self.LayerNorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
[docs] def forward(self, face, bbox, image_type_embeddings=None):
embedding = self.face_proj(face) + self.bbox_proj(bbox)
if image_type_embeddings is not None:
embedding += image_type_embeddings
embedding = self.LayerNorm(embedding)
return self.dropout(embedding)
[docs]class ImageEmbedding(nn.Module):
"""Projects an image feature in the embedding space using a linear layer."""
def __init__(self, input_dim, embedding_dim, dropout=0.1, layer_norm_eps=None):
super().__init__()
self.linear = nn.Linear(input_dim, embedding_dim)
if layer_norm_eps is not None:
self.LayerNorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
[docs] def forward(self, input, image_type_embeddings=None):
embedding = self.linear(input)
if image_type_embeddings is not None:
embedding += image_type_embeddings
embedding = self.LayerNorm(embedding)
return self.dropout(embedding)