in crypten/cuda/cuda_tensor.py [0:0]
def matmul(x, y, *args, **kwargs):
# Use 4 blocks if each dot product is 256 elements or larger to prevent overflow in the sum
nb = 3 if x.size(-1) < 256 else 4
# Prepend 1 to the dimension of x or y if it is 1-dimensional
remove_x, remove_y = False, False
if x.dim() == 1:
x = x.view(1, x.shape[0])
remove_x = True
if y.dim() == 1:
y = y.view(y.shape[0], 1)
remove_y = True
x_encoded = CUDALongTensor.__encode_as_fp64(x, nb).data
y_encoded = CUDALongTensor.__encode_as_fp64(y, nb).data
# Span x and y for cross multiplication
repeat_idx = [1] * (x_encoded.dim() - 1)
x_enc_span = x_encoded.repeat(nb, *repeat_idx)
y_enc_span = torch.repeat_interleave(y_encoded, repeats=nb, dim=0)
# Broadcasting
for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)):
if x_enc_span.ndim > y_enc_span.ndim:
y_enc_span.unsqueeze_(1)
else:
x_enc_span.unsqueeze_(1)
z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs)
if remove_x:
z_encoded.squeeze_(-2)
if remove_y:
z_encoded.squeeze_(-1)
return CUDALongTensor.__decode_as_int64(z_encoded, nb)