Caffe2 - Python API
A deep learning, cross platform ML framework
download.py
1 ## @package download
2 # Module caffe2.python.models.download
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import os
9 import sys
10 import signal
11 import re
12 
13 # Import urllib
14 try:
15  import urllib.error as urlliberror
16  import urllib.request as urllib
17  HTTPError = urlliberror.HTTPError
18  URLError = urlliberror.URLError
19 except ImportError:
20  import urllib2 as urllib
21  HTTPError = urllib.HTTPError
22  URLError = urllib.URLError
23 
24 # urllib requires more work to deal with a redirect, so not using vanity url
25 DOWNLOAD_BASE_URL = "https://s3.amazonaws.com/download.caffe2.ai/models/"
26 DOWNLOAD_COLUMNS = 70
27 
28 
29 # Don't let urllib hang up on big downloads
30 def signalHandler(signal, frame):
31  print("Killing download...")
32  exit(0)
33 
34 
35 signal.signal(signal.SIGINT, signalHandler)
36 
37 
38 def deleteDirectory(top_dir):
39  for root, dirs, files in os.walk(top_dir, topdown=False):
40  for name in files:
41  os.remove(os.path.join(root, name))
42  for name in dirs:
43  os.rmdir(os.path.join(root, name))
44  os.rmdir(top_dir)
45 
46 
47 def progressBar(percentage):
48  full = int(DOWNLOAD_COLUMNS * percentage / 100)
49  bar = full * "#" + (DOWNLOAD_COLUMNS - full) * " "
50  sys.stdout.write(u"\u001b[1000D[" + bar + "] " + str(percentage) + "%")
51  sys.stdout.flush()
52 
53 
54 def downloadFromURLToFile(url, filename, show_progress=True):
55  try:
56  print("Downloading from {url}".format(url=url))
57  response = urllib.urlopen(url)
58  size = int(response.info().get('Content-Length').strip())
59  chunk = min(size, 8192)
60  print("Writing to {filename}".format(filename=filename))
61  if show_progress:
62  downloaded_size = 0
63  progressBar(0)
64  with open(filename, "wb") as local_file:
65  while True:
66  data_chunk = response.read(chunk)
67  if not data_chunk:
68  break
69  local_file.write(data_chunk)
70  if show_progress:
71  downloaded_size += len(data_chunk)
72  progressBar(int(100 * downloaded_size / size))
73  print("") # New line to fix for progress bar
74  except HTTPError as e:
75  raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
76  .format(code=e.code, reason=e.reason))
77  except URLError as e:
78  raise Exception("Could not download model. [URL Error] {reason}."
79  .format(reason=e.reason))
80  except Exception as e:
81  raise e
82 
83 
84 def getURLFromName(name, filename):
85  return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
86  name=name, filename=filename)
87 
88 
89 def downloadModel(model, args):
90  # Figure out where to store the model
91  model_folder = '{folder}'.format(folder=model)
92  dir_path = os.path.dirname(os.path.realpath(__file__))
93  if args.install:
94  model_folder = '{dir_path}/{folder}'.format(dir_path=dir_path,
95  folder=model)
96 
97  # Check if that folder is already there
98  if os.path.exists(model_folder) and not os.path.isdir(model_folder):
99  if not args.force:
100  raise Exception("Cannot create folder for storing the model,\
101  there exists a file of the same name.")
102  else:
103  print("Overwriting existing file! ({filename})"
104  .format(filename=model_folder))
105  os.remove(model_folder)
106  if os.path.isdir(model_folder):
107  if not args.force:
108  response = ""
109  query = "Model already exists, continue? [y/N] "
110  try:
111  response = raw_input(query)
112  except NameError:
113  response = input(query)
114  if response.upper() == 'N' or not response:
115  print("Cancelling download...")
116  exit(0)
117  print("Overwriting existing folder! ({filename})".format(filename=model_folder))
118  deleteDirectory(model_folder)
119 
120  # Now we can safely create the folder and download the model
121  os.makedirs(model_folder)
122  for f in ['predict_net.pb', 'init_net.pb']:
123  try:
124  downloadFromURLToFile(getURLFromName(model, f),
125  '{folder}/{f}'.format(folder=model_folder,
126  f=f))
127  except Exception as e:
128  print("Abort: {reason}".format(reason=str(e)))
129  print("Cleaning up...")
130  deleteDirectory(model_folder)
131  exit(0)
132 
133  if args.install:
134  os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),
135  "{folder}/__init__.py".format(folder=model_folder))
136 
137 
138 def validModelName(name):
139  invalid_names = ['__init__']
140  if name in invalid_names:
141  return False
142  if not re.match("^[/0-9a-zA-Z_-]+$", name):
143  return False
144  return True
145 
146 
147 if __name__ == "__main__":
148  parser = argparse.ArgumentParser(
149  description='Download or install pretrained models.')
150  parser.add_argument('model', nargs='+',
151  help='Model to download/install.')
152  parser.add_argument('-i', '--install', action='store_true',
153  help='Install the model.')
154  parser.add_argument('-f', '--force', action='store_true',
155  help='Force a download/installation.')
156  args = parser.parse_args()
157  for model in args.model:
158  if validModelName(model):
159  downloadModel(model, args)
160  else:
161  print("'{}' is not a valid model name.".format(model))