run_inverse_dynamics_model.py (157 lines of code) (raw):
# NOTE: this is _not_ the original code of IDM!
# As such, while it is close and seems to function well,
# its performance might be bit off from what is reported
# in the paper.
from argparse import ArgumentParser
import pickle
import cv2
import numpy as np
import json
import torch as th
from agent import ENV_KWARGS
from inverse_dynamics_model import IDMAgent
KEYBOARD_BUTTON_MAPPING = {
"key.keyboard.escape" :"ESC",
"key.keyboard.s" :"back",
"key.keyboard.q" :"drop",
"key.keyboard.w" :"forward",
"key.keyboard.1" :"hotbar.1",
"key.keyboard.2" :"hotbar.2",
"key.keyboard.3" :"hotbar.3",
"key.keyboard.4" :"hotbar.4",
"key.keyboard.5" :"hotbar.5",
"key.keyboard.6" :"hotbar.6",
"key.keyboard.7" :"hotbar.7",
"key.keyboard.8" :"hotbar.8",
"key.keyboard.9" :"hotbar.9",
"key.keyboard.e" :"inventory",
"key.keyboard.space" :"jump",
"key.keyboard.a" :"left",
"key.keyboard.d" :"right",
"key.keyboard.left.shift" :"sneak",
"key.keyboard.left.control" :"sprint",
"key.keyboard.f" :"swapHands",
}
# Template action
NOOP_ACTION = {
"ESC": 0,
"back": 0,
"drop": 0,
"forward": 0,
"hotbar.1": 0,
"hotbar.2": 0,
"hotbar.3": 0,
"hotbar.4": 0,
"hotbar.5": 0,
"hotbar.6": 0,
"hotbar.7": 0,
"hotbar.8": 0,
"hotbar.9": 0,
"inventory": 0,
"jump": 0,
"left": 0,
"right": 0,
"sneak": 0,
"sprint": 0,
"swapHands": 0,
"camera": np.array([0, 0]),
"attack": 0,
"use": 0,
"pickItem": 0,
}
MESSAGE = """
This script will take a video, predict actions for its frames and
and show them with a cv2 window.
Press any button the window to proceed to the next frame.
"""
# Matches a number in the MineRL Java code regarding sensitivity
# This is for mapping from recorded sensitivity to the one used in the model
CAMERA_SCALER = 360.0 / 2400.0
def json_action_to_env_action(json_action):
"""
Converts a json action into a MineRL action.
Returns (minerl_action, is_null_action)
"""
# This might be slow...
env_action = NOOP_ACTION.copy()
# As a safeguard, make camera action again so we do not override anything
env_action["camera"] = np.array([0, 0])
is_null_action = True
keyboard_keys = json_action["keyboard"]["keys"]
for key in keyboard_keys:
# You can have keys that we do not use, so just skip them
# NOTE in original training code, ESC was removed and replaced with
# "inventory" action if GUI was open.
# Not doing it here, as BASALT uses ESC to quit the game.
if key in KEYBOARD_BUTTON_MAPPING:
env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1
is_null_action = False
mouse = json_action["mouse"]
camera_action = env_action["camera"]
camera_action[0] = mouse["dy"] * CAMERA_SCALER
camera_action[1] = mouse["dx"] * CAMERA_SCALER
if mouse["dx"] != 0 or mouse["dy"] != 0:
is_null_action = False
else:
if abs(camera_action[0]) > 180:
camera_action[0] = 0
if abs(camera_action[1]) > 180:
camera_action[1] = 0
mouse_buttons = mouse["buttons"]
if 0 in mouse_buttons:
env_action["attack"] = 1
is_null_action = False
if 1 in mouse_buttons:
env_action["use"] = 1
is_null_action = False
if 2 in mouse_buttons:
env_action["pickItem"] = 1
is_null_action = False
return env_action, is_null_action
def main(model, weights, video_path, json_path, n_batches, n_frames):
print(MESSAGE)
agent_parameters = pickle.load(open(model, "rb"))
net_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
agent = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs)
agent.load_weights(weights)
required_resolution = ENV_KWARGS["resolution"]
cap = cv2.VideoCapture(video_path)
json_index = 0
with open(json_path) as json_file:
json_lines = json_file.readlines()
json_data = "[" + ",".join(json_lines) + "]"
json_data = json.loads(json_data)
for _ in range(n_batches):
th.cuda.empty_cache()
print("=== Loading up frames ===")
frames = []
recorded_actions = []
for _ in range(n_frames):
ret, frame = cap.read()
if not ret:
break
assert frame.shape[0] == required_resolution[1] and frame.shape[1] == required_resolution[0], "Video must be of resolution {}".format(required_resolution)
# BGR -> RGB
frames.append(frame[..., ::-1])
env_action, _ = json_action_to_env_action(json_data[json_index])
recorded_actions.append(env_action)
json_index += 1
frames = np.stack(frames)
print("=== Predicting actions ===")
predicted_actions = agent.predict_actions(frames)
for i in range(n_frames):
frame = frames[i]
recorded_action = recorded_actions[i]
cv2.putText(
frame,
f"name: prediction (true)",
(10, 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
(255, 255, 255),
1
)
for y, (action_name, action_array) in enumerate(predicted_actions.items()):
current_prediction = action_array[0, i]
cv2.putText(
frame,
f"{action_name}: {current_prediction} ({recorded_action[action_name]})",
(10, 25 + y * 12),
cv2.FONT_HERSHEY_SIMPLEX,
0.35,
(255, 255, 255),
1
)
# RGB -> BGR again...
cv2.imshow("MineRL IDM model predictions", frame[..., ::-1])
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
parser = ArgumentParser("Run IDM on MineRL recordings.")
parser.add_argument("--weights", type=str, required=True, help="Path to the '.weights' file to be loaded.")
parser.add_argument("--model", type=str, required=True, help="Path to the '.model' file to be loaded.")
parser.add_argument("--video-path", type=str, required=True, help="Path to a .mp4 file (Minecraft recording).")
parser.add_argument("--jsonl-path", type=str, required=True, help="Path to a .jsonl file (Minecraft recording).")
parser.add_argument("--n-frames", type=int, default=128, help="Number of frames to process at a time.")
parser.add_argument("--n-batches", type=int, default=10, help="Number of batches (n-frames) to process for visualization.")
args = parser.parse_args()
main(args.model, args.weights, args.video_path, args.jsonl_path, args.n_batches, args.n_frames)