import numpy as np import tensorflow as tf from tensorflow.contrib.training import HParams def default_hparams(): return HParams(n_vocab=0, n_ctx=1024, n_embd=768, n_head=12, n_layer=12,) def shape_list(x): """Deal with dynamic shape in tensorflow cleanly.""" static = x.shape.as_list() dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)] def softmax(x, axis=-1): x = x - tf.reduce_max(x, axis=axis, keepdims=True) ex = tf.exp(x) return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) def gelu(x): return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))) def norm(x, scope, *, axis=-1, epsilon=1e-5): """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" with tf.variable_scope(scope): n_state = x.shape[-1].value g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True) x = (x - u) * tf.rsqrt(s + epsilon) x = x * g + b return x def split_states(x, n): """Reshape the last dimension of x into [n, x.shape[-1]/n].""" *start, m = shape_list(x) return tf.reshape(x, start + [n, m // n]) def merge_states(x): """Smash the last two dimensions of x into a single dimension.""" *start, a, b = shape_list(x) return tf.reshape(x, start + [a * b]) def conv1d(x, scope, nf, *, w_init_stdev=0.02): with tf.variable_scope(scope): *start, nx = shape_list(x) w = tf.get_variable( "w", [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev), ) b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0)) c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, start + [nf], ) return c def attention_mask(nd, ns, *, dtype): """1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. """ i = tf.range(nd)[:, None] j = tf.range(ns) m = i >= j - ns + nd return tf.cast(m, dtype) def attn(x, scope, n_state, *, past, hparams): assert x.shape.ndims == 3 # Should be [batch, sequence, features] assert n_state % hparams.n_head == 0 if past is not None: assert ( past.shape.ndims == 5 ) # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] def split_heads(x): # From [batch, sequence, features] to [batch, heads, sequence, features] return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) def merge_heads(x): # Reverse of split_heads return merge_states(tf.transpose(x, [0, 2, 1, 3])) def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w def multihead_attn(q, k, v): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) w = mask_attn_weights(w) w = softmax(w) a = tf.matmul(w, v) return a with tf.variable_scope(scope): c = conv1d(x, "c_attn", n_state * 3) q, k, v = map(split_heads, tf.split(c, 3, axis=2)) present = tf.stack([k, v], axis=1) if past is not None: pk, pv = tf.unstack(past, axis=1) k = tf.concat([pk, k], axis=-2) v = tf.concat([pv, v], axis=-2) a = multihead_attn(q, k, v) a = merge_heads(a) a = conv1d(a, "c_proj", n_state) return a, present def mlp(x, scope, n_state, *, hparams): with tf.variable_scope(scope): nx = x.shape[-1].value h = gelu(conv1d(x, "c_fc", n_state)) h2 = conv1d(h, "c_proj", nx) return h2 def block(x, scope, *, past, hparams): with tf.variable_scope(scope): nx = x.shape[-1].value a, present = attn(norm(x, "ln_1"), "attn", nx, past=past, hparams=hparams) x = x + a m = mlp(norm(x, "ln_2"), "mlp", nx * 4, hparams=hparams) x = x + m return x, present def past_shape(*, hparams, batch_size=None, sequence=None): return [ batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head, ] def expand_tile(value, size): """Add a new axis of given size.""" value = tf.convert_to_tensor(value, name="value") ndims = value.shape.ndims return tf.tile(tf.expand_dims(value, axis=0), [size] + [1] * ndims) def positions_for(tokens, past_length): batch_size = tf.shape(tokens)[0] nsteps = tf.shape(tokens)[1] return expand_tile(past_length + tf.range(nsteps), batch_size) def model(hparams, X, past=None, scope="model", reuse=False): with tf.variable_scope(scope, reuse=reuse): results = {} batch, sequence = shape_list(X) wpe = tf.get_variable( "wpe", [hparams.n_ctx, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.01), ) wte = tf.get_variable( "wte", [hparams.n_vocab, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.02), ) past_length = 0 if past is None else tf.shape(past)[-2] h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) # Transformer presents = [] pasts = ( tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer ) assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, "h%d" % layer, past=past, hparams=hparams) presents.append(present) results["present"] = tf.stack(presents, axis=1) h = norm(h, "ln_f") # Language model loss. Do tokens