in lib/utils/vis.py [0:0]
def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
file_name, nrow=8, padding=2):
'''
batch_image: [batch_size, channel, height, width]
batch_joints: [batch_size, num_joints, 3],
batch_joints_vis: [batch_size, num_joints, 1],
}
'''
grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
ndarr = ndarr.copy()
nmaps = batch_image.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height = int(batch_image.size(2) + padding)
width = int(batch_image.size(3) + padding)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
joints = batch_joints[k]
joints_vis = batch_joints_vis[k]
for joint, joint_vis in zip(joints, joints_vis):
joint[0] = x * width + padding + joint[0]
joint[1] = y * height + padding + joint[1]
if joint_vis[0]:
cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
k = k + 1
cv2.imwrite(file_name, ndarr)