-
Notifications
You must be signed in to change notification settings - Fork 0
/
1_5b_opengpt_2.py
82 lines (61 loc) · 31.2 KB
/
1_5b_opengpt_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- coding: utf-8 -*-
"""1.5B OpenGPT-2.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1b61Jowb390aZwpE51qcRADQ_SWyhdl5e
"""
# Commented out IPython magic to ensure Python compatibility.
!git clone https://github.com/rowanz/grover.git
# %cd /content/grover
!python3 -m pip install regex jsonlines
!python3 -m pip install -U tqdm
import os
from google.colab import auth
from googleapiclient.discovery import build
from apiclient.http import MediaIoBaseDownload
from tqdm.auto import tqdm
import os
import io
auth.authenticate_user()
drive_service = build('drive', 'v3')
model_type = 'mega'
model_path = '/content/grover/models'
model_dir = os.path.join(model_path, model_type)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
NUMBER = 800000
MODEL = "grover_mega_owt_fix"
local_file_ids = ['1t1B5JfjolytwSSUAGOZCnstVlaL2Bg25', '1kojGap2kXzkJBWtwIZtOXgeV3IWnNMtO', '1FITtxBwJsagKfaLz9tgsguHbbdMCyeQw']
local_filenames = ['.data-00000-of-00001', '.index', '.meta']
for ext, id_ in zip(local_filenames, local_file_ids):
ext = str(NUMBER) + ext
filename = '%s/%s/model.ckpt.%s' % (model_path, model_type, ext)
request = drive_service.files().get_media(fileId=id_)
with open(filename, 'wb') as f:
downloader = MediaIoBaseDownload(f, request, chunksize=100*1024*1024)
done = False
pbar = tqdm(total=100, desc='%s' % ext)
progress = 0
while done is False:
status, done = downloader.next_chunk()
new_progress = int(status.progress() * 100)
pbar.update(new_progress - progress)
progress = new_progress
pbar.close()
print('Downloaded %s' % filename)
# Commented out IPython magic to ensure Python compatibility.
# %cd /content/grover
with open('sample/encoder.py', 'w') as f:
f.write("""\n\"\"\"Byte pair encoding utilities\n\nSome functions are adapted from OpenAI but with modifications\n\nhttps://github.com/openai/gpt-2\n\"\"\"\n\nimport os\nimport json\nimport regex as re\nfrom functools import lru_cache\nimport tensorflow as tf\nimport random\nimport numpy as np\n\n\n@lru_cache()\ndef bytes_to_unicode():\n \"\"\"\n Returns list of utf-8 byte and a corresponding list of unicode strings.\n The reversible bpe codes work on unicode strings.\n This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n This is a signficant percentage of your normal, say, 32K bpe vocab.\n To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n And avoids mapping to whitespace/control characters the bpe code barfs on.\n \"\"\"\n bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))\n cs = bs[:]\n n = 0\n for b in range(2 ** 8):\n if b not in bs:\n bs.append(b)\n cs.append(2 ** 8 + n)\n n += 1\n cs = [chr(n) for n in cs]\n return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n \"\"\"Return set of symbol pairs in a word.\n\n Word is represented as tuple of symbols (symbols being variable-length strings).\n \"\"\"\n pairs = set()\n prev_char = word[0]\n for char in word[1:]:\n pairs.add((prev_char, char))\n prev_char = char\n return pairs\n\n\nclass Encoder:\n def __init__(self, encoder, bpe_merges, errors='replace'):\n # self.encoder = {k: v + 1 for k, v in encoder.items()}\n # self.encoder['<|padding|>'] = 0\n # self.padding = 0\n #\n # del self.encoder['<|endoftext|>']\n #\n # for special_token_type in ['domain', 'date', 'authors', 'title', 'article', 'summary']:\n # setattr(self, f'begin_{special_token_type}', len(self.encoder))\n # self.encoder[f'<|begin{special_token_type}|>'] = len(self.encoder)\n #\n # setattr(self, f'end_{special_token_type}', len(self.encoder))\n # self.encoder[f'<|endof{special_token_type}|>'] = len(self.encoder)\n #\n # # This will be used if we want to combine short articles.\n # self.reset_context = len(self.encoder)\n # self.encoder['<|resetcontext|>'] = len(self.encoder)\n\n self.encoder = encoder\n self.endoftext = self.encoder['<|endoftext|>']\n\n ################################## END OF SPECIAL TOKENS TO ADD\n\n self.decoder = {v: k for k, v in self.encoder.items()}\n self.errors = errors # how to handle errors in decoding\n self.byte_encoder = bytes_to_unicode()\n self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))\n self.cache = {}\n\n # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions\n self.pat = re.compile(r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\")\n\n def bpe(self, token):\n if token in self.cache:\n return self.cache[token]\n word = tuple(token)\n pairs = get_pairs(word)\n\n if not pairs:\n return token\n\n while True:\n bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))\n if bigram not in self.bpe_ranks:\n break\n first, second = bigram\n new_word = []\n i = 0\n while i < len(word):\n try:\n j = word.index(first, i)\n new_word.extend(word[i:j])\n i = j\n except:\n new_word.extend(word[i:])\n break\n\n if word[i] == first and i < len(word) - 1 and word[i + 1] == second:\n new_word.append(first + second)\n i += 2\n else:\n new_word.append(word[i])\n i += 1\n new_word = tuple(new_word)\n word = new_word\n if len(word) == 1:\n break\n else:\n pairs = get_pairs(word)\n word = ' '.join(word)\n self.cache[token] = word\n return word\n\n def encode(self, text):\n bpe_tokens = []\n for token in re.findall(self.pat, text):\n token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n return bpe_tokens\n\n def decode(self, tokens):\n text = ''.join([self.decoder[token] for token in tokens])\n text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)\n return text\n\n def __len__(self):\n return len(self.encoder)\n\n @property\n def special_tokens_onehot(self):\n \"\"\" Return the IDs of all special tokens\"\"\"\n return [(self.decoder[i].startswith('<|') and self.decoder[i].endswith('|>')) for i in range(len(self))]\n\n\ndef get_encoder():\n directory_name = os.path.dirname(__file__)\n with open(os.path.join(directory_name, 'encoder.json'), 'r') as f:\n encoder = json.load(f)\n with open(os.path.join(directory_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:\n bpe_data = f.read()\n bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\\n')[1:-1]]\n return Encoder(\n encoder=encoder,\n bpe_merges=bpe_merges,\n )\n\n\n##############################################################\n# TURN SOMETHING INTO THE RIGHT FORMAT FOR AN EXAMPLE\n##############################################################\ndef _tokenize_article_pieces(encoder, item):\n \"\"\"\n Turn the article into tokens\n NOTE: in hindsight I kinda messed up here because the first token is always represented as a BPE continuation\n rather than an initial token in its own right. whoops....\n\n :param item: Contains things that need to be tokenized\n\n\n fields are ['domain', 'date', 'authors', 'title', 'article', 'summary']\n :return: dict\n \"\"\"\n # article_pieces = {\n # 'article': [encoder.begin_article] + encoder.encode(item['text']) + [encoder.end_article],\n # 'domain': [encoder.begin_domain] + encoder.encode(item['domain']) + [encoder.end_domain],\n # 'title': [encoder.begin_title] + encoder.encode(item['title']) + [encoder.end_title],\n # }\n # # 4/6: Attach the summary too, why the hell not\n # if item['summary'] and len(item['summary']) > 50:\n # article_pieces['summary'] = [encoder.begin_summary] + encoder.encode(item['summary']) + [encoder.end_summary]\n #\n # # 5/6: date\n # date_split = item['publish_date'].split('-')\n # assert len(date_split) == 3\n # assert date_split[0].isdigit()\n #\n # date_txt = ['January', 'February', 'March', 'April', 'May', 'June', 'July',\n # 'August', 'September', 'October', 'November', 'December'][int(date_split[0]) - 1] + ' {}, {}'.format(\n # date_split[1], date_split[2])\n # article_pieces['date'] = [encoder.begin_date] + encoder.encode(date_txt) + [encoder.end_date]\n #\n # # 6/6: authors\n # authors = ', '.join(item['authors'])\n # if len(authors) > 5:\n # article_pieces['authors'] = [encoder.begin_authors] + encoder.encode(authors) + [encoder.end_authors]\n return encoder.encode(item) + [encoder.endoftext]\n\n\ndef _cut_tokens_to_add_stuff(tokens, stuff_to_add, desired_size, padding_token):\n \"\"\"\n The idea behind this function is to take away tokens from `tokens' such that tokens[:LENGTH] + stuff_to_add becomes\n exactly at the right size (desired_size).\n\n :param tokens:\n :param stuff_to_add:\n :param desired_size:\n :return:\n \"\"\"\n if len(tokens) >= desired_size:\n return tokens\n\n # no way we can add this stuff\n if len(stuff_to_add) >= desired_size:\n return tokens\n\n if (len(tokens) + len(stuff_to_add)) <= desired_size:\n return tokens + stuff_to_add\n\n # Otherwise we'll have to actually cut\n tokens = tokens[:(desired_size - len(stuff_to_add) - 1)]\n tokens.append(padding_token)\n tokens.extend(stuff_to_add)\n return tokens\n\n\ndef tokenize_for_grover_training(encoder, item, desired_size=1024, unconditional_prob=0.35, metadata_dropout_prob=0.1,\n cut_prob=0.2):\n \"\"\"\n Not only will we tokenize an item with a BPE encoder, but we'll also put it in a nice format for language modeling.\n The goal is to MINIMIZE PADDING. If we don't fill up the desired size of 1024 tokens then we're wasting compute.\n\n The canonical order is\n\n DOMAIN DATE AUTHORS TITLE ARTICLE SUMMARY\n\n\n :param encoder:\n :param item: Contains things like\n {"url": "https://www.advocate.com/node/1010911",\n "timestamp": "20180118211607",\n "url_used": "https://web.archive.org/web/20180118211607id_/https://www.advocate.com/node/1010911",\n "domain": "advocate.com",\n "title": "Report: One-Third of Trump's Judicial Picks Are Anti-LGBT",\n "text": ....\n "summary": ....\n "authors": list\n "publish_date": ...\n }\n :param desired_size: the goal for how long the span will be\n :param unconditional_prob: The probability that we will generate JUST THE TEXT first.\n :param metadata_dropout_prob: The probability that we will drop out each item of metadata\n :param cut_prob: The probability that, if we're already over the desired size, we'll cut the article and start\n predicting metadata before the desired_size window ends.\n :return:\n \"\"\"\n # Get all the bits and pieces\n tokens = _tokenize_article_pieces(encoder, item)\n # canonical_metadata_order = ['domain', 'date', 'authors', 'title']\n #\n # # unconditional_prob is probability we only generate the text first, without any metadata\n # switch = random.random()\n # if switch < unconditional_prob:\n # assignments = {'article': 'a'}\n # chunk_a = article_pieces.pop('article')\n # chunk_b = []\n # for x in canonical_metadata_order + ['summary']:\n # if random.random() > metadata_dropout_prob:\n # chunk_b.extend(article_pieces.pop(x, []))\n # assignments[x] = 'b'\n # elif switch < 0.5:\n # # Put everything in chunk_a, without dropout\n # assignments = {}\n # chunk_a = []\n # chunk_b = []\n # for x in canonical_metadata_order + ['article', 'summary']:\n # chunk_a.extend(article_pieces.pop(x, []))\n # assignments[x] = 'a'\n # else:\n # assignments = {}\n # chunk_a = []\n # chunk_b = []\n # for k in canonical_metadata_order + ['article', 'summary']:\n # if random.random() < metadata_dropout_prob and k not in ('article', 'title'):\n # pass\n # elif random.random() < 0.5:\n # if k != 'summary':\n # chunk_a.extend(article_pieces.pop(k, []))\n # assignments[k] = 'a'\n # else:\n # chunk_b.extend(article_pieces.pop(k, []))\n # assignments[k] = 'b'\n #\n # if (len(chunk_a) + len(chunk_b)) <= desired_size:\n # return chunk_a + chunk_b\n #\n # if (assignments.get('article', '') == 'a') and (len(chunk_b) > 0) and (random.random() < cut_prob):\n # return _cut_tokens_to_add_stuff(chunk_a, chunk_b, desired_size, encoder.padding)\n #\n # tokens = chunk_a + chunk_b\n\n return tokens\n\n\ndef detokenize(encoder, tokens):\n return encoder.decode(tokens)\n\n\n#######################################\n\ndef create_int_feature(values):\n feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))\n return feature\n\n\ndef sliding_window(article, max_seq_length):\n \"\"\"\n Randomly sample some spans. It's a simple approximation of sliding window\n :param tokens:\n :param max_seq_length:\n :return:\n \"\"\"\n # if it's shorter, no need for this\n if len(article['input_ids']) <= max_seq_length:\n amount_to_pad = max_seq_length - len(article['input_ids'])\n yield article\n return\n\n num_spans = len(article['input_ids']) - max_seq_length + 1\n weights = np.ones(num_spans, dtype=np.float32)\n # weights[0] = max_seq_length\n weights /= weights.sum()\n\n num_to_yield = int(0.5 + len(article['input_ids']) / max_seq_length)\n starts = np.random.choice(num_spans, size=num_to_yield, replace=False, p=weights)\n\n input_ids = article.pop('input_ids')\n for i in starts.tolist():\n article['input_ids'] = input_ids[i:(i + max_seq_length)]\n yield article\n\n# def sliding_window(article, max_seq_length, pad_token):\n# \"\"\"\n# Randomly sample some spans. It's a simple approximation of sliding window\n# :param tokens:\n# :param max_seq_length:\n# :return:\n# \"\"\"\n# # if it's shorter, no need for this\n# if len(article['input_ids']) <= max_seq_length:\n# amount_to_pad = max_seq_length - len(article['input_ids'])\n# article['input_ids'].extend([pad_token] * amount_to_pad)\n# yield article\n# return\n#\n# num_spans = len(article['input_ids']) - max_seq_length + 1\n# weights = np.ones(num_spans, dtype=np.float32)\n# # weights[0] = max_seq_length\n# weights /= weights.sum()\n#\n# num_to_yield = int(0.5 + len(article['input_ids']) / max_seq_length)\n# starts = np.random.choice(num_spans, size=num_to_yield, replace=False, p=weights)\n#\n# input_ids = article.pop('input_ids')\n# for i in starts.tolist():\n# article['input_ids'] = input_ids[i:(i + max_seq_length)]\n# yield article\n\n\ndef format_context(encoder, news_article, target):\n \"\"\"\n Generates a news article given some partial information\n :param news_article: Contains context\n :param target: What we want to get an answer for.\n :return:\n \"\"\"\n canonical_metadata_order = ['domain', 'date', 'authors', 'title', 'article']\n tokens = []\n for metadata_category in canonical_metadata_order:\n metadata = news_article.get(metadata_category, '').strip()\n\n # This MIGHT BE needed because I think during training time we never saw empty articles\n # if metadata or ((metadata_category == 'article') and target != 'article'):\n if (metadata_category == 'article') and (target != 'article'):\n metadata = news_article.get('title', '') # Just copy from the title maybe?\n\n if metadata:\n tokens.append(encoder.__dict__[f'begin_{metadata_category}'])\n tokens.extend(encoder.encode(metadata))\n tokens.append(encoder.__dict__[f'end_{metadata_category}'])\n\n assert target in (canonical_metadata_order + ['summary'])\n tokens.append(encoder.__dict__[f'begin_{target}'])\n return tokens\n\ndef extract_generated_target(output_tokens, encoder, target):\n \"\"\"\n Given some tokens that were generated, extract the target\n :param output_tokens: [num_tokens] thing that was generated\n :param encoder: how they were encoded\n :param target: the piece of metadata we wanted to generate!\n :return:\n \"\"\"\n # Filter out first instance of start token\n assert output_tokens.ndim == 1\n\n start_ind = 0\n end_ind = output_tokens.shape[0]\n\n return {\n 'extraction': encoder.decode(output_tokens[start_ind:end_ind]),\n 'start_ind': start_ind,\n 'end_ind': end_ind,\n }\n\n# def extract_generated_target(output_tokens, encoder, target):\n# \"\"\"\n# Given some tokens that were generated, extract the target\n# :param output_tokens: [num_tokens] thing that was generated\n# :param encoder: how they were encoded\n# :param target: the piece of metadata we wanted to generate!\n# :return:\n# \"\"\"\n# # Filter out first instance of start token\n# assert output_tokens.ndim == 1\n#\n# start_tokens = output_tokens == encoder.__dict__[f'begin_{target}']\n# if np.any(start_tokens):\n# start_ind = np.argmax(start_tokens) + 1\n# else:\n# start_ind = 0\n#\n# end_tokens = output_tokens == encoder.__dict__[f'end_{target}']\n# if np.any(end_tokens):\n# end_ind = np.argmax(end_tokens)\n# else:\n# end_ind = output_tokens.shape[0]\n#\n# return {\n# 'extraction': encoder.decode(output_tokens[start_ind:end_ind]),\n# 'start_ind': start_ind,\n# 'end_ind': end_ind,\n# }\n\n\nif __name__ == '__main__':\n encoder = get_encoder()\n print("VOCAB SIZE IS {}".format(len(encoder.encoder)))\n""")
# Supports repetitions with -samples <number>
with open('sample/contextual_generate_cli.py', 'w') as f:
f.write("""\nimport tensorflow as tf\nimport numpy as np\nimport sys\nimport json\nimport sys\n\nsys.path.append('../')\nfrom lm.modeling import GroverModel, GroverConfig, _top_p_sample, sample\nfrom sample.encoder import get_encoder, format_context, _tokenize_article_pieces, extract_generated_target\nfrom tqdm import tqdm\n\nimport argparse\n\nparser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles')\nparser.add_argument(\n '-metadata_fn',\n dest='metadata_fn',\n type=str,\n help='Path to a JSONL containing metadata',\n)\nparser.add_argument(\n '-out_fn',\n dest='out_fn',\n type=str,\n help='Out jsonl, which will contain the completed jsons',\n)\nparser.add_argument(\n '-input',\n dest='input',\n type=str,\n help='Text to complete',\n)\nparser.add_argument(\n '-model_config_fn',\n dest='model_config_fn',\n default='../lm/configs/base.json',\n type=str,\n help='Configuration JSON for the model',\n)\nparser.add_argument(\n '-model_ckpt',\n dest='model_ckpt',\n default='../models/base/model.ckpt',\n type=str,\n help='checkpoint file for the model',\n)\nparser.add_argument(\n '-target',\n dest='target',\n default='article',\n type=str,\n help='What to generate for each item in metadata_fn. can be article (body), title, etc.',\n)\nparser.add_argument(\n '-batch_size',\n dest='batch_size',\n default=1,\n type=int,\n help='How many things to generate per context. will split into chunks if need be',\n)\nparser.add_argument(\n '-num_folds',\n dest='num_folds',\n default=1,\n type=int,\n help='Number of folds. useful if we want to split up a big file into multiple jobs.',\n)\nparser.add_argument(\n '-fold',\n dest='fold',\n default=0,\n type=int,\n help='which fold we are on. useful if we want to split up a big file into multiple jobs.'\n)\nparser.add_argument(\n '-max_batch_size',\n dest='max_batch_size',\n default=None,\n type=int,\n help='max batch size. You can leave this out and we will infer one based on the number of hidden layers',\n)\nparser.add_argument(\n '-top_p',\n dest='top_p',\n default=0.95,\n type=float,\n help='p to use for top p sampling. if this isn\\'t none, use this for everthing'\n)\nparser.add_argument(\n '-samples',\n dest='samples',\n default=1,\n type=int,\n help='num_samples',\n)\n\nargs = parser.parse_args()\n\nencoder = get_encoder()\nnews_config = GroverConfig.from_json_file(args.model_config_fn)\n\n# We might have to split the batch into multiple chunks if the batch size is too large\ndefault_mbs = {12: 32, 24: 16, 48: 3}\nmax_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]\n\n# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size\nnum_chunks = int(np.ceil(args.batch_size / max_batch_size))\nbatch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))\nprint("\\n~~\\nbatch size={}, max batch size={}, num chunks={}, batch size per chunk={}\\n~~\\n".format(\n args.batch_size, max_batch_size, num_chunks, batch_size_per_chunk), flush=True)\n\n# This controls the top p for each generation.\ntop_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p\n\n# with open(args.metadata_fn, 'r') as f:\n# articles = [json.loads(l) for i, l in enumerate(f) if i % args.num_folds == args.fold]\n\ntf_config = tf.ConfigProto(allow_soft_placement=True)\n\nwith tf.Session(config=tf_config, graph=tf.Graph()) as sess:\n initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None])\n p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk])\n eos_token = tf.placeholder(tf.int32, [])\n tokens, probs = sample(news_config=news_config, initial_context=initial_context,\n eos_token=eos_token, ignore_ids=None, p_for_topp=p_for_topp,\n do_topk=False)\n\n saver = tf.train.Saver()\n saver.restore(sess, args.model_ckpt)\n print('Loaded model.')\n text = input()\n while text != "":\n for i in range(args.samples):\n print("Sample,", i + 1, " of ", args.samples)\n # Let's go!\n encoded = _tokenize_article_pieces(encoder, text)\n context_formatted = []\n # for key in ['domain', 'date', 'authors', 'title', 'article']:\n # if key != args.target:\n # context_formatted.extend(article_pieces.pop(key, []))\n # context_formatted.append(encoder.__dict__['begin_{}'.format(args.target)])\n context_formatted.extend(encoded[:-1])\n # Format context end\n\n # Indices we definitely DONT WANT TO PREDICT\n ignore_ids_np = np.array(encoder.special_tokens_onehot)\n ignore_ids_np[encoder.endoftext] = 0\n\n gens = []\n gens_raw = []\n gen_probs = []\n\n # article['top_ps'] = top_p.reshape(-1).tolist()\n for chunk_i in range(num_chunks):\n tokens_out, probs_out = sess.run([tokens, probs],\n feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,\n eos_token: 60000,\n p_for_topp: top_p[chunk_i]})\n\n for t_i, p_i in zip(tokens_out, probs_out):\n extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=args.target)\n gens.append(extraction['extraction'])\n\n # article['gens_{}'.format(args.target)] = gens\n # article['gensraw_{}'.format(args.target)] = gens_raw\n # article['probs_{}'.format(args.target)] = gen_probs\n\n # these were in there for whatever reason...\n # article.pop('input_ids_conditional', None)\n # article.pop('input_ids_unconditional', None)\n # f_out.write(json.dumps(article) + '\\n')\n print(gens[0])\n text = input()\n""")
# Supports repetitions with -samples <number>
# This version supports multi line input, but colab doesn't support EOF from the stdin interface so there is no way to actaully press enter other than echoing input.
with open('sample/contextual_generate_cli_multiline.py', 'w') as f:
f.write("""\n\n\nimport tensorflow as tf\nimport numpy as np\nimport sys\nimport json\nimport sys\n\nsys.path.append('../')\nfrom lm.modeling import GroverModel, GroverConfig, _top_p_sample, sample\nfrom sample.encoder import get_encoder, format_context, _tokenize_article_pieces, extract_generated_target\nfrom tqdm import tqdm\n\nimport argparse\n\nparser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles')\nparser.add_argument(\n '-metadata_fn',\n dest='metadata_fn',\n type=str,\n help='Path to a JSONL containing metadata',\n)\nparser.add_argument(\n '-out_fn',\n dest='out_fn',\n type=str,\n help='Out jsonl, which will contain the completed jsons',\n)\nparser.add_argument(\n '-input',\n dest='input',\n type=str,\n help='Text to complete',\n)\nparser.add_argument(\n '-model_config_fn',\n dest='model_config_fn',\n default='../lm/configs/base.json',\n type=str,\n help='Configuration JSON for the model',\n)\nparser.add_argument(\n '-model_ckpt',\n dest='model_ckpt',\n default='../models/base/model.ckpt',\n type=str,\n help='checkpoint file for the model',\n)\nparser.add_argument(\n '-target',\n dest='target',\n default='article',\n type=str,\n help='What to generate for each item in metadata_fn. can be article (body), title, etc.',\n)\nparser.add_argument(\n '-batch_size',\n dest='batch_size',\n default=1,\n type=int,\n help='How many things to generate per context. will split into chunks if need be',\n)\nparser.add_argument(\n '-num_folds',\n dest='num_folds',\n default=1,\n type=int,\n help='Number of folds. useful if we want to split up a big file into multiple jobs.',\n)\nparser.add_argument(\n '-fold',\n dest='fold',\n default=0,\n type=int,\n help='which fold we are on. useful if we want to split up a big file into multiple jobs.'\n)\nparser.add_argument(\n '-max_batch_size',\n dest='max_batch_size',\n default=None,\n type=int,\n help='max batch size. You can leave this out and we will infer one based on the number of hidden layers',\n)\nparser.add_argument(\n '-top_p',\n dest='top_p',\n default=0.95,\n type=float,\n help='p to use for top p sampling. if this isn\\'t none, use this for everthing'\n)\nparser.add_argument(\n '-samples',\n dest='samples',\n default=1,\n type=int,\n help='num_samples',\n)\n\nargs = parser.parse_args()\n\nencoder = get_encoder()\nnews_config = GroverConfig.from_json_file(args.model_config_fn)\n\n# We might have to split the batch into multiple chunks if the batch size is too large\ndefault_mbs = {12: 32, 24: 16, 48: 3}\nmax_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]\n\n# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size\nnum_chunks = int(np.ceil(args.batch_size / max_batch_size))\nbatch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))\nprint("\\n~~\\nbatch size={}, max batch size={}, num chunks={}, batch size per chunk={}\\n~~\\n".format(\n args.batch_size, max_batch_size, num_chunks, batch_size_per_chunk), flush=True)\n\n# This controls the top p for each generation.\ntop_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p\n\n# with open(args.metadata_fn, 'r') as f:\n# articles = [json.loads(l) for i, l in enumerate(f) if i % args.num_folds == args.fold]\n\ntf_config = tf.ConfigProto(allow_soft_placement=True)\n\nwith tf.Session(config=tf_config, graph=tf.Graph()) as sess:\n initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None])\n p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk])\n eos_token = tf.placeholder(tf.int32, [])\n tokens, probs = sample(news_config=news_config, initial_context=initial_context,\n eos_token=eos_token, ignore_ids=None, p_for_topp=p_for_topp,\n do_topk=False)\n\n saver = tf.train.Saver()\n saver.restore(sess, args.model_ckpt)\n print('Loaded model.')\n text = "".join(sys.stdin.readlines())\n for i in range(args.samples):\n print("Sample,", i + 1, " of ", args.samples)\n # Let's go!\n encoded = _tokenize_article_pieces(encoder, text)\n context_formatted = []\n \n context_formatted.extend(encoded[:-1])\n # Format context end\n\n # Indices we definitely DONT WANT TO PREDICT\n ignore_ids_np = np.array(encoder.special_tokens_onehot)\n ignore_ids_np[encoder.endoftext] = 0\n\n gens = []\n gens_raw = []\n gen_probs = []\n\n # article['top_ps'] = top_p.reshape(-1).tolist()\n for chunk_i in range(num_chunks):\n tokens_out, probs_out = sess.run([tokens, probs],\n feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,\n eos_token: 60000,\n p_for_topp: top_p[chunk_i]})\n\n for t_i, p_i in zip(tokens_out, probs_out):\n extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=args.target)\n gens.append(extraction['extraction'])\n print(gens[0])\n""")
#THis option lets you type from colab, but you can't enter new lines
!PYTHONPATH=$(pwd) python3 sample/contextual_generate_cli.py -model_config_fn lm/configs/mega.json -model_ckpt models/mega/model.ckpt.800000
# newlines are not supported well by Google Colab input so please using echo or printf to pipe your input in if you want to us enew lines
# Just modify the text below to be what you want
!printf "Recycling is good for the world.\n\nNO! YOU COULD NOT BE MORE WRONG!!\n\n" | PYTHONPATH=$(pwd) python3 sample/contextual_generate_cli_multiline.py -model_config_fn lm/configs/mega.json -samples 10 -model_ckpt models/mega/model.ckpt.800000