in low_rank_comparisons/src/model.py [0:0]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, len_past = None):
if past is None:
past_length = 0
past = [None] * len(self.h)
elif len_past is None:
# equal size for past. []
past_length = past[0][0].size(-2)
if position_ids is None and len_past is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
elif len_past is not None:
position_ids = (len_past).unsqueeze(1) #.long()
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
###### parse input_embeds;
if self.prefix_len > 0 or self.infix_len > 0:
_context_token_msk = (input_ids < self.config.prefix_cursor) * 1
_prefix_token_msk = (input_ids >= self.config.prefix_cursor) * (input_ids < self.config.infix_cursor) * 1
_infix_token_msk = (input_ids >= self.config.infix_cursor) * 1
#_prefix_embeds = None
if self.prefix_len > 0:
_prefix_tokens = (input_ids - self.config.prefix_cursor) * _prefix_token_msk
_prefix_embeds = self.adapter_pe(_prefix_tokens)
#_infix_embeds = None
if self.infix_len > 0:
_infix_tokens = (input_ids - self.config.infix_cursor) * _infix_token_msk
_infix_embeds = self.adapter_ie(_infix_tokens)
input_ids = input_ids.clamp(max = self.n_vocab-1)
inputs_embeds = self.wte(input_ids)
if self.prefix_len > 0 or self.infix_len > 0:
inputs_embeds = inputs_embeds * _context_token_msk.unsqueeze(-1)
if self.prefix_len > 0:
inputs_embeds = inputs_embeds + _prefix_embeds * _prefix_token_msk.unsqueeze(-1)
if self.infix_len > 0:
inputs_embeds = inputs_embeds + _infix_embeds * _infix_token_msk.unsqueeze(-1)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = []
for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past = layer_past, len_past=len_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape), presents