in tasks/custom_tasks.py [0:0]
def prepend_prompt(dataset: tf.data.Dataset,
output_features: seqio.preprocessors.OutputFeaturesType,
sequence_length: Optional[
seqio.preprocessors.SequenceLengthType] = None,
prompt_mode: str = "",
key: str = "inputs",
mode: str = "") -> tf.data.Dataset:
"""Prepends a prompt at the beginning of an input sequence."""
del sequence_length
if prompt_mode and mode:
logging.info("Add prompt")
prompt_tokens = output_features[key].vocabulary.encode_tf(prompt_mode)
logging.info(prompt_tokens)
logging.info(dataset)
def add_to_inputs(x):
x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
return x
dataset = dataset.map(add_to_inputs)
return dataset