in pretrain_reconstruction.py [0:0]
def main():
torch.set_num_threads(1)
device = torch.device("cuda:0" if args.cuda else "cpu")
ndevices = torch.cuda.device_count()
# Setup loggers
tbwriter = SummaryWriter(log_dir=args.log_dir)
logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.getLogger().setLevel(logging.INFO)
if "habitat" in args.env_name:
devices = [int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
# Devices need to be indexed between 0 to N-1
devices = [dev for dev in range(len(devices))]
envs = make_vec_envs_habitat(
args.habitat_config_file, device, devices, seed=args.seed
)
else:
train_log_dir = os.path.join(args.log_dir, "train_monitor")
try:
os.makedirs(train_log_dir)
except OSError:
pass
envs = make_vec_envs_avd(
args.env_name,
args.seed,
args.num_processes,
train_log_dir,
device,
True,
num_frame_stack=1,
split="train",
nRef=args.num_pose_refs,
ref_dist_thresh=args.ref_dist_thresh,
)
args.feat_shape_sim = (512,)
args.obs_shape = envs.observation_space.spaces["im"].shape
# =================== Load concept clusters =================
clusters_h5 = h5py.File(args.clusters_path, "r")
cluster_centroids = torch.Tensor(np.array(clusters_h5["cluster_centroids"])).to(
device
)
args.nclusters = cluster_centroids.shape[0]
clusters2images = {}
for i in range(args.nclusters):
clusters_h5[f"cluster_{i}/images"]
) # (K, C, H, W) torch Tensor
cluster_images = rearrange(cluster_images, "k c h w -> k h w c")
cluster_images = (cluster_images * 255.0).astype(np.uint8)
clusters2images[i] = cluster_images # (K, H, W, C)
clusters_h5.close()
# =================== Create models ====================
decoder = FeatureReconstructionModule(
args.nclusters, args.nclusters, nlayers=args.n_transformer_layers,
)
feature_network = FeatureNetwork()
feature_network = nn.DataParallel(feature_network, dim=0)
pose_encoder = PoseEncoder()
if args.use_multi_gpu:
decoder = nn.DataParallel(decoder, dim=1)
pose_encoder = nn.DataParallel(pose_encoder, dim=0)
# =================== Load models ====================
save_path = os.path.join(args.save_dir, "checkpoints")
checkpoint_path = os.path.join(save_path, "ckpt.latest.pth")
if os.path.isfile(checkpoint_path):
logging.info("Resuming from old model!")
decoder_state, pose_encoder_state, j_start = torch.load(checkpoint_path)
decoder.load_state_dict(decoder_state)
pose_encoder.load_state_dict(pose_encoder_state)
else:
j_start = -1
decoder.to(device)
pose_encoder.to(device)
feature_network.to(device)
decoder.eval()
pose_encoder.eval()
feature_network.eval() # Feature network is frozen
# =================== Define decoder training algorithm ====================
algo_config = {}
algo_config["lr"] = args.lr
algo_config["eps"] = args.eps
algo_config["rec_loss_fn"] = rec_loss_fn_classify
algo_config["rec_loss_fn_J"] = args.rec_loss_fn_J
algo_config["max_grad_norm"] = args.max_grad_norm
algo_config["cluster_centroids"] = cluster_centroids
algo_config["prediction_interval"] = 20 if "avd" in args.env_name else 100
algo_config["decoder"] = decoder
algo_config["pose_encoder"] = pose_encoder
reconstruction_algo = SupervisedReconstruction(algo_config)
# =================== Define rollouts ====================
odometer_shape = (4,)
rollouts = RolloutStorageReconstruction(
args.num_rl_steps,
args.num_processes,
(args.nclusters,),
odometer_shape,
args.num_pose_refs,
)
rollouts.to(device)
def get_obs(obs):
obs_im = process_image(obs["im"])
return obs_im
start = time.time()
NPROC = args.num_processes
NREF = args.num_pose_refs
for j in range(j_start + 1, num_updates):
# =================== Start a new episode ====================
obs = envs.reset()
# Processing environment inputs
obs_im = get_obs(obs) # (num_processes, 3, 84, 84)
obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4)
# Convert mm to m for AVD
if "avd" in args.env_name:
obs_odometer[:, :2] /= 1000.0
# ============== Target poses and corresponding images ================
# NOTE - these are constant throughout the episode.
# (num_processes * num_pose_refs, 3) --- (y, x, t)
tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3]
tgt_poses = unflatten_two(tgt_poses, NPROC, NREF) # (N, nRef, 3)
tgt_masks = obs["valid_masks"].unsqueeze(2) # (N, nRef, 1)
# Convert mm to m for AVD
if "avd" in args.env_name:
tgt_poses[:, :, :2] /= 1000.0
tgt_ims = process_image(flatten_two(obs["pose_refs"])) # (N*nRef, C, H, W)
# Initialize the memory of rollouts
rollouts.reset()
with torch.no_grad():
obs_feat = feature_network(obs_im) # (N, 2048)
tgt_feat = feature_network(tgt_ims) # (N*nRef, 2048)
# Compute similarity scores with all other clusters
obs_feat = torch.matmul(obs_feat, cluster_centroids.t()) # (N, nclusters)
tgt_feat = torch.matmul(
tgt_feat, cluster_centroids.t()
) # (N*nRef, nclusters)
tgt_feat = unflatten_two(tgt_feat, NPROC, NREF) # (N, nRef, nclusters)
rollouts.obs_feats[0].copy_(obs_feat)
rollouts.obs_odometer[0].copy_(obs_odometer)
rollouts.tgt_poses.copy_(tgt_poses)
rollouts.tgt_feats.copy_(tgt_feat)
rollouts.tgt_masks.copy_(tgt_masks)
# =============== Update over a full batch of episodes ================
# num_steps must be total number of steps in each episode
for step in range(args.num_steps):
pstep = rollouts.step
action = obs["oracle_action"]
# Act, get reward and next obs
obs, reward, done, infos = envs.step(action)
# Processing environment inputs
obs_im = get_obs(obs) # (num_processes, 3, 84, 84)
obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4)
if "avd" in args.env_name:
obs_odometer[:, :2] /= 1000.0
with torch.no_grad():
obs_feat = feature_network(obs_im)
# Compute similarity scores with all other clusters
obs_feat = torch.matmul(
obs_feat, cluster_centroids.t()
) # (N, nclusters)
# Always set masks to 1 (since this loop happens within one episode)
masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device)
# Accumulate odometer readings to give relative pose
# from the starting point
obs_odometer = rollouts.obs_odometer[pstep] * masks + obs_odometer
# Update rollouts
rollouts.insert(obs_feat, obs_odometer)
if (step + 1) % args.num_rl_steps == 0:
decoder.train()
pose_encoder.train()
# Update decoder
losses = reconstruction_algo.update(rollouts)
# Refresh rollouts
rollouts.after_update()
decoder.eval()
pose_encoder.eval()
# =================== Save model ====================
if (j + 1) % args.save_interval == 0 and args.save_dir != "":
save_path = os.path.join(args.save_dir, "checkpoints")
try:
os.makedirs(save_path)
except OSError:
pass
decoder_state = decoder.state_dict()
pose_encoder_state = pose_encoder.state_dict()
torch.save(
[decoder_state, pose_encoder_state, j],
os.path.join(save_path, "ckpt.latest.pth"),
)
if args.save_unique:
torch.save(
[decoder_state, pose_encoder_state, j],
os.path.join(save_path, f"{save_path}/ckpt.{(j+1):07d}.pth"),
)
# =================== Logging data ====================
total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps
if j % args.log_interval == 0:
end = time.time()
fps = int(total_num_steps / (end - start))
logging.info(f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}")
train_metrics = losses
for k, v in train_metrics.items():
logging.info("{}: {:.3f}".format(k, v))
tbwriter.add_scalar("train_metrics/{}".format(k), v, j)
# =================== Evaluate models ====================
if args.eval_interval is not None and (j + 1) % args.eval_interval == 0:
if "habitat" in args.env_name:
devices = [
int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
]
# Devices need to be indexed between 0 to N-1
devices = [dev for dev in range(len(devices))]
eval_envs = make_vec_envs_habitat(
args.eval_habitat_config_file, device, devices
)
else:
eval_envs = make_vec_envs_avd(
args.env_name,
args.seed + 12,
12,
eval_log_dir,
device,
True,
split="val",
nRef=NREF,
ref_dist_thresh=args.ref_dist_thresh,
set_return_topdown_map=True,
)
num_eval_episodes = 16 if "habitat" in args.env_name else 30
eval_config = {}
eval_config["num_steps"] = args.num_steps
eval_config["num_processes"] = 1 if "habitat" in args.env_name else 12
eval_config["num_eval_episodes"] = num_eval_episodes
eval_config["num_pose_refs"] = NREF
eval_config["cluster_centroids"] = cluster_centroids
eval_config["clusters2images"] = clusters2images
eval_config["odometer_shape"] = odometer_shape
eval_config["rec_loss_fn"] = rec_loss_fn_classify
eval_config["rec_loss_fn_J"] = args.rec_loss_fn_J
eval_config["vis_save_dir"] = os.path.join(
args.save_dir, "policy_vis", "update_{:05d}".format(j + 1)
)
eval_config["env_name"] = args.env_name
models = {}
models["decoder"] = decoder
models["pose_encoder"] = pose_encoder
models["feature_network"] = feature_network
val_metrics = evaluate_reconstruction_oracle(
models, eval_envs, eval_config, device
)
for k, v in val_metrics.items():
tbwriter.add_scalar("val_metrics/{}".format(k), v, j)
tbwriter.close()