ValueError: Dataloader
returned 0 length. Please make sure that it returns at least 1 batch
#9478
Answered
by
tchaton
morestart
asked this question in
code help: CV
-
use this code, i can get test data. But when i use pl data module to fit train model, i got dataloader returned 0 length error import os
from typing import Optional
import PIL
import cv2
import json
import copy
import numpy as np
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader
from det.det_modules import ResizeShortSize, IaaAugment, EastRandomCropData, MakeBorderMap, MakeShrinkMap
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
class ICDARDataset(Dataset):
def __init__(self, json_path, img_path, is_train=True):
self.ignore_tags = ['*', '###']
self.load_char_annotation = False
self.data_list = self.load_data(json_path, img_path)
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
self.iaa_augment = IaaAugment()
self.east_random_crop_data = EastRandomCropData()
self.make_border_map = MakeBorderMap()
self.make_shrink_map = MakeShrinkMap()
self.resize = ResizeShortSize(short_size=736, resize_text_polys=False)
self.is_train = is_train
def load_data(self, json_path: str, img_path) -> list:
data_list = []
content = load_json(json_path)
for item in content:
p = os.path.join(img_path, item + '.jpg')
polygons = []
texts = []
illegibility_list = []
for annotation in content[item]:
if len(annotation['points']) == 0 or len(annotation['transcription']) == 0:
continue
polygons.append(annotation['points'])
texts.append(annotation['transcription'])
illegibility_list.append(annotation['illegibility'])
data_list.append(
{
'img_path': p,
'text_polys': np.array(polygons, dtype=object),
'texts': texts,
'ignore_tags': illegibility_list
}
)
return data_list
def __getitem__(self, index):
data = self.data_list[index]
im = cv2.imread(data['img_path'])
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
data['img'] = im
data['shape'] = [im.shape[0], im.shape[1]]
if self.is_train:
data = self.iaa_augment(data)
data = self.east_random_crop_data(data)
data = self.make_border_map(data)
data = self.make_shrink_map(data)
else:
data = self.resize(data)
# resize = ResizeShortSize(short_size=736, resize_text_polys=False)
# data = resize(data)
data['img'] = self.transform(data['img'])
data['text_polys'] = data['text_polys']
return copy.deepcopy(data)
def __len__(self):
return len(self.data_list)
class DetCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
if isinstance(v, np.ndarray):
v = torch.tensor(v)
if isinstance(v, PIL.Image.Image):
v = transforms.ToTensor()(v)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = torch.stack(data_dict[k], 0)
return data_dict
class DBDataModule(pl.LightningDataModule):
def __init__(self, train_json_path, train_img_path, val_json_path, val_img_path):
super(DBDataModule, self).__init__()
self.train = ICDARDataset(train_json_path, train_img_path, is_train=True)
self.val = ICDARDataset(val_json_path, val_img_path, is_train=False)
def train_dataloader(self):
return DataLoader(self.train, batch_size=32, num_workers=0, shuffle=True, collate_fn=DetCollectFN)
def val_dataloader(self):
return DataLoader(self.val, batch_size=32, num_workers=0, collate_fn=DetCollectFN)
if __name__ == '__main__':
import torch
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
def show_img(imgs: np.ndarray, title='img'):
from matplotlib import pyplot as plt
color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3)
imgs = np.expand_dims(imgs, axis=0)
for i, img in enumerate(imgs):
plt.figure()
plt.title('{}_{}'.format(title, i))
plt.imshow(img, cmap=None if color else 'gray')
def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
import cv2
if isinstance(img_path, str):
img_path = cv2.imread(img_path)
# img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
img_path = img_path.copy()
for point in result:
# point = point.astype(int)
cv2.polylines(img_path, [point], True, color, thickness)
return img_path
dataset = ICDARDataset('/home/data/OCRData/icdar2019/train/train.json', '/home/data/OCRData/icdar2019/train/images')
print(len(dataset))
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
for i, data in enumerate(train_loader):
img = data['img']
shrink_label = data['shrink_map']
threshold_label = data['threshold_map']
print(threshold_label.shape, threshold_label.shape, img.shape)
show_img(img[0].numpy().transpose(1, 2, 0), title='img')
show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
# img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys']))
# show_img(img, title='draw_bbox')
plt.show()
break |
Beta Was this translation helpful? Give feedback.
Answered by
tchaton
Sep 13, 2021
Replies: 1 comment 5 replies
-
Dear @morestart, Would you mind unit-testing your code ? Can you check your Best, |
Beta Was this translation helpful? Give feedback.
5 replies
Answer selected by
morestart
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Dear @morestart,
Would you mind unit-testing your code ? Can you check your
DBDataModule
train andval
ICDARDataset length aren't 0 ?Lightning doesn't manipulate your dataset / dataloaders, so maybe your dataset are empty.
Best,
T.C