Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 103 additions & 63 deletions grounded_sam_demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse
import os
from os import listdir
import sys

import numpy as np
import json
import torch
from PIL import Image
from PIL import UnidentifiedImageError

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
Expand All @@ -28,10 +30,9 @@
import numpy as np
import matplotlib.pyplot as plt


def load_image(image_path):
def load_image(data_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image
image_pil = Image.open(data_path).convert("RGB") # load image

transform = T.Compose(
[
Expand Down Expand Up @@ -62,7 +63,10 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
# torch.Tensor(image)
image = torch.Tensor(image).to(device)
# image = image.to(device)
# print(model)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
Expand Down Expand Up @@ -108,7 +112,7 @@ def show_box(box, ax, label):
ax.text(x0, y0, label)


def save_mask_data(output_dir, mask_list, box_list, label_list):
def save_mask_data(output_dir, mask_list, box_list, label_list, datapath, i):
value = 0 # 0 for background

mask_img = torch.zeros(mask_list.shape[-2:])
Expand All @@ -117,7 +121,10 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
plt.figure(figsize=(10, 10))
plt.imshow(mask_img.numpy())
plt.axis('off')
plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
os.makedirs(f"{output_dir}/{datapath}", exist_ok=True)
plt.savefig(os.path.join(f"{output_dir}/{datapath}/mask_{i}"), bbox_inches="tight", dpi=300, pad_inches=0.0)
plt.clf()
plt.close()

json_data = [{
'value': value,
Expand All @@ -133,8 +140,8 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
'logit': float(logit),
'box': box.numpy().tolist(),
})
with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
json.dump(json_data, f)
# with open(os.path.join(f"{output_dir}/{f}/mask_" + i.split(".")[0] + ".json"), 'w') as f:
# json.dump(json_data, f)


if __name__ == "__main__":
Expand All @@ -156,7 +163,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
parser.add_argument(
"--use_sam_hq", action="store_true", help="using sam-hq for prediction"
)
parser.add_argument("--input_image", type=str, required=True, help="path to image file")
parser.add_argument("--data_path", type=str, required=True, help="path to image file")
parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
parser.add_argument(
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
Expand All @@ -176,7 +183,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
sam_checkpoint = args.sam_checkpoint
sam_hq_checkpoint = args.sam_hq_checkpoint
use_sam_hq = args.use_sam_hq
image_path = args.input_image
data_path = args.data_path
text_prompt = args.text_prompt
output_dir = args.output_dir
box_threshold = args.box_threshold
Expand All @@ -187,56 +194,89 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
image_pil, image = load_image(image_path)
# load model
model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device)

# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))

# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)

# initialize SAM
if use_sam_hq:
predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
else:
predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)

# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)

plt.axis('off')
plt.savefig(
os.path.join(output_dir, "grounded_sam_output.jpg"),
bbox_inches="tight", dpi=300, pad_inches=0.0
)

save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
for datapath in listdir(data_path):
img = listdir(f"{data_path}/{datapath}")
# print(img)
for im in img:
try:
image_pil, image = load_image(f"{data_path}/{datapath}/{im}")
print(f"processing:{data_path}/{datapath}/{im}")
except UnidentifiedImageError:
print(f"error:{data_path}/{datapath}/{im}")
error_data = {
'error_image': im,
}

error_file_path = os.path.join("PI_output/error.json")
if os.path.exists(error_file_path):
with open(error_file_path, 'r') as e:
json_data = json.load(e)
else:
json_data = []

json_data.append(error_data)

with open(error_file_path, 'w') as f:
json.dump(json_data, f, indent=4, ensure_ascii=False)
continue
# load model
model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device)

# visualize raw image
# os.makedirs(f"{output_dir}/raw_image/{f}", exist_ok=True)
# image_pil.save(os.path.join(f"{output_dir}/raw_image/{f}/raw_image_{im}"))

# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)

# initialize SAM
if use_sam_hq:
predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
else:
predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
try:
image = cv2.imread(f"{data_path}/{datapath}/{im}")
if image is None:
raise ValueError(f"Image '{data_path}/{datapath}/{im}' loading failed. Check the file path and format.")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)

# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)

plt.axis('off')
plt.clf()
plt.close()
# plt.savefig(
# os.path.join(f"{output_dir}/{f}/grounded_sam_output_{i}"),
# bbox_inches="tight", dpi=300, pad_inches=0.0
# )

save_mask_data(output_dir, masks, boxes_filt, pred_phrases, datapath, im)
torch.cuda.empty_cache()
except Exception as e:
print(f"An error occurred: {e}")