modules/SwissArmyTransformer/sat/model/finetune/mlp_head.py (29 lines of code) (raw):

# -*- encoding: utf-8 -*- ''' @File : mlp_head.py @Time : 2021/12/12 20:44:09 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import torch from sat.model.base_model import BaseModel, BaseMixin, non_conflict class MLPHeadMixin(BaseMixin): def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005): super().__init__() self.activation_func = activation_func last_size = hidden_size self.layers = torch.nn.ModuleList() for sz in output_sizes: this_layer = torch.nn.Linear(last_size, sz, bias=bias) last_size = sz torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std) self.layers.append(this_layer) def final_forward(self, logits, **kw_args): for i, layer in enumerate(self.layers): if i > 0: logits = self.activation_func(logits) logits = layer(logits) return logits