def forward()

in fairscale/experimental/nn/mevo.py [0:0]


    def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor:  # type: ignore
        if not self.training and target is None:
            return self.eval_forward(input)

        if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
            cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
            mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
            print("DEBUG cur, peak", cur_mem, mem)
        assert isinstance(input, torch.Tensor)
        assert isinstance(target, torch.Tensor)
        if torch.is_grad_enabled():
            assert input.requires_grad
        input, target = _reshape_inputs(input, target)

        tokens, d_model = input.shape
        vocab, d2 = self.proj_weight.shape
        assert d_model == d2
        split_dim = 0
        input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens)
        weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab)
        inputs = torch.split(input, input_split_size, split_dim)
        weight = self.trigger()
        weights = torch.split(weight, weight_split_size, split_dim)

        # Get maxs
        maxs = []
        for i in inputs:
            m = None  # max with (tokens_tile,) shape
            for w_idx, w in enumerate(weights):
                _m = GetMaxFunction.apply(i, w, self, w_idx, weight_split_size, split_dim)
                if m is None:
                    m = _m
                else:
                    m = torch.max(m, _m)
            assert m is not None
            maxs.append(m)  # (tokens_tile,)
        maxs_tensor = torch.cat(maxs)  # (tokens,)
        assert maxs_tensor.shape == (tokens,)

        # Get sums.
        sums = []
        for idx, i in enumerate(inputs):
            s = None  # sum with (tokens_tile,) shape
            for w_idx, w in enumerate(weights):
                _s = GetSumFunction.apply(i, w, maxs[idx], self, w_idx, weight_split_size, split_dim)
                if s is None:
                    s = _s
                else:
                    s += _s
            assert s is not None
            sums.append(s)  # (tokens_tile,)
        sums_tensor = torch.cat(sums)  # (tokens,)
        assert sums_tensor.shape == (tokens,)

        # select weights for targets
        result = self.get_target_nlprob(input, self.proj_weight, target, maxs_tensor, sums_tensor)
        if self.reduction == "mean":
            result /= tokens
        return result