in distilvit/train.py [0:0]
def data_collator(tokenizer, features):
# XXX change this so it also works with flickr's labels_0, labels_1 etc
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
if "label" in first and first["label"] is not None:
label = (
first["label"].item()
if isinstance(first["label"], torch.Tensor)
else first["label"]
)
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
dtype = (
torch.long if isinstance(first["label_ids"][0], int) else torch.float
)
batch["labels"] = torch.tensor(
[f["label_ids"] for f in features], dtype=dtype
)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
# make sure we pad or truncate
if k == "labels":
truncated_features = []
for f in features:
item = f[k]
if len(item) != MAX_LENGTH:
print(
f"Found item of size {len(item)}), truncating or padding"
)
if len(item) > MAX_LENGTH:
item = item[:MAX_LENGTH]
else:
item = item + [tokenizer.pad_token_id] * (
MAX_LENGTH - len(item)
)
assert len(item) == MAX_LENGTH
truncated_features.append(item)
batch[k] = torch.tensor(truncated_features)
else:
batch[k] = torch.tensor([f[k] for f in features])
return batch