import json import os import warnings import numpy as np import tensorflow as tf from generator.gpt2.src import encoder, model, sample from story.utils import * warnings.filterwarnings("ignore") tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) class GPT2Generator: def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9, censor=True, force_cpu=False): self.generate_num = generate_num self.temp = temperature self.top_k = top_k self.top_p = top_p self.censor = censor self.model_name = "model_v5" self.model_dir = "generator/gpt2/models" self.checkpoint_path = os.path.join(self.model_dir, self.model_name) models_dir = os.path.expanduser(os.path.expandvars(self.model_dir)) self.batch_size = 1 self.samples = 1 self.enc = encoder.get_encoder(self.model_name, models_dir) hparams = model.default_hparams() with open(os.path.join(models_dir, self.model_name, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) seed = np.random.randint(0, 100000) config = None if force_cpu: config = tf.compat.v1.ConfigProto( device_count={"GPU": 0} ) else: config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.compat.v1.Session(config=config) self.context = tf.placeholder(tf.int32, [self.batch_size, None]) # np.random.seed(seed) # tf.set_random_seed(seed) self.output = sample.sample_sequence( hparams=hparams, length=self.generate_num, context=self.context, batch_size=self.batch_size, temperature=temperature, top_k=top_k, top_p=top_p, ) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, self.model_name)) saver.restore(self.sess, ckpt) def prompt_replace(self, prompt): # print("\n\nBEFORE PROMPT_REPLACE:") # print(repr(prompt)) if len(prompt) > 0 and prompt[-1] == " ": prompt = prompt[:-1] # prompt = second_to_first_person(prompt) # print("\n\nAFTER PROMPT_REPLACE") # print(repr(prompt)) return prompt def result_replace(self, result): # print("\n\nBEFORE RESULT_REPLACE:") # print(repr(result)) result = cut_trailing_sentence(result) if len(result) == 0: return "" first_letter_capitalized = result[0].isupper() result = result.replace('."', '".') result = result.replace("#", "") result = result.replace("*", "") result = result.replace("\n\n", "\n") # result = first_to_second_person(result) if self.censor: result = remove_profanity(result) if not first_letter_capitalized: result = result[0].lower() + result[1:] # # print("\n\nAFTER RESULT_REPLACE:") # print(repr(result)) return result def generate_raw(self, prompt): context_tokens = self.enc.encode(prompt) generated = 0 for _ in range(self.samples // self.batch_size): out = self.sess.run( self.output, feed_dict={ self.context: [context_tokens for _ in range(self.batch_size)] }, )[:, len(context_tokens) :] for i in range(self.batch_size): generated += 1 text = self.enc.decode(out[i]) return text def generate(self, prompt, options=None, seed=1): debug_print = False prompt = self.prompt_replace(prompt) if debug_print: print("******DEBUG******") print("Prompt is: ", repr(prompt)) text = self.generate_raw(prompt) if debug_print: print("Generated result is: ", repr(text)) print("******END DEBUG******") result = text result = self.result_replace(result) if len(result) == 0: return self.generate(prompt) return result