in timm/models/eva.py [0:0]
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.dynamic_img_size:
B, H, W, C = x.shape
if self.pos_embed is not None:
prev_grid_size = self.patch_embed.grid_size
pos_embed = resample_abs_pos_embed(
self.pos_embed,
new_size=(H, W),
old_size=prev_grid_size,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
pos_embed = None
x = x.view(B, -1, C)
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
else:
pos_embed = self.pos_embed
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if pos_embed is not None:
x = x + pos_embed
if self.reg_token is not None:
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
x = torch.cat(to_cat + [x], dim=1)
x = self.pos_drop(x)
# obtain shared rotary position embedding and apply patch dropout
if self.patch_drop is not None:
x, keep_indices = self.patch_drop(x)
if rot_pos_embed is not None and keep_indices is not None:
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
return x, rot_pos_embed