in sparseconvnet/networkArchitectures.py [0:0]
def SparseResNet(dimension, nInputPlanes, layers):
"""
pre-activated ResNet
e.g. layers = {{'basic',16,2,1},{'basic',32,2}}
"""
nPlanes = nInputPlanes
m = scn.Sequential()
def residual(nIn, nOut, stride):
if stride > 1:
return scn.Convolution(dimension, nIn, nOut, 3, stride, False)
elif nIn != nOut:
return scn.NetworkInNetwork(nIn, nOut, False)
else:
return scn.Identity()
for blockType, n, reps, stride in layers:
for rep in range(reps):
if blockType[0] == 'b': # basic block
if rep == 0:
m.add(scn.BatchNormReLU(nPlanes))
m.add(
scn.ConcatTable().add(
scn.Sequential().add(
scn.SubmanifoldConvolution(
dimension,
nPlanes,
n,
3,
False) if stride == 1 else scn.Convolution(
dimension,
nPlanes,
n,
3,
stride,
False)) .add(
scn.BatchNormReLU(n)) .add(
scn.SubmanifoldConvolution(
dimension,
n,
n,
3,
False))) .add(
residual(
nPlanes,
n,
stride)))
else:
m.add(
scn.ConcatTable().add(
scn.Sequential().add(
scn.BatchNormReLU(nPlanes)) .add(
scn.SubmanifoldConvolution(
dimension,
nPlanes,
n,
3,
False)) .add(
scn.BatchNormReLU(n)) .add(
scn.SubmanifoldConvolution(
dimension,
n,
n,
3,
False))) .add(
scn.Identity()))
nPlanes = n
m.add(scn.AddTable())
m.add(scn.BatchNormReLU(nPlanes))
return m