from __future__ import print_function from PIL import Image import os import os.path import numpy as np import sys if sys.version_info[0] == 2: import cPickle as pickle else: import pickle import torch.utils.data as data from torchvision.datasets.utils import download_url, check_integrity class FRDEEPN(data.Dataset): """`FRDEEP-N `_Dataset Inspired by `HTRU1 `_ Dataset. Args: root (string): Root directory of dataset where directory ``htru1-batches-py`` exists or will be saved to if download is set to True. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'NVSS_PNG_dataset' url = "http://www.jb.man.ac.uk/research/ascaife/NVSS_PNG_dataset.tar.gz" filename = "NVSS_PNG_dataset.tar.gz" tgz_md5 = '2584ed1e174ea71f581d0e0d6f32ef38' train_list = [ ['data_batch_1', '3a2a15d88756ba61c796378fc8574540'], ['data_batch_2', '6a04e3985397e1f67f0ad42153dca64e'], ['data_batch_3', 'd852c8200f3bbb63beacf31f3e954f9a'], ['data_batch_4', 'a5739996ca44a1a1841f2d0e6b844dd6'], ['data_batch_5', '8e2fdb3f60bf7541ca135fc8e2407f7a'], ['data_batch_6', '9e5a82500bd9742f8fefe412ada95336'], ['data_batch_7', 'f66af7795265fbe24376f669200412c4'], ['data_batch_8', '75982afc09bf480ecc521acdb39cbe46'], ['data_batch_9', '72aee306fef9acee21a0e5537bb681e4'], ['data_batch_10', '7a039ce8062a533b23b401a612c5f9b7'], ['data_batch_11', 'c0013314098c96ca4c7c20c0f17abcd3'], ] test_list = [ ['test_batch', '39fd167b9a7df12cee1ef9a804f9fa86'], ] meta = { 'filename': 'batches.meta', 'key': 'label_names', 'md5': '655493bdee948954f3939727b3f9e735', } def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if self.train: downloaded_list = self.train_list else: downloaded_list = self.test_list self.data = [] self.targets = [] # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: if sys.version_info[0] == 2: entry = pickle.load(f) else: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 1, 150, 150) self.data = self.data.transpose((0, 2, 3, 1)) self._load_meta() def _load_meta(self): path = os.path.join(self.root, self.base_folder, self.meta['filename']) if not check_integrity(path, self.meta['md5']): raise RuntimeError('Dataset metadata file not found or corrupted.' + ' You can use download=True to download it') with open(path, 'rb') as infile: if sys.version_info[0] == 2: data = pickle.load(infile) else: data = pickle.load(infile, encoding='latin1') self.classes = data[self.meta['key']] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = np.reshape(img,(150,150)) img = Image.fromarray(img,mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) def _check_integrity(self): root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): return False return True def download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_url(self.url, self.root, self.filename, self.tgz_md5) # extract file with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) tmp = 'train' if self.train is True else 'test' fmt_str += ' Split: {}\n'.format(tmp) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str class FRDEEPF(data.Dataset): """`FRDEEP-F `_Dataset Inspired by `HTRU1 `_ Dataset. Args: root (string): Root directory of dataset where directory ``htru1-batches-py`` exists or will be saved to if download is set to True. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'FIRST_PNG_dataset' url = "http://www.jb.man.ac.uk/research/ascaife/FIRST_PNG_dataset.tar.gz" filename = "FIRST_PNG_dataset.tar.gz" tgz_md5 = '2f39461e6c62fb45289559915106013a' train_list = [ ['data_batch_1', 'f34da44757c7fa3f6e6cd3d0839a4634'], ['data_batch_2', 'f56cda0d9a99305fee2bad7de0560f95'], ['data_batch_3', '93265dd849331af4e1b092f74b06450b'], ['data_batch_4', '0de8f4c18b775251f4e553e2990cd446'], ['data_batch_5', 'c6aa87400a1be6007da7cfcefd2c3e5c'], ['data_batch_6', 'cebd3fdea93abbc048a3a4d5e58528e0'], ['data_batch_7', '49497445e9380f157e78cf8d74fca1eb'], ['data_batch_8', '88e298eed2d87bbdddad83fef1482723'], ['data_batch_9', '8c40117dbf4d456e63a8a665b245aa63'], ['data_batch_10', 'f24d110cc5811ba4651630b9ee9b2989'], ['data_batch_11', 'b843dc3b7f48606235029f135d41c85e'], ] test_list = [ ['test_batch', '4e06889b1e7713deb46e62887eb37727'], ] meta = { 'filename': 'batches.meta', 'key': 'label_names', 'md5': '655493bdee948954f3939727b3f9e735', } def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if self.train: downloaded_list = self.train_list else: downloaded_list = self.test_list self.data = [] self.targets = [] # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: if sys.version_info[0] == 2: entry = pickle.load(f) else: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 1, 150, 150) self.data = self.data.transpose((0, 2, 3, 1)) self._load_meta() def _load_meta(self): path = os.path.join(self.root, self.base_folder, self.meta['filename']) if not check_integrity(path, self.meta['md5']): raise RuntimeError('Dataset metadata file not found or corrupted.' + ' You can use download=True to download it') with open(path, 'rb') as infile: if sys.version_info[0] == 2: data = pickle.load(infile) else: data = pickle.load(infile, encoding='latin1') self.classes = data[self.meta['key']] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = np.reshape(img,(150,150)) img = Image.fromarray(img,mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) def _check_integrity(self): root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): return False return True def download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_url(self.url, self.root, self.filename, self.tgz_md5) # extract file with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) tmp = 'train' if self.train is True else 'test' fmt_str += ' Split: {}\n'.format(tmp) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str