-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimg2vec.py
More file actions
86 lines (65 loc) · 2.44 KB
/
img2vec.py
File metadata and controls
86 lines (65 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import pickle
import faiss
import torch
import torch.nn as nn
import argparse
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm
# Load the pretrained ResNet model
resnet = models.resnet50(pretrained=True)
# Add a linear layer at the end
resnet.fc = nn.Linear(resnet.fc.in_features, 256)
resnet.eval()
# Preprocessing transforms for the input image
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def image_embedding(image_path):
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image_tensor = preprocess(image).unsqueeze(0)
# Forward pass through the ResNet model
with torch.no_grad():
output = resnet(image_tensor)
# Extract the embedding tensor
emb_tensor = output.squeeze()
print(emb_tensor.shape)
return emb_tensor
def registration_P(directory_path):
index_file = "img_index.index"
if os.path.exists(index_file):
# Load the existing index
index = faiss.read_index(index_file)
else:
# Create a new index
index = faiss.IndexFlatL2(256) # Assuming the dimension of your embeddings is 256
names = [] # List to store the names
# Get a list of all image files
file_list = [file_name for file_name in os.listdir(directory_path) if
file_name.endswith('.jpg') or file_name.endswith('.png')]
# Loop over images with a progress bar
progress_bar = tqdm(file_list, desc="Processing images", unit="image")
for file_name in progress_bar:
image_path = os.path.join(directory_path, file_name)
# Process the image and get the embedding
emb_tensor = image_embedding(image_path)
emb_numpy = emb_tensor.detach().numpy().reshape(1, -1)
# Add the new vector to the index
index.add(emb_numpy)
# Add the name to the list
names.append(file_name)
# Save the names list to a file
with open('img_names.pkl', 'wb') as f:
pickle.dump(names, f)
# Write the index back to the file
faiss.write_index(index, index_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo script")
parser.add_argument("--file", type=str, help="Input file path")
args = parser.parse_args()
registration_P(args.file)