Source code for torch.utils.model_zoo

import torch

import hashlib
import os
import re
import shutil
import sys
import tempfile
if sys.version_info[0] == 2:
    from urlparse import urlparse
    from urllib2 import urlopen
else:
    from urllib.request import urlopen
    from urllib.parse import urlparse
try:
    from tqdm import tqdm
except ImportError:
    tqdm = None  # defined below

# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')


[docs]def load_url(url, model_dir=None, map_location=None): r"""Loads the Torch serialized object at the given URL. If the object is already present in `model_dir`, it's deserialied and returned. The filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. The default value of `model_dir` is ``$TORCH_HOME/models`` where ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be overriden with the ``$TORCH_MODEL_ZOO`` environement variable. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) Example: >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ if model_dir is None: torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) if not os.path.exists(model_dir): os.makedirs(model_dir) parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename).group(1) _download_url_to_file(url, cached_file, hash_prefix) return torch.load(cached_file, map_location=map_location)
def _download_url_to_file(url, dst, hash_prefix): u = urlopen(url) meta = u.info() if hasattr(meta, 'getheaders'): file_size = int(meta.getheaders("Content-Length")[0]) else: file_size = int(meta.get_all("Content-Length")[0]) f = tempfile.NamedTemporaryFile(delete=False) try: sha256 = hashlib.sha256() with tqdm(total=file_size) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) sha256.update(buffer) pbar.update(len(buffer)) f.close() digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError('invalid hash value (expected "{}", got "{}")' .format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) if tqdm is None: # fake tqdm if it's not installed class tqdm(object): def __init__(self, total): self.total = total self.n = 0 def update(self, n): self.n += n sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): sys.stderr.write('\n')