def generate_continually()

in modules/SwissArmyTransformer/sat/generation/utils.py [0:0]


def generate_continually(func, input_source='interactive'):
    if input_source == 'interactive':
        while True:
            raw_text, is_stop = "", False
            if torch.distributed.get_rank() == 0:
                raw_text = input("\nPlease Input Query (stop to exit) >>> ")
                raw_text = raw_text.strip()
                if not raw_text:
                    print('Query should not be empty!')
                    continue
                if raw_text == "stop":
                    is_stop = True
                torch.distributed.broadcast_object_list([raw_text, is_stop])
            else:
                info = [raw_text, is_stop]
                torch.distributed.broadcast_object_list(info)
                raw_text, is_stop = info
            if is_stop:
                return
            try:
                start_time = time.time()
                func(raw_text)
                if torch.distributed.get_rank() == 0:
                    print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
            except (ValueError, FileNotFoundError) as e:
                print(e)
                continue
    else:
        with open(input_source, 'r', encoding="utf-8") as fin:
            inputs = fin.readlines()
        for line_no, raw_text in enumerate(inputs):
            if line_no % get_data_parallel_world_size() != get_data_parallel_rank():
                continue
            rk = dist.get_rank()
            if get_model_parallel_rank() == 0:
                print(f'Working on No. {line_no} on model group {rk}... ')
            raw_text = raw_text.strip()
            if len(raw_text) == 0:
                continue
            start_time = time.time()
            func(raw_text)
            if get_model_parallel_rank() == 0:
                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)