in fairscale/experimental/nn/mevo.py [0:0]
def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore
if not self.training and target is None:
return self.eval_forward(input)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("DEBUG cur, peak", cur_mem, mem)
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
if torch.is_grad_enabled():
assert input.requires_grad
input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape
vocab, d2 = self.proj_weight.shape
assert d_model == d2
split_dim = 0
input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens)
weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab)
inputs = torch.split(input, input_split_size, split_dim)
weight = self.trigger()
weights = torch.split(weight, weight_split_size, split_dim)
# Get maxs
maxs = []
for i in inputs:
m = None # max with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_m = GetMaxFunction.apply(i, w, self, w_idx, weight_split_size, split_dim)
if m is None:
m = _m
else:
m = torch.max(m, _m)
assert m is not None
maxs.append(m) # (tokens_tile,)
maxs_tensor = torch.cat(maxs) # (tokens,)
assert maxs_tensor.shape == (tokens,)
# Get sums.
sums = []
for idx, i in enumerate(inputs):
s = None # sum with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_s = GetSumFunction.apply(i, w, maxs[idx], self, w_idx, weight_split_size, split_dim)
if s is None:
s = _s
else:
s += _s
assert s is not None
sums.append(s) # (tokens_tile,)
sums_tensor = torch.cat(sums) # (tokens,)
assert sums_tensor.shape == (tokens,)
# select weights for targets
result = self.get_target_nlprob(input, self.proj_weight, target, maxs_tensor, sums_tensor)
if self.reduction == "mean":
result /= tokens
return result