modules/SwissArmyTransformer/sat/model/finetune/mlp_head.py (29 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : mlp_head.py
@Time : 2021/12/12 20:44:09
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from sat.model.base_model import BaseModel, BaseMixin, non_conflict
class MLPHeadMixin(BaseMixin):
def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005):
super().__init__()
self.activation_func = activation_func
last_size = hidden_size
self.layers = torch.nn.ModuleList()
for sz in output_sizes:
this_layer = torch.nn.Linear(last_size, sz, bias=bias)
last_size = sz
torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std)
self.layers.append(this_layer)
def final_forward(self, logits, **kw_args):
for i, layer in enumerate(self.layers):
if i > 0:
logits = self.activation_func(logits)
logits = layer(logits)
return logits