from torchvision import datasets, transforms from torch.utils.data import DataLoader import os ## Note that: here we provide a basic solution for loading data and transforming data. ## You can directly change it if you find something wrong or not good enough. ## the mean and standard variance of imagenet dataset ## mean_vals = [0.485, 0.456, 0.406] ## std_vals = [0.229, 0.224, 0.225] def load_data(data_dir="../data/", input_size=224, batch_size=36, augmentation=False): if augmentation: data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(input_size, (0.64, 1), (1, 1)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'valid': transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } else: data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'valid': transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } image_dataset_train = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']) image_dataset_valid = datasets.ImageFolder(os.path.join(data_dir, 'valid'), data_transforms['valid']) train_loader = DataLoader(image_dataset_train, batch_size=batch_size, shuffle=True, num_workers=4) valid_loader = DataLoader(image_dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4) return train_loader, valid_loader