in fastchat/train/train_baichuan.py [0:0]
def mask_targets(conversations, targets, tokenizer, conv):
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn + conv.sep2).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
rank0_print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return targets