def make_env()

in mae_envs/envs/box_locking.py [0:0]


def make_env(n_substeps=15, horizon=80, deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_agents=1, fixed_agent_spawn=False,
             lock_box=True, grab_box=True, grab_selective=False,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, grab_exclusive=False, grab_out_of_vision=False,
             lock_out_of_vision=True,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='quadrant', p_door_dropout=0.0,
             n_rooms=4, random_room_number=True,
             n_lidar_per_agent=0, visualize_lidar=False, compress_lidar_scale=None,
             n_boxes=2, box_size=0.5, box_only_z_rot=False,
             boxid_obs=True, boxsize_obs=True, pad_ramp_size=True, additional_obs={},
             # lock-box task
             task_type='all', lock_reward=5.0, unlock_penalty=7.0, shaped_reward_scale=0.25,
             return_threshold=0.1,
             # ramps
             n_ramps=0):

    grab_radius_multiplier = lock_grab_radius / box_size
    lock_radius_multiplier = lock_grab_radius / box_size

    env = Base(n_agents=n_agents, n_substeps=n_substeps,
               floor_size=floor_size,
               horizon=horizon, action_lims=action_lims, deterministic_mode=deterministic_mode,
               grid_size=grid_size)

    if scenario == 'randomwalls':
        env.add_module(RandomWalls(grid_size=grid_size, num_rooms=n_rooms,
                                   random_room_number=random_room_number,
                                   min_room_size=6, door_size=door_size,
                                   gen_door_obs=False))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        agent_placement_fn = uniform_placement if not fixed_agent_spawn else center_placement
    elif scenario == 'quadrant':
        env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        agent_placement_fn = quadrant_placement if not fixed_agent_spawn else center_placement
    elif scenario == 'empty':
        env.add_module(WallScenarios(grid_size=grid_size, door_size=2, scenario='empty'))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        agent_placement_fn = center_placement
    elif 'var_tri' in scenario:
        env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size, scenario='var_tri'))
        ramp_placement_fn = [tri_placement(i % 3) for i in range(n_ramps)]
        agent_placement_fn = center_placement if fixed_agent_spawn else \
            (uniform_placement if 'uniform' in scenario else rotate_tri_placement)
        box_placement_fn = uniform_placement if 'uniform' in scenario else rotate_tri_placement
    else:
        raise ValueError(f"Scenario {scenario} not supported.")

    env.add_module(Agents(n_agents,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_agents,
                          friction=other_friction,
                          polar_obs=polar_obs))
    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=0,
                             boxid_obs=boxid_obs,
                             box_only_z_rot=box_only_z_rot,
                             boxsize_obs=boxsize_obs))

    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn,
                             friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=pad_ramp_size))

    if n_lidar_per_agent > 0 and visualize_lidar:
        env.add_module(LidarSites(n_agents=n_agents, n_lidar_per_agent=n_lidar_per_agent))

    if np.max(n_boxes) > 0 and grab_box:
        env.add_module(AgentManipulation())
    if box_floor_friction is not None:
        env.add_module(FloorAttributes(friction=box_floor_friction))
    env.add_module(WorldConstants(gravity=gravity))
    env.reset()
    keys_self = ['agent_qpos_qvel', 'hider', 'prep_obs']
    keys_mask_self = ['mask_aa_obs']
    keys_external = ['agent_qpos_qvel']
    keys_copy = ['you_lock', 'team_lock']
    keys_mask_external = []

    env = SplitMultiAgentActions(env)
    env = TeamMembership(env, np.zeros((n_agents,)))
    env = AgentAgentObsMask2D(env)
    env = DiscretizeActionWrapper(env, 'action_movement')
    env = NumpyArrayRewardWrapper(env)

    if np.max(n_boxes) > 0:
        env = AgentGeomObsMask2D(env, pos_obs_key='box_pos', mask_obs_key='mask_ab_obs',
                                 geom_idxs_obs_key='box_geom_idxs')
        keys_external += ['mask_ab_obs', 'box_obs']
        keys_mask_external.append('mask_ab_obs')
    if lock_box and np.max(n_boxes) > 0:
        env = LockObjWrapper(env, body_names=[f'moveable_box{i}' for i in range(n_boxes)],
                             agent_idx_allowed_to_lock=np.arange(n_agents),
                             lock_type=lock_type,
                             radius_multiplier=lock_radius_multiplier,
                             obj_in_game_metadata_keys=["curr_n_boxes"],
                             agent_allowed_to_lock_keys=None if lock_out_of_vision else ["mask_ab_obs"])

    if n_ramps > 0:
        env = AgentGeomObsMask2D(env, pos_obs_key='ramp_pos', mask_obs_key='mask_ar_obs',
                                 geom_idxs_obs_key='ramp_geom_idxs')
        env = LockObjWrapper(env, body_names=[f"ramp{i}:ramp" for i in range(n_ramps)],
                             agent_idx_allowed_to_lock=np.arange(n_agents),
                             lock_type=lock_type, ac_obs_prefix='ramp_',
                             radius_multiplier=lock_radius_multiplier,
                             agent_allowed_to_lock_keys=None if lock_out_of_vision else ["mask_ar_obs"])

        keys_external += ['ramp_obs']
        keys_mask_external += ['mask_ar_obs']
        keys_copy += ['ramp_you_lock', 'ramp_team_lock']

    if grab_box and np.max(n_boxes) > 0:
        body_names = ([f'moveable_box{i}' for i in range(n_boxes)] +
                      [f"ramp{i}:ramp" for i in range(n_ramps)])
        obj_in_game_meta_keys = ['curr_n_boxes'] + (['curr_n_ramps'] if n_ramps > 0 else [])
        env = GrabObjWrapper(env,
                             body_names=body_names,
                             radius_multiplier=grab_radius_multiplier,
                             grab_exclusive=grab_exclusive,
                             obj_in_game_metadata_keys=obj_in_game_meta_keys)

    if n_lidar_per_agent > 0:
        env = Lidar(env, n_lidar_per_agent=n_lidar_per_agent, visualize_lidar=visualize_lidar,
                    compress_lidar_scale=compress_lidar_scale)
        keys_copy += ['lidar']
        keys_external += ['lidar']

    env = AddConstantObservationsWrapper(env, new_obs=additional_obs)
    keys_external += list(additional_obs)
    keys_mask_external += [ob for ob in additional_obs if 'mask' in ob]

    #############################################
    # lock Box Task Reward
    ###
    env = LockObjectsTask(env, n_objs=n_boxes, task=task_type, fixed_order=True,
                          obj_lock_obs_key='obj_lock', obj_pos_obs_key='box_pos',
                          act_lock_key='action_glue', agent_pos_key='agent_pos',
                          lock_reward=lock_reward, unlock_penalty=unlock_penalty,
                          shaped_reward_scale=shaped_reward_scale,
                          return_threshold=return_threshold)
    ###
    #############################################

    env = SplitObservations(env, keys_self + keys_mask_self, keys_copy=keys_copy)
    env = SpoofEntityWrapper(env, n_boxes,
                             ['box_obs', 'you_lock', 'team_lock', 'obj_lock'],
                             ['mask_ab_obs'])
    keys_mask_external += ['mask_ab_obs_spoof']

    if n_agents < 2:
        env = SpoofEntityWrapper(env, 1, ['agent_qpos_qvel', 'hider', 'prep_obs'], ['mask_aa_obs'])

    env = LockAllWrapper(env, remove_object_specific_lock=True)
    if not grab_out_of_vision and grab_box:
        # Can only pull if in vision
        mask_keys = ['mask_ab_obs'] + (['mask_ar_obs'] if n_ramps > 0 else [])
        env = MaskActionWrapper(env, action_key='action_pull', mask_keys=mask_keys)
    if not grab_selective and grab_box:
        env = GrabClosestWrapper(env)
    env = DiscardMujocoExceptionEpisodes(env)
    env = ConcatenateObsWrapper(env, {'agent_qpos_qvel': ['agent_qpos_qvel', 'hider', 'prep_obs'],
                                      'box_obs': ['box_obs', 'you_lock', 'team_lock', 'obj_lock'],
                                      'ramp_obs': ['ramp_obs', 'ramp_you_lock', 'ramp_team_lock',
                                                   'ramp_obj_lock']})
    env = SelectKeysWrapper(env, keys_self=keys_self,
                            keys_other=keys_external + keys_mask_self + keys_mask_external)

    return env