"""Usage: embedding.py <dataset> [<config> --disable_caching --output=<path>]
Options:
--disable_caching Disables Dataset caching (useless when using save_to_disk), see datasets.set_caching_enabled()
--output=<path> Optionally save the resulting dataset there instead of overwriting the input dataset.
"""
from docopt import docopt
import json
from collections import OrderedDict
from multiprocessing import Pool
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import FeatureExtractionMixin
from datasets import load_from_disk, set_caching_enabled
from ..models.utils import device
from ..data.loading import load_image_batch, get_pretrained
[docs]class ImageEncoder(nn.Module):
"""Simply encode and pool the features"""
def __init__(self, encoder, pool):
super().__init__()
self.encoder = encoder
self.pool = pool
[docs] def forward(self, x):
features = self.encoder(x)
return self.pool(features)
[docs]def get_nn_module(Class_name, *args, **kwargs):
return getattr(nn, Class_name)(*args, **kwargs)
[docs]def from_pretrained(model_name='resnet50', imagenet=True, pretrained_model_path=None, **kwargs):
"""
Notes
-----
For models trained on other dataset than imagenet, don’t forget to pass the right num_classes in kwargs
Examples
--------
To load from a Places365 checkpoint, first process the state_dict as this:
>>> checkpoint = torch.load("resnet50_places365.pth.tar", map_location="cpu")
>>> state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
>>> torch.save(state_dict, "resnet50_places365_state_dict.pth")
"""
model = getattr(torchvision.models, model_name)(pretrained=imagenet, **kwargs)
if not imagenet:
print(f"Loading pre-trained model from '{pretrained_model_path}'")
state_dict = torch.load(pretrained_model_path)
model.load_state_dict(state_dict)
else:
print("Loaded pre-trained model on ImageNet")
return model
[docs]def get_encoder(torchvision_model):
"""Keep only convolutional layers (i.e. remove final pooling and classification layers)"""
if isinstance(torchvision_model, (torchvision.models.ResNet, )):
cutoff = -2
else:
raise NotImplementedError(f"Don't know where the convolutional layers end for {torchvision_model}")
return nn.Sequential(OrderedDict(list(torchvision_model.named_children())[:cutoff]))
[docs]def get_torchvision_model(pretrained_kwargs={}, pool_kwargs={}):
"""Get model pre-trained on ImageNet or Places365"""
torchvision_model = from_pretrained(**pretrained_kwargs)
encoder = get_encoder(torchvision_model)
pool = get_nn_module(**pool_kwargs)
return ImageEncoder(encoder, pool)
[docs]def embed(batch, model, transform, save_as='image_embedding', image_key='image',
call=None, pool=None):
images = load_image_batch(batch[image_key], pool=pool)
not_None_images, not_None_indices, output = [], [], []
for i, image in enumerate(images):
output.append(None)
if image is not None:
not_None_images.append(image)
not_None_indices.append(i)
if len(not_None_images) == 0:
return output
images = not_None_images
if isinstance(transform, FeatureExtractionMixin):
if pool is not None:
image_list = pool.map(transform, images)
images = {}
for k in image_list[0].keys():
images[k] = torch.tensor(
np.concatenate([image[k] for image in image_list]),
device=device
)
else:
images = transform(images, return_tensors="pt")
images = {k: v.to(device) for k, v in images.items()}
else:
if pool is not None:
images = pool.map(transform, images)
images = torch.stack(images).contiguous().to(device)
else:
images = [transform(image).unsqueeze(0) for image in images]
images = torch.cat(images).to(device)
method = model if call is None else getattr(model, call)
with torch.no_grad():
if isinstance(images, dict):
image_embeddings = method(**images)
else:
image_embeddings = method(images)
not_None_output = image_embeddings.squeeze().cpu().numpy()
for i, j in enumerate(not_None_indices):
output[j] = not_None_output[i]
batch[save_as] = output
return batch
[docs]def dataset_embed(dataset_path, map_kwargs={}, model_kwargs={}, transform_kwargs={},
output_path=None, keep_columns=None, processes=None, **fn_kwargs):
dataset = load_from_disk(dataset_path)
# defaults to overwrite the dataset
if output_path is None:
output_path = dataset_path
assert keep_columns is None, f"You probably don't want to overwrite {dataset_path} by keeping only {keep_columns}"
elif keep_columns is not None:
keep_columns = set(keep_columns)
dataset = dataset.remove_columns([c for c in dataset.column_names if c not in keep_columns])
fn_kwargs.update(get_model_and_transform(model_kwargs=model_kwargs, transform_kwargs=transform_kwargs))
fn_kwargs['pool'] = None if processes is None else Pool(processes=processes)
dataset = dataset.map(embed, batched=True, fn_kwargs=fn_kwargs, **map_kwargs)
dataset.save_to_disk(output_path)
if __name__ == '__main__':
args = docopt(__doc__)
set_caching_enabled(not args['--disable_caching'])
config_path = args['<config>']
if config_path is not None:
with open(config_path, 'rt') as file:
config = json.load(file)
else:
config = {}
dataset_embed(args['<dataset>'], output_path=args['--output'], **config)