import genesis as gs
import numpy as np
from gymnasium import spaces
import random
import torch

joints_name = (
    "joint1",
    "joint2",
    "joint3",
    "joint4",
    "joint5",
    "joint6",
    "joint7",
    "finger_joint1",
    "finger_joint2",
)
AGENT_DIM = len(joints_name)
ENV_DIM = 14
color_dict = {
    "red":   (1.0, 0.0, 0.0, 1.0),
    "green": (0.0, 1.0, 0.0, 1.0),
    "blue":  (0.0, 0.5, 1.0, 1.0),
    "yellow": (1.0, 1.0, 0.0, 1.0),
}

class CubeStack:
    def __init__(self, enable_pixels, observation_height, observation_width, num_envs, env_spacing, camera_capture_mode, strip_environment_state):
        self.enable_pixels = enable_pixels
        self.observation_height = observation_height
        self.observation_width = observation_width
        self.num_envs = num_envs
        self._random = np.random.RandomState()
        self._build_scene(num_envs, env_spacing)
        self.observation_space = self._make_obs_space()
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(AGENT_DIM,), dtype=np.float32)
        self.camera_capture_mode = camera_capture_mode
        self.strip_environment_state = strip_environment_state

    def _build_scene(self, num_envs, env_spacing):
        if not gs._initialized:
            gs.init(backend=gs.gpu, precision="32")

        self.scene = gs.Scene(
            sim_options=gs.options.SimOptions(dt=0.01),
            rigid_options=gs.options.RigidOptions(box_box_detection=True),
            show_viewer=False,
        )

        self.plane = self.scene.add_entity(gs.morphs.Plane())

        # === Main task cubes ===
        self.cube_1 = self.scene.add_entity(
            gs.morphs.Box(
                size=(0.04, 0.04, 0.04),
                pos=(0.6, -0.1, 0.02),
            ),
            surface=gs.surfaces.Plastic(color=(1, 0, 0)),
        )

        self.cube_2 = self.scene.add_entity(
            gs.morphs.Box(
                size=(0.04, 0.04, 0.04),
                pos=(0.45, 0.15, 0.02),
            ),
            surface=gs.surfaces.Plastic(color=(0, 1, 0)),
        )

        # === Distractor cubes ===
        self.distractor_cubes = []
        for _ in range(3):  # add 3 distractors (shared across batched envs)
            xy = np.random.uniform(low=[0.3, -0.3], high=[0.7, 0.3])
            cube = self.scene.add_entity(
                gs.morphs.Box(
                    size=(0.04, 0.04, 0.04),
                    pos=(xy[0], xy[1], 0.02),  # dummy, randomized in reset()
                ),
                surface=gs.surfaces.Plastic(color=(0.5, 0.5, 0.5)),  # gray
            )
            self.distractor_cubes.append(cube)

        # === Franka arm ===
        self.franka = self.scene.add_entity(
            gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"),
            vis_mode="collision",
        )
        
        if self.enable_pixels:
            self.cam = self.scene.add_camera(
                res=(self.observation_width, self.observation_height),
                pos=(3.5, 0.0, 2.5),
                lookat=(0, 0, 0.5),
                fov=30,
                GUI=False
            )

        self.scene.build(n_envs=num_envs, env_spacing=env_spacing)
        self.motors_dof = np.arange(7)
        self.fingers_dof = np.arange(7, 9)
        self.eef = self.franka.get_link("hand")

    def _make_obs_space(self):
        #TODO: see if we should add text obs
        if self.enable_pixels:
            # we explicity remove the need of environment_state
            return spaces.Dict({
                "agent_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(AGENT_DIM,), dtype=np.float32),
                "pixels": spaces.Box(low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8),
            })
        else:
            return spaces.Dict({
                "agent_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(AGENT_DIM,), dtype=np.float32),
                "environment_state": spaces.Box(low=-np.inf, high=np.inf, shape=(ENV_DIM,), dtype=np.float32),
            })
        
    def reset(self):
        B = self.num_envs
        z = 0.02
        quat = torch.tensor([0, 0, 0, 1], dtype=torch.float32, device=gs.device).repeat(B, 1)

        # === Reset cube_1 (to be picked) ===
        x1 = self._random.uniform(0.45, 0.75, size=(B,))
        y1 = self._random.uniform(-0.2, 0.2, size=(B,))
        pos1 = torch.tensor(np.stack([x1, y1, np.full(B, z)], axis=1), dtype=torch.float32, device=gs.device)
        self.cube_1.set_pos(pos1)
        self.cube_1.set_quat(quat)

        # === Reset cube_2 (target) ===
        x2 = self._random.uniform(0.3, 0.7, size=(B,))
        y2 = self._random.uniform(-0.3, 0.3, size=(B,))
        pos2 = torch.tensor(np.stack([x2, y2, np.full(B, z)], axis=1), dtype=torch.float32, device=gs.device)
        self.cube_2.set_pos(pos2)
        self.cube_2.set_quat(quat)

        # === Distractor cubes ===
        if hasattr(self, "distractor_cubes"):
            for cube in self.distractor_cubes:
                xd = self._random.uniform(0.3, 0.7, size=(B,))
                yd = self._random.uniform(-0.3, 0.3, size=(B,))
                pos_d = torch.tensor(np.stack([xd, yd, np.full(B, z)], axis=1), dtype=torch.float32, device=gs.device)
                cube.set_pos(pos_d)
                cube.set_quat(quat)

        # === Reset robot to home pose ===
        qpos = np.array([0.0, -0.4, 0.0, -2.2, 0.0, 2.0, 0.8, 0.04, 0.04])
        qpos_tensor = torch.tensor(qpos, dtype=torch.float32, device=gs.device).repeat(B, 1)
        self.franka.set_qpos(qpos_tensor, zero_velocity=True)
        self.franka.control_dofs_position(qpos_tensor[:, :7], self.motors_dof)
        self.franka.control_dofs_position(qpos_tensor[:, 7:], self.fingers_dof)

        # === Optional control stability tweaks ===
        self.franka.set_dofs_kp(np.array([4500, 4500, 3500, 3500, 2000, 2000, 2000, 100, 100]))
        self.franka.set_dofs_kv(np.array([450, 450, 350, 350, 200, 200, 200, 10, 10]))
        self.franka.set_dofs_force_range(
            np.array([-87] * 7 + [-100, -100]),
            np.array([87] * 7 + [100, 100]),
        )

        self.scene.step()

        if self.enable_pixels:
            self.cam.start_recording()

        return self.get_obs()

        
    def seed(self, seed):
        np.random.seed(seed)
        random.seed(seed)
        self._random = np.random.RandomState(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        self.action_space.seed(seed)

    def step(self, action):
        self.franka.control_dofs_position(action[:, :7], self.motors_dof)
        self.franka.control_dofs_position(action[:, 7:], self.fingers_dof)
        self.scene.step()
        reward = self.compute_reward()
        obs = self.get_obs()
        return None, reward, None, obs

    
    def compute_reward(self):
        pos_1 = self.cube_1.get_pos()  # (B, 3)
        pos_2 = self.cube_2.get_pos()  # (B, 3)

        xy_dist = torch.norm(pos_1[:, :2] - pos_2[:, :2], dim=1)  # (B,)
        z_diff = pos_1[:, 2] - pos_2[:, 2]  # (B,)

        reward = ((xy_dist < 0.05) & (z_diff > 0.03)).float()  # (B,)
        return reward.cpu().numpy()
    
    def get_obs(self):
        eef_pos = self.eef.get_pos()          # (B, 3)
        eef_rot = self.eef.get_quat()         # (B, 4)
        gripper = self.franka.get_dofs_position()[:, 7:9]  # (B, 2)

        cube1_pos = self.cube_1.get_pos()     # (B, 3)
        cube1_rot = self.cube_1.get_quat()    # (B, 4)
        cube2_pos = self.cube_2.get_pos()    # (B, 3)

        diff = eef_pos - cube1_pos                          # (B, 3)
        dist = torch.norm(diff, dim=1, keepdim=True) # (B, 1) (privileged)

        agent_pos = torch.cat([eef_pos, eef_rot, gripper], dim=1).float()  # (B, 9)
        environment_state = torch.cat([cube1_pos, cube1_rot, diff, dist, cube2_pos], dim=1).float()  # (B, 14)

        obs = {
            "agent_pos": agent_pos,
            "environment_state": environment_state,
        }

        if self.enable_pixels:
            #TODO (jadechoghari): it's hacky but keep it for the sake of saving time
            if self.strip_environment_state is True:
                del obs["environment_state"]
            if self.camera_capture_mode == "per_env":
                # Capture a separate image for each environment
                batch_imgs = []
                for i in range(self.num_envs):
                    pos_i = self.scene.envs_offset[i] + np.array([3.5, 0.0, 2.5])
                    lookat_i = self.scene.envs_offset[i] + np.array([0, 0, 0.5])
                    self.cam.set_pose(pos=pos_i, lookat=lookat_i)
                    img = self.cam.render()[0]
                    batch_imgs.append(img)
                pixels = np.stack(batch_imgs, axis=0)  # shape: (B, H, W, 3)
                assert pixels.ndim == 4, f"pixels shape {pixels.shape} is not 4D (B, H, W, 3)"
            elif self.camera_capture_mode == "global":
                # Capture a single global/overview image
                pixels = self.cam.render()[0]  # shape: (H, W, 3)
                assert pixels.ndim == 3, f"pixels shape {pixels.shape} is not 3D (H, W, 3)"
            else:
                raise ValueError(f"Unknown camera_capture_mode: {self.camera_capture_mode}")
            obs["pixels"] = pixels
        return obs
