import os import sys import requests from tqdm import tqdm if len(sys.argv) != 2: print("You must enter the model name as a parameter, e.g.: download_model.py 124M") sys.exit(1) model = sys.argv[1] subdir = os.path.join("models", model) if not os.path.exists(subdir): os.makedirs(subdir) subdir = subdir.replace("\\", "/") # needed for Windows for filename in [ "checkpoint", "encoder.json", "hparams.json", "model.ckpt.data-00000-of-00001", "model.ckpt.index", "model.ckpt.meta", "vocab.bpe", ]: r = requests.get( "https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True ) with open(os.path.join(subdir, filename), "wb") as f: file_size = int(r.headers["content-length"]) chunk_size = 1000 with tqdm( ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True ) as pbar: # 1k for chunk_size, since Ethernet packet size is around 1500 bytes for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) pbar.update(chunk_size)