Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify find data initialization #1032

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 48 additions & 70 deletions deepface/modules/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def find(
file_name = f"representations_{model_name}.pkl"
file_name = file_name.replace("-", "_").lower()
datastore_path = os.path.join(db_path, file_name)
representations = []

df_cols = [
"identity",
Expand All @@ -110,93 +111,71 @@ def find(
"target_h",
]

if os.path.exists(datastore_path):
with open(datastore_path, "rb") as f:
representations = pickle.load(f)

if len(representations) > 0 and len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {datastore_path} is out-of-the-date."
"Please delete it and re-run."
)

alpha_employees = __list_images(path=db_path)
beta_employees = [representation[0] for representation in representations]

newbies = list(set(alpha_employees) - set(beta_employees))
oldies = list(set(beta_employees) - set(alpha_employees))
# Ensure the proper pickle file exists
if not os.path.exists(datastore_path):
with open(datastore_path, "wb") as f:
pickle.dump([], f)
f.close()

if newbies:
logger.warn(
f"Items {newbies} were added into {db_path}"
f" just after data source {datastore_path} created!"
)
newbies_representations = __find_bulk_embeddings(
employees=newbies,
model_name=model_name,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
normalization=normalization,
silent=silent,
)
representations = representations + newbies_representations
# Load the representations from the pickle file
with open(datastore_path, "rb") as f:
representations = pickle.load(f)
f.close()

if oldies:
logger.warn(
f"Items {oldies} were dropped from {db_path}"
f" just after data source {datastore_path} created!"
# Check if the representations are out-of-date
if len(representations) > 0:
if len(representations[0]) != len(df_cols):
raise ValueError(
f"Seems existing {datastore_path} is out-of-the-date."
"Please delete it and re-run."
)
representations = [rep for rep in representations if rep[0] not in oldies]

if newbies or oldies:
if len(representations) == 0:
raise ValueError(f"There is no image in {db_path} anymore!")

# save new representations
with open(datastore_path, "wb") as f:
pickle.dump(representations, f)
pickled_images = [representation[0] for representation in representations]
else:
pickled_images = []

if not silent:
logger.info(
f"{len(newbies)} new representations are just added"
f" whereas {len(oldies)} represented one(s) are just dropped"
f" in {os.path.join(db_path,file_name)} file."
)
# Get the list of images on storage
storage_images = __list_images(path=db_path)

if not silent:
logger.info(f"There are {len(representations)} representations found in {file_name}")
# Enforce data consistency amongst on disk images and pickle file
must_save_pickle = False
new_images = list(set(storage_images) - set(pickled_images)) # images added to storage
old_images = list(set(pickled_images) - set(storage_images)) # images removed from storage

else: # create representation.pkl from scratch
employees = __list_images(path=db_path)
if not silent:
logger.info(f"Found {len(new_images)} new images and {len(old_images)} removed images")

if len(employees) == 0:
raise ValueError(
f"Could not find any valid image in {db_path} folder!"
"Valid images are .jpg, .jpeg or .png files.",
)
# remove old images first
if len(old_images)>0:
representations = [rep for rep in representations if rep[0] not in old_images]
must_save_pickle = True

# ------------------------
# find representations for db images
representations = __find_bulk_embeddings(
employees=employees,
# find representations for new images
if len(new_images)>0:
representations += __find_bulk_embeddings(
employees=new_images,
model_name=model_name,
target_size=target_size,
detector_backend=detector_backend,
enforce_detection=enforce_detection,
align=align,
normalization=normalization,
silent=silent,
)

# -------------------------------
) # add new images
must_save_pickle = True

if must_save_pickle:
with open(datastore_path, "wb") as f:
pickle.dump(representations, f)
f.close()
if not silent:
logger.info(f"There are now {len(representations)} representations in {file_name}")

# Should we have no representations bailout
if len(representations) == 0:
if not silent:
logger.info(f"Representations stored in {datastore_path} file.")
toc = time.time()
logger.info(f"find function duration {toc - tic} seconds")
return []

# ----------------------------
# now, we got representations for facial database
Expand Down Expand Up @@ -287,10 +266,9 @@ def find(

# -----------------------------------

toc = time.time()

if not silent:
logger.info(f"find function lasts {toc - tic} seconds")
toc = time.time()
logger.info(f"find function duration {toc - tic} seconds")

return resp_obj

Expand Down
7 changes: 4 additions & 3 deletions tests/test_find.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import cv2
import pandas as pd
from deepface import DeepFace
Expand All @@ -10,7 +11,7 @@


def test_find_with_exact_path():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
Expand All @@ -30,7 +31,7 @@ def test_find_with_exact_path():


def test_find_with_array_input():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)
assert len(dfs) > 0
Expand All @@ -52,7 +53,7 @@ def test_find_with_array_input():


def test_find_with_extracted_faces():
img_path = "dataset/img1.jpg"
img_path = os.path.join("dataset","img1.jpg")
face_objs = DeepFace.extract_faces(img_path)
img = face_objs[0]["face"]
dfs = DeepFace.find(img, db_path="dataset", detector_backend="skip", silent=True)
Expand Down
Loading