def create_hooks()

in apex/apex/parallel/distributed.py [0:0]


    def create_hooks(self):
        # Fallback hook that's only called at the end of backward.
        # Used if you deliberately want to delay allreduces to the end, or to refresh the 
        # bucket structure that will be used to overlap communication with computation in later
        # iterations.
        def allreduce_params():
            # Bucket record refresh
            if not self.delay_allreduce:
                if self.needs_refresh:
                    self.sync_bucket_structure()

                    self.needs_refresh = False

            self.allreduce_fallback()


        def overlapping_backward_epilogue():
            self.reduction_stream.record_event(self.reduction_event)
            torch.cuda.current_stream().wait_event(self.reduction_event)
     
            # Sanity checks that all the buckets were kicked off
            if self.next_bucket != self.num_buckets:
                raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}).  ".format(
                                   self.next_bucket, self.num_buckets),
                                   "This probably indicates some buckets were not allreduced.")

            for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
                if actual != expected:
                    raise RuntimeError("Some param buckets were not allreduced.")
           

        self.grad_accs = []
        for param in self.module.parameters():
            if param.requires_grad:
                def wrapper(param):
                    param_tmp = param.expand_as(param)
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]

                    def allreduce_hook(*unused):
                        if not self._disable_allreduce:
                            if self.delay_allreduce or self.needs_refresh:
                                # TODO:  How do we want to handle multiple backward passes between
                                # each forward, e.g., backward passes with retain_graph=True?
                                # needs_refresh and callback_queued are both vulnerable states.
                                if not self.delay_allreduce and self.needs_refresh:
                                    # Use the backward pass to build the bucket structure on the fly.
                                    active_i = self.param_id_to_active_i[id(param)]

                                    # Float, half, and double tensors are grouped into buckets separately.
                                    current_type = self.param_type_to_tmp_i[param.type()]
  
                                    self.tmp_buckets[current_type].append(active_i)                          

                                    ship_tmp_bucket = False
                                    if self.custom_allreduce_triggers:
                                        if id(param) in self.allreduce_trigger_params:
                                            ship_tmp_bucket = True
                                    else:
                                        self.tmp_numels[current_type] += param.numel()
                                        if self.tmp_numels[current_type] >= self.message_size:
                                            ship_tmp_bucket = True

                                    # To consider:  If custom_allreduce_triggers are in use, ship all
                                    # tmp_buckets, not just tmp_buckets[current_type].
                                    if ship_tmp_bucket:
                                        self.active_i_buckets.append(self.tmp_buckets[current_type])
                                        self.tmp_buckets[current_type] = []
                                        self.tmp_numels[current_type] = 0
                                
                                if not self.callback_queued:
                                    Variable._execution_engine.queue_callback(allreduce_params)
                                    self.callback_queued = True
                            else:
                                if not self.callback_queued:
                                    Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
                                    self.callback_queued = True 

                                self.comm_ready_buckets(param)
                        
                    grad_acc.register_hook(allreduce_hook)
                    self.grad_accs.append(grad_acc)

                wrapper(param)