def create_hooks()

in pretrain/PyTorch/distributed_apex.py [0:0]


    def create_hooks(self):
        # Fallback hook that's only called at the end of backward.
        # Used if you deliberately want to delay allreduces to the end, or to refresh the
        # bucket structure that will be used to overlap communication with computation in later
        # iterations.
        def allreduce_params():
            # Bucket record refresh
            if not self.delay_allreduce:
                if self.needs_refresh:
                    self.sync_bucket_structure()

                    self.needs_refresh = False

            self.allreduce_fallback()

        def overlapping_backward_epilogue():
            self.reduction_stream.record_event(self.reduction_event)
            torch.cuda.current_stream().wait_event(self.reduction_event)

            # Sanity checks that all the buckets were kicked off
            if self.next_bucket != self.num_buckets:
                raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}).  ".format(
                                   self.next_bucket, self.num_buckets),
                                   "This probably indicates some buckets were not allreduced.")

            for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
                if actual != expected:
                    raise RuntimeError(
                        "Some param buckets were not allreduced.")

        self.grad_accs = []
        for param in self.module.parameters():
            if param.requires_grad:
                def wrapper(param):
                    param_tmp = param.expand_as(param)
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]

                    def allreduce_hook(*unused):
                        # user must explicitly specify when to do all reduce
                        if self.need_reduction == False:
                            #print("Does not need Reduction")
                            return
                        #print("Needs Reduction")
                        if self.delay_allreduce or self.needs_refresh:
                            # TODO:  How do we want to handle multiple backward passes between
                            # each forward, e.g., backward passes with retain_graph=True?
                            # needs_refresh and callback_queued are both vulnerable states.
                            if not self.delay_allreduce and self.needs_refresh:
                                # Use the backward pass to build the bucket structure on the fly.
                                active_i = self.param_id_to_active_i[id(param)]

                                # Float, half, and double tensors are grouped into buckets separately.
                                current_type = self.param_type_to_tmp_i[param.type(
                                )]

                                self.tmp_buckets[current_type].append(active_i)

                                ship_tmp_bucket = False
                                if self.custom_allreduce_triggers:
                                    if id(param) in self.allreduce_trigger_params:
                                        ship_tmp_bucket = True
                                else:
                                    self.tmp_numels[current_type] += param.numel()
                                    if self.tmp_numels[current_type] >= self.message_size:
                                        ship_tmp_bucket = True

                                # To consider:  If custom_allreduce_triggers are in use, ship all
                                # tmp_buckets, not just tmp_buckets[current_type].
                                if ship_tmp_bucket:
                                    self.active_i_buckets.append(
                                        self.tmp_buckets[current_type])
                                    self.tmp_buckets[current_type] = []
                                    self.tmp_numels[current_type] = 0

                            if not self.callback_queued:
                                Variable._execution_engine.queue_callback(
                                    allreduce_params)
                                self.callback_queued = True
                        else:
                            if not self.callback_queued:
                                Variable._execution_engine.queue_callback(
                                    overlapping_backward_epilogue)
                                self.callback_queued = True

                            self.comm_ready_buckets(param)

                    grad_acc.register_hook(allreduce_hook)
                    self.grad_accs.append(grad_acc)

                wrapper(param)