in rag-end2end-retriever/finetune_rag.py [0:0]
def training_step(self, batch, batch_idx) -> Dict:
global isEmUpdateBusy # use to check whether the entire embedding update process is finished or not
global isAddIndexBusy # use to check whether the entire indexing process is finished or not
global processes # use to keep threads embedding update processes
global threadHandle_index # use to keep thread in embedding indexing processes
if (self.trainer.global_rank == 0) and (self.custom_config.end2end):
if (not batch_idx == 0) and (batch_idx % self.custom_config.indexing_freq == 0):
free_gpu_list = []
nvmlInit()
deviceCount = nvmlDeviceGetCount()
my_list = json.loads(self.custom_config.gpu_order)
for i in range(deviceCount):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
if info.used / 1e6 < 15:
position = my_list.index(i)
free_gpu_list.append("cuda:" + str(position))
if len(free_gpu_list) >= self.custom_config.index_gpus:
has_free_gpus = True
else:
has_free_gpus = False
if (not isEmUpdateBusy) and has_free_gpus:
model_copy = type(self.model.rag.ctx_encoder)(
self.config_dpr
) # get a new instance #this will be load in the CPU
model_copy.load_state_dict(self.model.rag.ctx_encoder.state_dict()) # copy weights
processes = []
if len(free_gpu_list) > self.custom_config.index_gpus:
cuda_devices = random.sample(free_gpu_list, self.custom_config.index_gpus)
else:
cuda_devices = free_gpu_list
num_processes = len(cuda_devices)
for rank in range(num_processes):
logger.info("Iniitializing embedding calculation process rank{}".format(rank))
device = cuda_devices[rank]
p = multiprocessing.Process(
target=embed_update,
args=(
copy.deepcopy(model_copy),
num_processes,
device,
rank,
self.custom_config.shard_dir,
self.custom_config.csv_path,
),
)
processes.append(p)
for p in processes:
p.start()
isEmUpdateBusy = True
if isEmUpdateBusy and (not isAddIndexBusy):
index_process_list = [processes[k].is_alive() for k in range(self.custom_config.index_gpus)]
if (
sum(index_process_list) == 0
): # If entire list is false, we can say all embedding calculation process has finished
logger.info("Start adding the index")
threadHandle_index = multiprocessing.Process(
target=add_index,
args=(
self.custom_config.shard_dir,
self.config.index_path,
),
)
threadHandle_index.start()
isAddIndexBusy = True
# check when index building has started
if isAddIndexBusy:
# check still the index_building process is happening
if not threadHandle_index.is_alive():
logger.info("Merging the dataset shards")
saved_dataset_shards = []
for address in glob(str(self.custom_config.shard_dir) + "/*/"):
saved_dataset_shards.append(load_from_disk(address))
concat = concatenate_datasets(saved_dataset_shards)
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset")
# To Do (@Aaron) : Useful in the future dynamic memory implementation.
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
logger.info("Loading new passages and iniitalzing new index")
self.trainer.model.module.module.model.rag.retriever.re_load()
self.trainer.model.module.module.model.rag.retriever.init_retrieval()
isEmUpdateBusy = False
isAddIndexBusy = False
self.trainer.strategy.barrier("barrier")
loss_tensors = self._step(batch)
logs = dict(zip(self.loss_names, loss_tensors))
# tokens per batch
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
src_pad_token_id = (
self.tokenizer.question_encoder.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
logs["tpb"] = (
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
)
self.log("loss", loss_tensors[0])
return loss_tensors[0]