lib/mlp.py (26 lines of code) (raw):
import torch as th
from torch import nn
from lib import misc
from lib import torch_util as tu
class MLP(nn.Module):
def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtype=th.float32):
super().__init__()
self.insize = insize
self.nhidlayer = nhidlayer
self.outsize = outsize
in_sizes = [insize] + [hidsize] * nhidlayer
out_sizes = [hidsize] * nhidlayer + [outsize]
self.layers = nn.ModuleList(
[tu.NormedLinear(insize, outsize, dtype=dtype) for (insize, outsize) in misc.safezip(in_sizes, out_sizes)]
)
self.hidactiv = hidactiv
def forward(self, x):
*hidlayers, finallayer = self.layers
for layer in hidlayers:
x = layer(x)
x = self.hidactiv(x)
x = finallayer(x)
return x
@property
def output_shape(self):
return (self.outsize,)