3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
15 import urllib.error
as urlliberror
16 import urllib.request
as urllib
17 HTTPError = urlliberror.HTTPError
18 URLError = urlliberror.URLError
20 import urllib2
as urllib
21 HTTPError = urllib.HTTPError
22 URLError = urllib.URLError
25 DOWNLOAD_BASE_URL =
"https://s3.amazonaws.com/download.caffe2.ai/models/" 30 def signalHandler(signal, frame):
31 print(
"Killing download...")
35 signal.signal(signal.SIGINT, signalHandler)
38 def deleteDirectory(top_dir):
39 for root, dirs, files
in os.walk(top_dir, topdown=
False):
41 os.remove(os.path.join(root, name))
43 os.rmdir(os.path.join(root, name))
47 def progressBar(percentage):
48 full = int(DOWNLOAD_COLUMNS * percentage / 100)
49 bar = full *
"#" + (DOWNLOAD_COLUMNS - full) *
" " 50 sys.stdout.write(
u"\u001b[1000D[" + bar +
"] " + str(percentage) +
"%")
54 def downloadFromURLToFile(url, filename, show_progress=True):
56 print(
"Downloading from {url}".format(url=url))
57 response = urllib.urlopen(url)
58 size = int(response.info().get(
'Content-Length').strip())
59 chunk = min(size, 8192)
60 print(
"Writing to {filename}".format(filename=filename))
64 with open(filename,
"wb")
as local_file:
66 data_chunk = response.read(chunk)
69 local_file.write(data_chunk)
71 downloaded_size += len(data_chunk)
72 progressBar(int(100 * downloaded_size / size))
74 except HTTPError
as e:
75 raise Exception(
"Could not download model. [HTTP Error] {code}: {reason}." 76 .format(code=e.code, reason=e.reason))
78 raise Exception(
"Could not download model. [URL Error] {reason}." 79 .format(reason=e.reason))
80 except Exception
as e:
84 def getURLFromName(name, filename):
85 return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
86 name=name, filename=filename)
89 def downloadModel(model, args):
91 model_folder =
'{folder}'.format(folder=model)
92 dir_path = os.path.dirname(os.path.realpath(__file__))
94 model_folder =
'{dir_path}/{folder}'.format(dir_path=dir_path,
98 if os.path.exists(model_folder)
and not os.path.isdir(model_folder):
100 raise Exception(
"Cannot create folder for storing the model,\ 101 there exists a file of the same name.")
103 print(
"Overwriting existing file! ({filename})" 104 .format(filename=model_folder))
105 os.remove(model_folder)
106 if os.path.isdir(model_folder):
109 query =
"Model already exists, continue? [y/N] " 111 response = raw_input(query)
113 response = input(query)
114 if response.upper() ==
'N' or not response:
115 print(
"Cancelling download...")
117 print(
"Overwriting existing folder! ({filename})".format(filename=model_folder))
118 deleteDirectory(model_folder)
121 os.makedirs(model_folder)
122 for f
in [
'predict_net.pb',
'init_net.pb']:
124 downloadFromURLToFile(getURLFromName(model, f),
125 '{folder}/{f}'.format(folder=model_folder,
127 except Exception
as e:
128 print(
"Abort: {reason}".format(reason=str(e)))
129 print(
"Cleaning up...")
130 deleteDirectory(model_folder)
134 os.symlink(
"{folder}/__sym_init__.py".format(folder=dir_path),
135 "{folder}/__init__.py".format(folder=model_folder))
138 def validModelName(name):
139 invalid_names = [
'__init__']
140 if name
in invalid_names:
142 if not re.match(
"^[/0-9a-zA-Z_-]+$", name):
147 if __name__ ==
"__main__":
148 parser = argparse.ArgumentParser(
149 description=
'Download or install pretrained models.')
150 parser.add_argument(
'model', nargs=
'+',
151 help=
'Model to download/install.')
152 parser.add_argument(
'-i',
'--install', action=
'store_true',
153 help=
'Install the model.')
154 parser.add_argument(
'-f',
'--force', action=
'store_true',
155 help=
'Force a download/installation.')
156 args = parser.parse_args()
157 for model
in args.model:
158 if validModelName(model):
159 downloadModel(model, args)
161 print(
"'{}' is not a valid model name.".format(model))