in modules/SwissArmyTransformer/sat/generation/utils.py [0:0]
def generate_continually(func, input_source='interactive'):
if input_source == 'interactive':
while True:
raw_text, is_stop = "", False
if torch.distributed.get_rank() == 0:
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
raw_text = raw_text.strip()
if not raw_text:
print('Query should not be empty!')
continue
if raw_text == "stop":
is_stop = True
torch.distributed.broadcast_object_list([raw_text, is_stop])
else:
info = [raw_text, is_stop]
torch.distributed.broadcast_object_list(info)
raw_text, is_stop = info
if is_stop:
return
try:
start_time = time.time()
func(raw_text)
if torch.distributed.get_rank() == 0:
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
except (ValueError, FileNotFoundError) as e:
print(e)
continue
else:
with open(input_source, 'r', encoding="utf-8") as fin:
inputs = fin.readlines()
for line_no, raw_text in enumerate(inputs):
if line_no % get_data_parallel_world_size() != get_data_parallel_rank():
continue
rk = dist.get_rank()
if get_model_parallel_rank() == 0:
print(f'Working on No. {line_no} on model group {rk}... ')
raw_text = raw_text.strip()
if len(raw_text) == 0:
continue
start_time = time.time()
func(raw_text)
if get_model_parallel_rank() == 0:
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)