Initial commit
This commit is contained in:
commit
f58a2e46bd
44 changed files with 4051 additions and 0 deletions
0
generator/__init__.py
Normal file
0
generator/__init__.py
Normal file
0
generator/ctrl/training_utils/action_results.tfrecords
Normal file
0
generator/ctrl/training_utils/action_results.tfrecords
Normal file
2
generator/gpt2/.gitignore
vendored
Normal file
2
generator/gpt2/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
*.pyc
|
||||
__pycache__
|
21
generator/gpt2/LICENSE
Normal file
21
generator/gpt2/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
0
generator/gpt2/__init__.py
Normal file
0
generator/gpt2/__init__.py
Normal file
41
generator/gpt2/download_model.py
Normal file
41
generator/gpt2/download_model.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print("You must enter the model name as a parameter, e.g.: download_model.py 124M")
|
||||
sys.exit(1)
|
||||
|
||||
model = sys.argv[1]
|
||||
|
||||
subdir = os.path.join("models", model)
|
||||
if not os.path.exists(subdir):
|
||||
os.makedirs(subdir)
|
||||
subdir = subdir.replace("\\", "/") # needed for Windows
|
||||
|
||||
for filename in [
|
||||
"checkpoint",
|
||||
"encoder.json",
|
||||
"hparams.json",
|
||||
"model.ckpt.data-00000-of-00001",
|
||||
"model.ckpt.index",
|
||||
"model.ckpt.meta",
|
||||
"vocab.bpe",
|
||||
]:
|
||||
|
||||
r = requests.get(
|
||||
"https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True
|
||||
)
|
||||
|
||||
with open(os.path.join(subdir, filename), "wb") as f:
|
||||
file_size = int(r.headers["content-length"])
|
||||
chunk_size = 1000
|
||||
with tqdm(
|
||||
ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True
|
||||
) as pbar:
|
||||
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
|
||||
for chunk in r.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
pbar.update(chunk_size)
|
137
generator/gpt2/gpt2_generator.py
Normal file
137
generator/gpt2/gpt2_generator.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
from generator.gpt2.src import encoder, model, sample
|
||||
from story.utils import *
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
|
||||
|
||||
class GPT2Generator:
|
||||
def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9, censor=True, force_cpu=False):
|
||||
self.generate_num = generate_num
|
||||
self.temp = temperature
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.censor = censor
|
||||
|
||||
self.model_name = "model_v5"
|
||||
self.model_dir = "generator/gpt2/models"
|
||||
self.checkpoint_path = os.path.join(self.model_dir, self.model_name)
|
||||
|
||||
models_dir = os.path.expanduser(os.path.expandvars(self.model_dir))
|
||||
self.batch_size = 1
|
||||
self.samples = 1
|
||||
|
||||
self.enc = encoder.get_encoder(self.model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join(models_dir, self.model_name, "hparams.json")) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
seed = np.random.randint(0, 100000)
|
||||
|
||||
config = None
|
||||
if force_cpu:
|
||||
config = tf.compat.v1.ConfigProto(
|
||||
device_count={"GPU": 0}
|
||||
)
|
||||
else:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
self.sess = tf.compat.v1.Session(config=config)
|
||||
|
||||
self.context = tf.placeholder(tf.int32, [self.batch_size, None])
|
||||
# np.random.seed(seed)
|
||||
# tf.set_random_seed(seed)
|
||||
self.output = sample.sample_sequence(
|
||||
hparams=hparams,
|
||||
length=self.generate_num,
|
||||
context=self.context,
|
||||
batch_size=self.batch_size,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, self.model_name))
|
||||
saver.restore(self.sess, ckpt)
|
||||
|
||||
def prompt_replace(self, prompt):
|
||||
# print("\n\nBEFORE PROMPT_REPLACE:")
|
||||
# print(repr(prompt))
|
||||
if len(prompt) > 0 and prompt[-1] == " ":
|
||||
prompt = prompt[:-1]
|
||||
|
||||
# prompt = second_to_first_person(prompt)
|
||||
|
||||
# print("\n\nAFTER PROMPT_REPLACE")
|
||||
# print(repr(prompt))
|
||||
return prompt
|
||||
|
||||
def result_replace(self, result):
|
||||
# print("\n\nBEFORE RESULT_REPLACE:")
|
||||
# print(repr(result))
|
||||
|
||||
result = cut_trailing_sentence(result)
|
||||
if len(result) == 0:
|
||||
return ""
|
||||
first_letter_capitalized = result[0].isupper()
|
||||
result = result.replace('."', '".')
|
||||
result = result.replace("#", "")
|
||||
result = result.replace("*", "")
|
||||
result = result.replace("\n\n", "\n")
|
||||
# result = first_to_second_person(result)
|
||||
if self.censor:
|
||||
result = remove_profanity(result)
|
||||
|
||||
if not first_letter_capitalized:
|
||||
result = result[0].lower() + result[1:]
|
||||
|
||||
#
|
||||
# print("\n\nAFTER RESULT_REPLACE:")
|
||||
# print(repr(result))
|
||||
|
||||
return result
|
||||
|
||||
def generate_raw(self, prompt):
|
||||
context_tokens = self.enc.encode(prompt)
|
||||
generated = 0
|
||||
for _ in range(self.samples // self.batch_size):
|
||||
out = self.sess.run(
|
||||
self.output,
|
||||
feed_dict={
|
||||
self.context: [context_tokens for _ in range(self.batch_size)]
|
||||
},
|
||||
)[:, len(context_tokens) :]
|
||||
for i in range(self.batch_size):
|
||||
generated += 1
|
||||
text = self.enc.decode(out[i])
|
||||
return text
|
||||
|
||||
def generate(self, prompt, options=None, seed=1):
|
||||
|
||||
debug_print = False
|
||||
prompt = self.prompt_replace(prompt)
|
||||
|
||||
if debug_print:
|
||||
print("******DEBUG******")
|
||||
print("Prompt is: ", repr(prompt))
|
||||
|
||||
text = self.generate_raw(prompt)
|
||||
|
||||
if debug_print:
|
||||
print("Generated result is: ", repr(text))
|
||||
print("******END DEBUG******")
|
||||
|
||||
result = text
|
||||
result = self.result_replace(result)
|
||||
if len(result) == 0:
|
||||
return self.generate(prompt)
|
||||
|
||||
return result
|
4
generator/gpt2/requirements.txt
Normal file
4
generator/gpt2/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
fire>=0.1.3
|
||||
regex==2018.1.10
|
||||
requests==2.21.0
|
||||
tqdm==4.31.1
|
0
generator/gpt2/src/__init__.py
Normal file
0
generator/gpt2/src/__init__.py
Normal file
131
generator/gpt2/src/encoder.py
Normal file
131
generator/gpt2/src/encoder.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
"""Byte pair encoding utilities"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(
|
||||
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
||||
)
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", errors=self.errors
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(model_name, models_dir):
|
||||
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
|
||||
encoder = json.load(f)
|
||||
with open(
|
||||
os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
||||
return Encoder(encoder=encoder, bpe_merges=bpe_merges,)
|
205
generator/gpt2/src/model.py
Normal file
205
generator/gpt2/src/model.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.training import HParams
|
||||
|
||||
|
||||
def default_hparams():
|
||||
return HParams(n_vocab=0, n_ctx=1024, n_embd=768, n_head=12, n_layer=12,)
|
||||
|
||||
|
||||
def shape_list(x):
|
||||
"""Deal with dynamic shape in tensorflow cleanly."""
|
||||
static = x.shape.as_list()
|
||||
dynamic = tf.shape(x)
|
||||
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
||||
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
x = x - tf.reduce_max(x, axis=axis, keepdims=True)
|
||||
ex = tf.exp(x)
|
||||
return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
|
||||
|
||||
|
||||
def norm(x, scope, *, axis=-1, epsilon=1e-5):
|
||||
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
|
||||
with tf.variable_scope(scope):
|
||||
n_state = x.shape[-1].value
|
||||
g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1))
|
||||
b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0))
|
||||
u = tf.reduce_mean(x, axis=axis, keepdims=True)
|
||||
s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True)
|
||||
x = (x - u) * tf.rsqrt(s + epsilon)
|
||||
x = x * g + b
|
||||
return x
|
||||
|
||||
|
||||
def split_states(x, n):
|
||||
"""Reshape the last dimension of x into [n, x.shape[-1]/n]."""
|
||||
*start, m = shape_list(x)
|
||||
return tf.reshape(x, start + [n, m // n])
|
||||
|
||||
|
||||
def merge_states(x):
|
||||
"""Smash the last two dimensions of x into a single dimension."""
|
||||
*start, a, b = shape_list(x)
|
||||
return tf.reshape(x, start + [a * b])
|
||||
|
||||
|
||||
def conv1d(x, scope, nf, *, w_init_stdev=0.02):
|
||||
with tf.variable_scope(scope):
|
||||
*start, nx = shape_list(x)
|
||||
w = tf.get_variable(
|
||||
"w",
|
||||
[1, nx, nf],
|
||||
initializer=tf.random_normal_initializer(stddev=w_init_stdev),
|
||||
)
|
||||
b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0))
|
||||
c = tf.reshape(
|
||||
tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b,
|
||||
start + [nf],
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def attention_mask(nd, ns, *, dtype):
|
||||
"""1's in the lower triangle, counting from the lower right corner.
|
||||
|
||||
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
||||
"""
|
||||
i = tf.range(nd)[:, None]
|
||||
j = tf.range(ns)
|
||||
m = i >= j - ns + nd
|
||||
return tf.cast(m, dtype)
|
||||
|
||||
|
||||
def attn(x, scope, n_state, *, past, hparams):
|
||||
assert x.shape.ndims == 3 # Should be [batch, sequence, features]
|
||||
assert n_state % hparams.n_head == 0
|
||||
if past is not None:
|
||||
assert (
|
||||
past.shape.ndims == 5
|
||||
) # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
|
||||
|
||||
def split_heads(x):
|
||||
# From [batch, sequence, features] to [batch, heads, sequence, features]
|
||||
return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
|
||||
|
||||
def merge_heads(x):
|
||||
# Reverse of split_heads
|
||||
return merge_states(tf.transpose(x, [0, 2, 1, 3]))
|
||||
|
||||
def mask_attn_weights(w):
|
||||
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
||||
_, _, nd, ns = shape_list(w)
|
||||
b = attention_mask(nd, ns, dtype=w.dtype)
|
||||
b = tf.reshape(b, [1, 1, nd, ns])
|
||||
w = w * b - tf.cast(1e10, w.dtype) * (1 - b)
|
||||
return w
|
||||
|
||||
def multihead_attn(q, k, v):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
|
||||
|
||||
w = mask_attn_weights(w)
|
||||
w = softmax(w)
|
||||
a = tf.matmul(w, v)
|
||||
return a
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
c = conv1d(x, "c_attn", n_state * 3)
|
||||
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
|
||||
present = tf.stack([k, v], axis=1)
|
||||
if past is not None:
|
||||
pk, pv = tf.unstack(past, axis=1)
|
||||
k = tf.concat([pk, k], axis=-2)
|
||||
v = tf.concat([pv, v], axis=-2)
|
||||
a = multihead_attn(q, k, v)
|
||||
a = merge_heads(a)
|
||||
a = conv1d(a, "c_proj", n_state)
|
||||
return a, present
|
||||
|
||||
|
||||
def mlp(x, scope, n_state, *, hparams):
|
||||
with tf.variable_scope(scope):
|
||||
nx = x.shape[-1].value
|
||||
h = gelu(conv1d(x, "c_fc", n_state))
|
||||
h2 = conv1d(h, "c_proj", nx)
|
||||
return h2
|
||||
|
||||
|
||||
def block(x, scope, *, past, hparams):
|
||||
with tf.variable_scope(scope):
|
||||
nx = x.shape[-1].value
|
||||
a, present = attn(norm(x, "ln_1"), "attn", nx, past=past, hparams=hparams)
|
||||
x = x + a
|
||||
m = mlp(norm(x, "ln_2"), "mlp", nx * 4, hparams=hparams)
|
||||
x = x + m
|
||||
return x, present
|
||||
|
||||
|
||||
def past_shape(*, hparams, batch_size=None, sequence=None):
|
||||
return [
|
||||
batch_size,
|
||||
hparams.n_layer,
|
||||
2,
|
||||
hparams.n_head,
|
||||
sequence,
|
||||
hparams.n_embd // hparams.n_head,
|
||||
]
|
||||
|
||||
|
||||
def expand_tile(value, size):
|
||||
"""Add a new axis of given size."""
|
||||
value = tf.convert_to_tensor(value, name="value")
|
||||
ndims = value.shape.ndims
|
||||
return tf.tile(tf.expand_dims(value, axis=0), [size] + [1] * ndims)
|
||||
|
||||
|
||||
def positions_for(tokens, past_length):
|
||||
batch_size = tf.shape(tokens)[0]
|
||||
nsteps = tf.shape(tokens)[1]
|
||||
return expand_tile(past_length + tf.range(nsteps), batch_size)
|
||||
|
||||
|
||||
def model(hparams, X, past=None, scope="model", reuse=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
results = {}
|
||||
batch, sequence = shape_list(X)
|
||||
|
||||
wpe = tf.get_variable(
|
||||
"wpe",
|
||||
[hparams.n_ctx, hparams.n_embd],
|
||||
initializer=tf.random_normal_initializer(stddev=0.01),
|
||||
)
|
||||
wte = tf.get_variable(
|
||||
"wte",
|
||||
[hparams.n_vocab, hparams.n_embd],
|
||||
initializer=tf.random_normal_initializer(stddev=0.02),
|
||||
)
|
||||
past_length = 0 if past is None else tf.shape(past)[-2]
|
||||
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
|
||||
|
||||
# Transformer
|
||||
presents = []
|
||||
pasts = (
|
||||
tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
|
||||
)
|
||||
assert len(pasts) == hparams.n_layer
|
||||
for layer, past in enumerate(pasts):
|
||||
h, present = block(h, "h%d" % layer, past=past, hparams=hparams)
|
||||
presents.append(present)
|
||||
results["present"] = tf.stack(presents, axis=1)
|
||||
h = norm(h, "ln_f")
|
||||
|
||||
# Language model loss. Do tokens <n predict token n?
|
||||
h_flat = tf.reshape(h, [batch * sequence, hparams.n_embd])
|
||||
logits = tf.matmul(h_flat, wte, transpose_b=True)
|
||||
logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
|
||||
results["logits"] = logits
|
||||
return results
|
123
generator/gpt2/src/sample.py
Normal file
123
generator/gpt2/src/sample.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
import tensorflow as tf
|
||||
from generator.gpt2.src import model
|
||||
|
||||
|
||||
def penalize_used(logits, output):
|
||||
|
||||
# I want to change the indices of logits wherever the index is found in output
|
||||
change_tensor = tf.zeros_like(logits, dtype=logits.dtype)
|
||||
unique = tf.unique(output[0])[0]
|
||||
ones = tf.ones_like(unique, dtype=unique.dtype)
|
||||
indices = tf.expand_dims(unique, 1)
|
||||
|
||||
updates = tf.scatter_nd(indices, ones, [logits.shape[1]])
|
||||
|
||||
bool_tensor = tf.expand_dims(tf.cast(updates, tf.bool), 0)
|
||||
|
||||
return tf.compat.v1.where(bool_tensor, logits * 0.85, logits)
|
||||
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
if k == 0:
|
||||
# no truncation
|
||||
return logits
|
||||
|
||||
def _top_k():
|
||||
values, _ = tf.nn.top_k(logits, k=k)
|
||||
min_values = values[:, -1, tf.newaxis]
|
||||
return tf.where(
|
||||
logits < min_values,
|
||||
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
|
||||
logits,
|
||||
)
|
||||
|
||||
return tf.cond(tf.equal(k, 0), lambda: logits, lambda: _top_k(),)
|
||||
|
||||
|
||||
def top_p_logits(logits, p):
|
||||
"""Nucleus sampling"""
|
||||
batch, _ = logits.shape.as_list()
|
||||
sorted_logits = tf.sort(logits, direction="DESCENDING", axis=-1)
|
||||
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
indices = tf.stack(
|
||||
[
|
||||
tf.range(0, batch),
|
||||
# number of indices to include
|
||||
tf.maximum(
|
||||
tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0
|
||||
),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
min_values = tf.gather_nd(sorted_logits, indices)
|
||||
return tf.where(logits < min_values, tf.ones_like(logits) * -1e10, logits,)
|
||||
|
||||
|
||||
def sample_sequence(
|
||||
*,
|
||||
hparams,
|
||||
length,
|
||||
start_token=None,
|
||||
batch_size=None,
|
||||
context=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1
|
||||
):
|
||||
if start_token is None:
|
||||
assert context is not None, "Specify exactly one of start_token and context!"
|
||||
else:
|
||||
assert context is None, "Specify exactly one of start_token and context!"
|
||||
context = tf.fill([batch_size, 1], start_token)
|
||||
|
||||
def step(hparams, tokens, past=None):
|
||||
lm_output = model.model(
|
||||
hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE
|
||||
)
|
||||
|
||||
logits = lm_output["logits"][:, :, : hparams.n_vocab]
|
||||
presents = lm_output["present"]
|
||||
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
|
||||
return {
|
||||
"logits": logits,
|
||||
"presents": presents,
|
||||
}
|
||||
|
||||
with tf.name_scope("sample_sequence"):
|
||||
|
||||
def body(past, prev, output):
|
||||
next_outputs = step(hparams, prev, past=past)
|
||||
logits = next_outputs["logits"][:, -1, :] / tf.to_float(temperature)
|
||||
logits = penalize_used(logits, output)
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
logits = top_p_logits(logits, p=top_p)
|
||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||
return [
|
||||
next_outputs["presents"]
|
||||
if past is None
|
||||
else tf.concat([past, next_outputs["presents"]], axis=-2),
|
||||
samples,
|
||||
tf.concat([output, samples], axis=1),
|
||||
]
|
||||
|
||||
past, prev, output = body(None, context, context)
|
||||
|
||||
def cond(*args):
|
||||
return True
|
||||
|
||||
_, _, tokens = tf.while_loop(
|
||||
cond=cond,
|
||||
body=body,
|
||||
maximum_iterations=length - 1,
|
||||
loop_vars=[past, prev, output],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(
|
||||
model.past_shape(hparams=hparams, batch_size=batch_size)
|
||||
),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
],
|
||||
back_prop=False,
|
||||
)
|
||||
|
||||
return tokens
|
6
generator/human_dm.py
Normal file
6
generator/human_dm.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from story.utils import *
|
||||
|
||||
|
||||
class HumanDM:
|
||||
def generate(self, prompt, options=None, seed=None):
|
||||
return input()
|
4
generator/simple/.gitignore
vendored
Normal file
4
generator/simple/.gitignore
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
models
|
||||
*.pyc
|
||||
__pycache__
|
||||
checkpoint
|
0
generator/simple/__init__.py
Normal file
0
generator/simple/__init__.py
Normal file
29
generator/simple/finetune.py
Normal file
29
generator/simple/finetune.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import tarfile
|
||||
|
||||
import gpt_2_simple as gpt2
|
||||
|
||||
model_name = "1558M"
|
||||
if not os.path.isdir(os.path.join("models", model_name)):
|
||||
print("Downloading ", model_name, " model...")
|
||||
gpt2.download_gpt2(
|
||||
model_name=model_name
|
||||
) # model is saved into current directory under /models/124M/
|
||||
|
||||
file_name = "text_adventures.txt"
|
||||
|
||||
sess = gpt2.start_tf_sess()
|
||||
gpt2.finetune(
|
||||
sess,
|
||||
file_name,
|
||||
multi_gpu=True,
|
||||
batch_size=32,
|
||||
learning_rate=0.0001,
|
||||
model_name=model_name,
|
||||
sample_every=10000,
|
||||
max_checkpoints=8,
|
||||
save_every=200,
|
||||
steps=1000,
|
||||
)
|
||||
|
||||
gpt2.generate(sess)
|
Reference in a new issue