def forward()

in pretrain/PyTorch/distributed_apex.py [0:0]


    def forward(self, *inputs, **kwargs):
        result = self.module(*inputs, **kwargs)

        if not self.delay_allreduce:
            param_list = [
                param for param in self.module.parameters() if param.requires_grad]

            # Conditions under which to refresh self.record
            # Forward has the authority to set needs_refresh to True, but only allreduce_params
            # in backward has the authority to set needs_refresh to False.
            # Parentheses are not necessary for correct order of operations, but make the intent clearer.
            if ((not self.active_params) or
                (len(param_list) != len(self.active_params)) or
                    any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
                self.needs_refresh = True
            #self.needs_refresh = True
            if self.needs_refresh:
                self.active_i_buckets = []
                self.buckets = []
                # [running half, float, double buckets]
                self.tmp_buckets = [[], [], []]
                self.tmp_numels = [0, 0, 0]
                self.bucket_sizes = []
                self.param_id_to_active_i = {
                    id(param): i for i, param in enumerate(param_list)}
                self.param_id_to_bucket = {}
            else:
                self.buckets = [[None for _ in range(self.bucket_sizes[i])]
                                for i in range(self.num_buckets)]
                self.buckets_ready_size = [0 for i in range(self.num_buckets)]
                if(self.retain_allreduce_buffers):
                    self.allreduce_buffers = [
                        None for _ in range(self.num_buckets)]
                self.next_bucket = 0
                self.ready_buckets_not_reduced = set()

            self.active_params = param_list

        self.callback_queued = False

        return result