I have created a word-level text generator using an LSTM model. But in my case, not every word is suitable to be selected. I want them to match additional conditions:
- Each word has a map: if a character is a vowel then it will write 1 if not, it will write 0 (for instance, overflow would be
10100010
). Then, the sentence generated needs to meet a given structure, for instance,01001100
(hi01
and friend001100
). - The last vowel of the last word must be the one provided. Let's say is e. (friend will do the job, then).
Thus, to handle this scenario, I've created a pandas dataframe with the following structure:
word last_vowel word_map
----- --------- ----------
hello o 01001
stack a 00100
jhon o 0010
This is my current workflow:
- Given the sentence structure, I choose a random word from the dataframe which matches the pattern. For instance, if the sentence structure is
0100100100100
, we can choose the word hello, as its vowel structure is01001
. - I subtract the selected word from the remaining structure:
0100100100100
will become00100100
as we've removed the initial01001
(hello). - I retrieve all the words from the dataframe which matches part of the remaining structure, in this case, stack
00100
and jhon0010
. - I pass the current word sentence content (just hello by now) to the LSTM model, and it retrieves the weights of each word.
- But I don't just want to select the best option, I want to select the best option contained in the selection of point 3. So I choose the word with the highest estimation within that list, in this case, stack.
- Repeat from point 2 until the remaining sentence structure is empty.
That works like a charm, but there is one remaining condition to handle: the last vowel of the sentence.
My way to deal with this issue is the following:
- Generating 1000 sentences forcing that the last vowel is the one specified.
- Get the rmse of the weights returned by the LSTM model. The better the output, the higher the weights will be.
- Selecting the sentence which retrieves the higher rank.
Do you think is there a better approach? Maybe a GAN or reinforcement learning?
EDIT: I think another approach would be adding WFST. I've heard about pynini library, but I don't know how to apply it to my specific context.