in src/models.py [0:0]
def __init__(self, D, W, skip, n_in, n_out, config, net_idx):
super(BaseNet, self).__init__()
self.net_idx = net_idx
if "auto" in skip:
skip_layer = 7
if len(skip) > 4:
skip_layer = int(skip[4:])
if config.posEnc and config.posEnc[net_idx] and "RayMarch" in config.inFeatures[net_idx]:
if config.posEnc[net_idx] == "nerf":
freq = config.posEncArgs[net_idx].split("-")
posInputs = (int(freq[0]))*6 + 3
dirInputs = (int(freq[1]))*6 + 2
# NOTE: this assumes 8 layers I guess
skip = f"0::{posInputs}-{D*skip_layer//8}:{posInputs}:"
# did not override so set to nothing
if "auto" in skip:
print("Warning auto skip setup not detectable, using no skip connections")
skip = ""
self.name = f"relu{self.net_idx}({W}x{D}{skip.replace(':','.') if skip else ''})"
self.D = D
self.W = W
self.inputLocations = {0:(0,n_in)}
if skip:
self.inputLocations = dict()
decode_skips = [p for p in skip.split('-')]
for s in decode_skips:
match = re.search('^([0-9]+)(:?)([0-9]*)(:?)([0-9]*)$', s)
if not match:
raise Exception("could not decode skip info")
loc = match.group(1)
has_first = match.group(2)
start_feat = match.group(3)
has_inbetween = match.group(4)
end_feat = match.group(5)
if has_first == '' and has_inbetween == '':
#all
self.inputLocations[int(loc)] = (0,n_in)
elif has_first == ':' and has_inbetween == '':
single = int(start_feat+end_feat)
self.inputLocations[int(loc)] = (single, single+1)
else:
istart = int(start_feat) if start_feat != '' else 0
iend = int(end_feat) if end_feat != '' else n_in
self.inputLocations[int(loc)] = (istart, iend)
if 0 not in self.inputLocations:
self.inputLocations[0] = (0,n_in)
self.n_in = n_in
self.n_out = n_out
layers = [nn.Linear(self.inputLocations[0][1]-self.inputLocations[0][0], self.W)]
for i in range(1, self.D):
layers.append(nn.Linear(self.inputLocations[i][1]-self.inputLocations[i][0] + self.W if i in self.inputLocations else self.W, self.W if i != self.D - 1 else self.n_out))
# layers.append(nn.Linear(self.W, self.n_out))
self.layers = nn.ModuleList(layers)
self.activation = F.relu
for i, l in enumerate(self.layers):
nn.init.kaiming_normal_(l.weight)
self.init_weights()