2

I trained the gpt-2-simple chat bot model but I am unable to save it. It's important for me to download the trained model from colab because otherwise I have to download the 355M model each time (see below code).

I tried various methods to save the trained model (like gpt2.saveload.save_gpt2()), but none worked and I don't have any more ideas.

My training code:

%tensorflow_version 2.x
!pip install gpt-2-simple

import gpt_2_simple as gpt2
import json

gpt2.download_gpt2(model_name="355M")

raw_data = '/content/drive/My Drive/data.json'

with open(raw_data, 'r') as f:
    df =json.load(f)

data = []

for x in df:
    for y in range(len(x['dialog'])-1):
        a = '[BOT] : ' + x['dialog'][y+1]['text']
        q = '[YOU] : ' + x['dialog'][y]['text']
        data.append(q)
        data.append(a)

with open('chatbot.txt', 'w') as f:
     for line in data:
        try:
            f.write(line)
            f.write('\n')
        except:
            pass

file_name = "/content/chatbot.txt"

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='fresh',
              run_name='run1',
              print_every=10,
              sample_every=100,
              save_every=100
              )

while True:
  ques = input("Question : ")
  inp = '[YOU] : '+ques+'\n'+'[BOT] :'
  x = gpt2.generate(sess,
                length=20,
                temperature = 0.6,
                include_prefix=False,
                prefix=inp,
                nsamples=1,
                )
Kyle F Hartzenberg
  • 2,567
  • 3
  • 6
  • 24
argo
  • 21
  • 2

1 Answers1

1

The gpt-2-simple repository README.md links an example Colab notebook which states the following:

Other optional-but-helpful parameters for gpt2.finetune:

  • restore_from: Set to fresh to start training from the base GPT-2, or set to latest to restart training from an existing checkpoint.
  • ...
  • run_name: subfolder within checkpoint to save the model. This is useful if you want to work with multiple models (will also need to specify run_name when loading the model)
  • overwrite: Set to True if you want to continue finetuning an existing model (w/ restore_from='latest') without creating duplicate copies.

The README.md also states that model checkpoints are stored in /checkpoint/run1 by default and that one can pass a run_name parameter to finetune and load_gpt2 if you want to store/load multiple models in a checkpoint folder.

Putting this altogether you should be able to do the following to work from saved models instead of re-downloading each time:

import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()

# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)

# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='latest',
              run_name='run1',
              overwrite=True,
              print_every=10,
              sample_every=100,
              save_every=500
)

See the source code for the load_gpt2() and finetune() functions for more specifics.

Kyle F Hartzenberg
  • 2,567
  • 3
  • 6
  • 24