diff --git a/grounded_sam_demo.py b/grounded_sam_demo.py index 732abcaa..a28d128a 100644 --- a/grounded_sam_demo.py +++ b/grounded_sam_demo.py @@ -1,11 +1,13 @@ import argparse import os +from os import listdir import sys import numpy as np import json import torch from PIL import Image +from PIL import UnidentifiedImageError sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "segment_anything")) @@ -28,10 +30,9 @@ import numpy as np import matplotlib.pyplot as plt - -def load_image(image_path): +def load_image(data_path): # load image - image_pil = Image.open(image_path).convert("RGB") # load image + image_pil = Image.open(data_path).convert("RGB") # load image transform = T.Compose( [ @@ -62,7 +63,10 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w if not caption.endswith("."): caption = caption + "." model = model.to(device) - image = image.to(device) + # torch.Tensor(image) + image = torch.Tensor(image).to(device) + # image = image.to(device) + # print(model) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) @@ -108,7 +112,7 @@ def show_box(box, ax, label): ax.text(x0, y0, label) -def save_mask_data(output_dir, mask_list, box_list, label_list): +def save_mask_data(output_dir, mask_list, box_list, label_list, datapath, i): value = 0 # 0 for background mask_img = torch.zeros(mask_list.shape[-2:]) @@ -117,7 +121,10 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): plt.figure(figsize=(10, 10)) plt.imshow(mask_img.numpy()) plt.axis('off') - plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0) + os.makedirs(f"{output_dir}/{datapath}", exist_ok=True) + plt.savefig(os.path.join(f"{output_dir}/{datapath}/mask_{i}"), bbox_inches="tight", dpi=300, pad_inches=0.0) + plt.clf() + plt.close() json_data = [{ 'value': value, @@ -133,8 +140,8 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): 'logit': float(logit), 'box': box.numpy().tolist(), }) - with open(os.path.join(output_dir, 'mask.json'), 'w') as f: - json.dump(json_data, f) + # with open(os.path.join(f"{output_dir}/{f}/mask_" + i.split(".")[0] + ".json"), 'w') as f: + # json.dump(json_data, f) if __name__ == "__main__": @@ -156,7 +163,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): parser.add_argument( "--use_sam_hq", action="store_true", help="using sam-hq for prediction" ) - parser.add_argument("--input_image", type=str, required=True, help="path to image file") + parser.add_argument("--data_path", type=str, required=True, help="path to image file") parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") parser.add_argument( "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" @@ -176,7 +183,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): sam_checkpoint = args.sam_checkpoint sam_hq_checkpoint = args.sam_hq_checkpoint use_sam_hq = args.use_sam_hq - image_path = args.input_image + data_path = args.data_path text_prompt = args.text_prompt output_dir = args.output_dir box_threshold = args.box_threshold @@ -187,56 +194,89 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): # make dir os.makedirs(output_dir, exist_ok=True) # load image - image_pil, image = load_image(image_path) - # load model - model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device) - - # visualize raw image - image_pil.save(os.path.join(output_dir, "raw_image.jpg")) - - # run grounding dino model - boxes_filt, pred_phrases = get_grounding_output( - model, image, text_prompt, box_threshold, text_threshold, device=device - ) - - # initialize SAM - if use_sam_hq: - predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) - else: - predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - predictor.set_image(image) - - size = image_pil.size - H, W = size[1], size[0] - for i in range(boxes_filt.size(0)): - boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) - boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 - boxes_filt[i][2:] += boxes_filt[i][:2] - - boxes_filt = boxes_filt.cpu() - transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) - - masks, _, _ = predictor.predict_torch( - point_coords = None, - point_labels = None, - boxes = transformed_boxes.to(device), - multimask_output = False, - ) - - # draw output image - plt.figure(figsize=(10, 10)) - plt.imshow(image) - for mask in masks: - show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) - for box, label in zip(boxes_filt, pred_phrases): - show_box(box.numpy(), plt.gca(), label) - - plt.axis('off') - plt.savefig( - os.path.join(output_dir, "grounded_sam_output.jpg"), - bbox_inches="tight", dpi=300, pad_inches=0.0 - ) - - save_mask_data(output_dir, masks, boxes_filt, pred_phrases) + for datapath in listdir(data_path): + img = listdir(f"{data_path}/{datapath}") + # print(img) + for im in img: + try: + image_pil, image = load_image(f"{data_path}/{datapath}/{im}") + print(f"processing:{data_path}/{datapath}/{im}") + except UnidentifiedImageError: + print(f"error:{data_path}/{datapath}/{im}") + error_data = { + 'error_image': im, + } + + error_file_path = os.path.join("PI_output/error.json") + if os.path.exists(error_file_path): + with open(error_file_path, 'r') as e: + json_data = json.load(e) + else: + json_data = [] + + json_data.append(error_data) + + with open(error_file_path, 'w') as f: + json.dump(json_data, f, indent=4, ensure_ascii=False) + continue + # load model + model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device) + + # visualize raw image + # os.makedirs(f"{output_dir}/raw_image/{f}", exist_ok=True) + # image_pil.save(os.path.join(f"{output_dir}/raw_image/{f}/raw_image_{im}")) + + # run grounding dino model + boxes_filt, pred_phrases = get_grounding_output( + model, image, text_prompt, box_threshold, text_threshold, device=device + ) + + # initialize SAM + if use_sam_hq: + predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) + else: + predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) + try: + image = cv2.imread(f"{data_path}/{datapath}/{im}") + if image is None: + raise ValueError(f"Image '{data_path}/{datapath}/{im}' loading failed. Check the file path and format.") + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + predictor.set_image(image) + + size = image_pil.size + H, W = size[1], size[0] + for i in range(boxes_filt.size(0)): + boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) + boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 + boxes_filt[i][2:] += boxes_filt[i][:2] + + boxes_filt = boxes_filt.cpu() + transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) + + masks, _, _ = predictor.predict_torch( + point_coords = None, + point_labels = None, + boxes = transformed_boxes.to(device), + multimask_output = False, + ) + + # draw output image + plt.figure(figsize=(10, 10)) + plt.imshow(image) + for mask in masks: + show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) + for box, label in zip(boxes_filt, pred_phrases): + show_box(box.numpy(), plt.gca(), label) + + plt.axis('off') + plt.clf() + plt.close() + # plt.savefig( + # os.path.join(f"{output_dir}/{f}/grounded_sam_output_{i}"), + # bbox_inches="tight", dpi=300, pad_inches=0.0 + # ) + + save_mask_data(output_dir, masks, boxes_filt, pred_phrases, datapath, im) + torch.cuda.empty_cache() + except Exception as e: + print(f"An error occurred: {e}")