modules/SwissArmyTransformer/sat/generation/utils.py (77 lines of code) (raw):

# -*- encoding: utf-8 -*- ''' @File : utils.py @Time : 2021/10/09 17:18:26 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import torch import time import stat from datetime import datetime from torchvision.utils import save_image import torch.distributed as dist from sat.mpu import get_data_parallel_world_size, get_data_parallel_rank, get_model_parallel_rank def timed_name(prefix, suffix=None, path=None): return os.path.join( path, f"{prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}{suffix}" ) def save_multiple_images(imgs, path, debug=True): # imgs: list of tensor images if debug: imgs = torch.cat(imgs, dim=0) print("\nSave to: ", path, flush=True) save_image(imgs, path, normalize=True) else: print("\nSave to: ", path, flush=True) for i in range(len(imgs)): save_image(imgs[i], os.path.join(path, f'{i}.jpg'), normalize=True) os.chmod(os.path.join(path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) save_image(torch.cat(imgs, dim=0), os.path.join(path,f'concat.jpg'), normalize=True) os.chmod(os.path.join(path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) def generate_continually(func, input_source='interactive'): if input_source == 'interactive': while True: raw_text, is_stop = "", False if torch.distributed.get_rank() == 0: raw_text = input("\nPlease Input Query (stop to exit) >>> ") raw_text = raw_text.strip() if not raw_text: print('Query should not be empty!') continue if raw_text == "stop": is_stop = True torch.distributed.broadcast_object_list([raw_text, is_stop]) else: info = [raw_text, is_stop] torch.distributed.broadcast_object_list(info) raw_text, is_stop = info if is_stop: return try: start_time = time.time() func(raw_text) if torch.distributed.get_rank() == 0: print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) except (ValueError, FileNotFoundError) as e: print(e) continue else: with open(input_source, 'r', encoding="utf-8") as fin: inputs = fin.readlines() for line_no, raw_text in enumerate(inputs): if line_no % get_data_parallel_world_size() != get_data_parallel_rank(): continue rk = dist.get_rank() if get_model_parallel_rank() == 0: print(f'Working on No. {line_no} on model group {rk}... ') raw_text = raw_text.strip() if len(raw_text) == 0: continue start_time = time.time() func(raw_text) if get_model_parallel_rank() == 0: print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)