in models/feedback.py [0:0]
def __init__(self, args):
super(FeedbackTransformer, self).__init__(args)
merged_layer_count = args.nlayers + 1
self.single_memory_attn_params = nn.Parameter(
torch.zeros(1, merged_layer_count)
)
self.register_buffer(
"single_memory_attn_buf", torch.zeros(1, merged_layer_count)
)
if self.args.share_proj_kv:
for l in range(1, len(self.layers)):
self.get_layer(l).attn.proj_key.weight = self.get_layer(
0
).attn.proj_key.weight
self.get_layer(l).attn.proj_val.weight = self.get_layer(
0
).attn.proj_val.weight
if self.args.pre_norm:
# make sure key and values are normalized in the same way
self.get_layer(l).norm1.weight = self.get_layer(0).norm1.weight
self.get_layer(l).norm1.bias = self.get_layer(0).norm1.bias