def training_step()

in rag-end2end-retriever/finetune_rag.py [0:0]


    def training_step(self, batch, batch_idx) -> Dict:
        global isEmUpdateBusy  # use to check whether the entire embedding update process is finished or not
        global isAddIndexBusy  # use to check whether the entire indexing process  is finished or not
        global processes  # use to keep threads embedding update processes
        global threadHandle_index  # use to keep thread in embedding indexing processes

        if (self.trainer.global_rank == 0) and (self.custom_config.end2end):
            if (not batch_idx == 0) and (batch_idx % self.custom_config.indexing_freq == 0):
                free_gpu_list = []
                nvmlInit()
                deviceCount = nvmlDeviceGetCount()

                my_list = json.loads(self.custom_config.gpu_order)

                for i in range(deviceCount):
                    handle = nvmlDeviceGetHandleByIndex(i)
                    info = nvmlDeviceGetMemoryInfo(handle)

                    if info.used / 1e6 < 15:
                        position = my_list.index(i)
                        free_gpu_list.append("cuda:" + str(position))

                if len(free_gpu_list) >= self.custom_config.index_gpus:
                    has_free_gpus = True

                else:
                    has_free_gpus = False

                if (not isEmUpdateBusy) and has_free_gpus:
                    model_copy = type(self.model.rag.ctx_encoder)(
                        self.config_dpr
                    )  # get a new instance  #this will be load in the CPU
                    model_copy.load_state_dict(self.model.rag.ctx_encoder.state_dict())  # copy weights

                    processes = []

                    if len(free_gpu_list) > self.custom_config.index_gpus:
                        cuda_devices = random.sample(free_gpu_list, self.custom_config.index_gpus)
                    else:
                        cuda_devices = free_gpu_list

                    num_processes = len(cuda_devices)

                    for rank in range(num_processes):
                        logger.info("Iniitializing  embedding calculation process rank{}".format(rank))
                        device = cuda_devices[rank]
                        p = multiprocessing.Process(
                            target=embed_update,
                            args=(
                                copy.deepcopy(model_copy),
                                num_processes,
                                device,
                                rank,
                                self.custom_config.shard_dir,
                                self.custom_config.csv_path,
                            ),
                        )
                        processes.append(p)

                    for p in processes:
                        p.start()

                    isEmUpdateBusy = True

            if isEmUpdateBusy and (not isAddIndexBusy):
                index_process_list = [processes[k].is_alive() for k in range(self.custom_config.index_gpus)]
                if (
                    sum(index_process_list) == 0
                ):  # If entire list is false, we can say all embedding calculation process has finished
                    logger.info("Start adding the index")
                    threadHandle_index = multiprocessing.Process(
                        target=add_index,
                        args=(
                            self.custom_config.shard_dir,
                            self.config.index_path,
                        ),
                    )
                    threadHandle_index.start()
                    isAddIndexBusy = True

            # check when index building has started
            if isAddIndexBusy:
                # check still the index_building process is happening
                if not threadHandle_index.is_alive():
                    logger.info("Merging the dataset shards")
                    saved_dataset_shards = []

                    for address in glob(str(self.custom_config.shard_dir) + "/*/"):
                        saved_dataset_shards.append(load_from_disk(address))

                    concat = concatenate_datasets(saved_dataset_shards)
                    concat.save_to_disk(self.config.passages_path)  # here we update the main passage file on the disk
                    logger.info("done updating the dataset")

                    # To Do (@Aaron) : Useful in the future dynamic memory implementation.
                    # if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
                    # logger.info("then updating the index")
                    # shutil.copy(self.custom_config.temp_index, self.config.idex_path)

                    logger.info("Loading new passages and iniitalzing new index")
                    self.trainer.model.module.module.model.rag.retriever.re_load()
                    self.trainer.model.module.module.model.rag.retriever.init_retrieval()

                    isEmUpdateBusy = False
                    isAddIndexBusy = False
        self.trainer.strategy.barrier("barrier")

        loss_tensors = self._step(batch)

        logs = dict(zip(self.loss_names, loss_tensors))
        # tokens per batch
        tgt_pad_token_id = (
            self.tokenizer.generator.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        src_pad_token_id = (
            self.tokenizer.question_encoder.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        logs["tpb"] = (
            batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
        )
        self.log("loss", loss_tensors[0])
        return loss_tensors[0]