File size: 4,529 Bytes
1633fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
import os
from PIL import Image
from torchvision import transforms
class CustomCocoDataset(Dataset):
def __init__(self, json_file, img_folder, common_transform=None):
self.coco = COCO(json_file)
self.img_folder = img_folder
self.ids = list(self.coco.imgToAnns.keys())
self.common_transform = common_transform
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_info = self.coco.loadImgs(img_id)[0]
path = img_info['file_name']
img_path = os.path.join(self.img_folder, path)
image = Image.open(img_path).convert('RGB')
# Perform a random crop
i, j, h, w = transforms.RandomResizedCrop.get_params(
image, scale=(0.9, 1.0), ratio=(1.0, 1.0)) # Ensure the same crop for both images
cropped_image = transforms.functional.crop(image, i, j, h, w)
# Resize to different resolutions
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC)
# Apply common transformations
if self.common_transform is not None:
jpg_image = self.common_transform(jpg_image)
hint_image = self.common_transform(hint_image)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
captions = [ann['caption'] for ann in anns]
combined_caption = ' '.join(captions)
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image)
class CustomCocoDataset(Dataset):
def __init__(self, json_file, img_folder, common_transform=None):
self.coco = COCO(json_file)
self.img_folder = img_folder
self.ids = list(self.coco.imgToAnns.keys())
self.common_transform = common_transform
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_info = self.coco.loadImgs(img_id)[0]
path = img_info['file_name']
img_path = os.path.join(self.img_folder, path)
image = Image.open(img_path).convert('RGB')
# Perform a random crop
i, j, h, w = transforms.RandomResizedCrop.get_params(
image, scale=(0.95, 1.0), ratio=(1.0, 1.0)) # Ensure the same crop for both images
cropped_image = transforms.functional.crop(image, i, j, h, w)
# Resize to different resolutions
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC)
# Apply common transformations
if self.common_transform is not None:
jpg_image = self.common_transform(jpg_image)
hint_image = self.common_transform(hint_image)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
# captions = [ann['caption'] for ann in anns]
captions = [ann['caption'].replace('\n', ' ') for ann in anns]
combined_caption = ' '.join(captions)
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image)
def main():
# Define the common transformations
common_transform = transforms.Compose([
transforms.ToTensor(), # Converts to tensor and normalizes to [0, 1]
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalizes to [-1, 1]
])
# Instantiate the dataset
dataset = CustomCocoDataset(
json_file='/home/t2vg-a100-G4-1/projects/dataset/annotations/captions_train2017.json',
img_folder='/home/t2vg-a100-G4-1/projects/dataset/train2017',
common_transform=common_transform
)
# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
# Get the first batch
for batch in dataloader:
jpg_image = batch['jpg']
# Print the min and max values in the image tensor
print(f'JPG Image Min Value: {jpg_image.min().item()}')
print(f'JPG Image Max Value: {jpg_image.max().item()}')
# Exit after the first batch
# break
if __name__ == "__main__":
main() |