in scripts/run_model.py [0:0]
def run_single_example(args, model):
dtype = torch.FloatTensor
if args.use_gpu == 1:
dtype = torch.cuda.FloatTensor
# Build the CNN to use for feature extraction
print('Loading CNN for feature extraction')
cnn = build_cnn(args, dtype)
# Load and preprocess the image
img_size = (args.image_height, args.image_width)
img = imread(args.image, mode='RGB')
img = imresize(img, img_size, interp='bicubic')
img = img.transpose(2, 0, 1)[None]
mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1)
img = (img.astype(np.float32) / 255.0 - mean) / std
# Use CNN to extract features for the image
img_var = Variable(torch.FloatTensor(img).type(dtype), volatile=True)
feats_var = cnn(img_var)
# Tokenize the question
vocab = load_vocab(args)
question_tokens = tokenize(args.question,
punct_to_keep=[';', ','],
punct_to_remove=['?', '.'])
question_encoded = encode(question_tokens,
vocab['question_token_to_idx'],
allow_unk=True)
question_encoded = torch.LongTensor(question_encoded).view(1, -1)
question_encoded = question_encoded.type(dtype).long()
question_var = Variable(question_encoded, volatile=True)
# Run the model
print('Running the model\n')
scores = None
predicted_program = None
if type(model) is tuple:
program_generator, execution_engine = model
program_generator.type(dtype)
execution_engine.type(dtype)
predicted_program = program_generator.reinforce_sample(
question_var,
temperature=args.temperature,
argmax=(args.sample_argmax == 1))
scores = execution_engine(feats_var, predicted_program)
else:
model.type(dtype)
scores = model(question_var, feats_var)
# Print results
_, predicted_answer_idx = scores.data.cpu()[0].max(dim=0)
predicted_answer = vocab['answer_idx_to_token'][predicted_answer_idx[0]]
print('Question: "%s"' % args.question)
print('Predicted answer: ', predicted_answer)
if predicted_program is not None:
print()
print('Predicted program:')
program = predicted_program.data.cpu()[0]
num_inputs = 1
for fn_idx in program:
fn_str = vocab['program_idx_to_token'][fn_idx]
num_inputs += iep.programs.get_num_inputs(fn_str) - 1
print(fn_str)
if num_inputs == 0:
break