TensorFlow addons
provides beam-search
facility. Quoting from the official documentation:
Following is the class-signature:
tfa.seq2seq.BeamSearchDecoder(
cell: tf.keras.layers.Layer,
beam_width: int,
embedding_fn: Optional[Callable] = None,
output_layer: Optional[tf.keras.layers.Layer] = None,
length_penalty_weight: tfa.types.FloatTensorLike = 0.0,
coverage_penalty_weight: tfa.types.FloatTensorLike = 0.0,
reorder_tensor_arrays: bool = True,
**kwargs
)
And here is an example:
tiled_encoder_outputs = tfa.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tfa.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tfa.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.get_initial_state(
batch_size=true_batch_size * beam_width, dtype=dtype)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)