def forward_with_model()

in train.py [0:0]


def forward_with_model(model, inputs, labels, weight_dtype=torch.float16):
    input_ids = inputs.input_ids
    pixel_values = inputs.pixel_values.to(weight_dtype)
    labels = labels
    return model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)