| import time |
| import os |
| import torch |
| import numpy as np |
| import torchvision |
| import torch.nn.functional as F |
| from torchvision.datasets import ImageFolder |
| import torchvision.transforms as transforms |
| from tqdm import tqdm |
| import pickle |
| import argparse |
| from PIL import Image |
|
|
| concat = lambda x: np.concatenate(x, axis=0) |
| to_np = lambda x: x.data.to("cpu").numpy() |
|
|
|
|
| class Wrapper(torch.nn.Module): |
| def __init__(self, model): |
| super(Wrapper, self).__init__() |
| self.model = model |
| self.avgpool_output = None |
| self.query = None |
| self.cossim_value = {} |
|
|
| def fw_hook(module, input, output): |
| self.avgpool_output = output.squeeze() |
|
|
| self.model.avgpool.register_forward_hook(fw_hook) |
|
|
| def forward(self, input): |
| _ = self.model(input) |
| return self.avgpool_output |
|
|
| def __repr__(self): |
| return "Wrappper" |
|
|
|
|
| def QueryToEmbedding(query_path): |
| dataset_transform = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| model = torchvision.models.resnet50(pretrained=True) |
| model.eval() |
| myw = Wrapper(model) |
|
|
| query_pil = Image.open(query_path) |
| query_pt = dataset_transform(query_pil) |
|
|
| with torch.no_grad(): |
| embedding = to_np(myw(query_pt.unsqueeze(0))) |
|
|
| return np.asarray([embedding]) |
|
|