This repository has been archived on 2025-03-12. You can view files and clone it, but cannot push or open issues or pull requests.
AIDungeon/generator/gpt2/download_model.py

42 lines
1.1 KiB
Python
Raw Permalink Normal View History

2025-03-11 22:26:45 -04:00
import os
import sys
import requests
from tqdm import tqdm
if len(sys.argv) != 2:
print("You must enter the model name as a parameter, e.g.: download_model.py 124M")
sys.exit(1)
model = sys.argv[1]
subdir = os.path.join("models", model)
if not os.path.exists(subdir):
os.makedirs(subdir)
subdir = subdir.replace("\\", "/") # needed for Windows
for filename in [
"checkpoint",
"encoder.json",
"hparams.json",
"model.ckpt.data-00000-of-00001",
"model.ckpt.index",
"model.ckpt.meta",
"vocab.bpe",
]:
r = requests.get(
"https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True
)
with open(os.path.join(subdir, filename), "wb") as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(
ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True
) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)