in crypten/nn/module.py [0:0]
def forward(self, x):
# unpack inputs:
if len(x) == 2:
x, weight = x
bias = None
elif len(x) == 3:
x, weight, bias = x
else:
raise ValueError(f"Conv module must have 2 or 3 inputs, not {len(x)}")
# prepare inputs into convolution function:
dim = weight.dim() - 2
if dim < 1 or dim > 2:
raise ValueError(
f"Convolution on {dim}-dimensional input is not supported."
)
args = [weight]
kwargs = {
"stride": self.stride,
"padding": self.padding,
"dilation": self.dilation,
"groups": self.groups,
}
# identify correct convolution function to use:
if torch.is_tensor(x):
func = getattr(torch.nn.functional, f"conv{dim}d", None)
args = [x] + args + bias # torch function takes different inputs
else:
func = getattr(x, f"conv{dim}d", None)
# perform the convolution:
x = func(*args, **kwargs)
# add the bias term if it is specified, and wasn;t already added:
if not torch.is_tensor(x) and bias is not None:
bias = bias.unsqueeze(0)
while bias.dim() < x.dim():
bias = bias.unsqueeze(-1)
x = x.add(bias)
return x