in models/vision_transformer.py [0:0]
def load_state_dict(self, state, strict=True):
# shape of pos_embedding is (seq_length, 1, hidden_dim)
pos_embedding = state["encoder.pos_embedding"]
seq_length, n, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(
f"Unexpected position embedding shape: {pos_embedding.shape}"
)
if hidden_dim != self.hidden_dim:
raise ValueError(
f"Position embedding hidden_dim incorrect: {hidden_dim}"
f", expected: {self.hidden_dim}"
)
new_seq_length = self.seq_length
if new_seq_length != seq_length:
# need to interpolate the weights for the position embedding
# we do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid
if self.classifier == "token":
# the class token embedding shouldn't be interpolated so we split it up
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:1, :, :]
pos_embedding_img = pos_embedding[1:, :, :]
else:
pos_embedding_token = pos_embedding[:0, :, :] # empty data
pos_embedding_img = pos_embedding
# (seq_length, 1, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(1, 2, 0)
seq_length_1d = int(math.sqrt(seq_length))
assert (
seq_length_1d * seq_length_1d == seq_length
), "seq_length is not a perfect square"
logging.info(
"Interpolating the position embeddings from image "
f"{seq_length_1d * self.patch_size} to size {self.image_size}"
)
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(
1, hidden_dim, seq_length_1d, seq_length_1d
)
new_seq_length_1d = self.image_size // self.patch_size
# use bicubic interpolation - it gives significantly better results in
# the test `test_resolution_change`
new_pos_embedding_img = torch.nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode="bicubic",
align_corners=True,
)
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_l)
new_pos_embedding_img = new_pos_embedding_img.reshape(
1, hidden_dim, new_seq_length
)
# (1, hidden_dim, new_seq_length) -> (new_seq_length, 1, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(2, 0, 1)
new_pos_embedding = torch.cat(
[pos_embedding_token, new_pos_embedding_img], dim=0
)
state["encoder.pos_embedding"] = new_pos_embedding
super().load_state_dict(state, strict=strict)