modules/SwissArmyTransformer/sat/generation/magnify.py (30 lines of code) (raw):

# -*- encoding: utf-8 -*- ''' @File : magnify.py @Time : 2021/01/14 00:41:40 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import numpy as np import torch import torch.nn.functional as F from .sampling import filling_sequence def magnify(model, tokenizer, tokens_list, text_token_list, args): # 32 * 32 to 4 16 * 16 s = int(math.sqrt(len(tokens_list)+ 1e-6)) assert s == 32 code = tokens_list.view(s, s) midfix = torch.tensor([tokenizer['[EOI1]'], tokenizer['[ROI2]'], tokenizer['[POS0]'], tokenizer['[BASE]'], tokenizer['[BOI2]']], device=code.device) magnified_code = code.new_zeros((s * 2, s * 2), dtype=torch.long) - 1 windows = [(0, 0, 18), (0, 1, 30), (0, 2, 30), (1, 1, 30), (1, 0, 30), (1, 2, 30), (2, 0, 32), (2, 1, 32), (2, 2, 32)] for i, j, line in windows: code_part = code[8 * i: 8 * (i+2), 8 * j: 8 * (j+2)].reshape(-1) magnified_code_part = magnified_code[16 * i: 16 * i + line, 16 * j: 16 * (j+2)].reshape(-1) context_tokens_tensor = torch.cat([text_token_list, code_part, midfix], dim=0) context_len = len(context_tokens_tensor) seq = torch.cat([context_tokens_tensor, magnified_code_part], dim=0) magnified_code_part_completed = filling_sequence(model, seq, args, invalid_slices=[slice(tokenizer.img_tokenizer.num_tokens, None)]) magnified_code[16 * i: 16 * i + line, 16 * j: 16 * (j+2)] = magnified_code_part_completed[0, context_len:].view(line, 32) return magnified_code.view(1, s * s * 4)