gym-compete/gym_compete/new_envs/utils.py (108 lines of code) (raw):

import xml.etree.ElementTree as ET import colorsys import numpy as np def list_filter(lambda_fn, iterable): return list(filter(lambda_fn, iterable)) def get_distinct_colors(n=2): ''' taken from: https://stackoverflow.com/a/876872 ''' HSV_tuples = [(x*1.0/n, 0.5, 0.5) for x in range(n)] RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples) return RGB_tuples def set_class(root, prop, agent_class): if root is None: return # root_class = root.get('class') if root.tag == prop: root.set('class', agent_class) children = list(root) for child in children: set_class(child, prop, agent_class) def set_geom_class(root, name): set_class(root, 'geom', name) def set_motor_class(root, name): set_class(root, 'motor', name) def add_prefix(root, prop, prefix, force_set=False): if root is None: return root_prop_val = root.get(prop) if root_prop_val is not None: root.set(prop, prefix + '/' + root_prop_val) elif force_set: root.set(prop, prefix + '/' + 'anon' + str(np.random.randint(1, 1e10))) children = list(root) for child in children: add_prefix(child, prop, prefix, force_set) def tuple_to_str(tp): return " ".join(map(str, tp)) def create_multiagent_xml( world_xml, all_agent_xmls, agent_scopes=None, outdir='/tmp/', outpath=None, ini_pos=None, rgb=None ): world = ET.parse(world_xml) world_root = world.getroot() world_default = world_root.find('default') world_body = world_root.find('worldbody') world_actuator = None world_tendons = None n_agents = len(all_agent_xmls) if rgb is None: rgb = get_distinct_colors(n_agents) RGB_tuples = list( map(lambda x: tuple_to_str(x), rgb) ) if agent_scopes is None: agent_scopes = ['agent' + str(i) for i in range(n_agents)] if ini_pos is None: ini_pos = [(-i, 0, 0.75) for i in np.linspace(-n_agents, n_agents, n_agents)] # ini_pos = list(map(lambda x: tuple_to_str(x), ini_pos)) for i in range(n_agents): agent_default = ET.SubElement( world_default, 'default', attrib={'class': agent_scopes[i]} ) rgba = RGB_tuples[i] + " 1" agent_xml = ET.parse(all_agent_xmls[i]) default = agent_xml.find('default') color_set = False for child in list(default): if child.tag == 'geom': child.set('rgba', rgba) color_set = True agent_default.append(child) if not color_set: agent_geom = ET.SubElement( agent_default, 'geom', attrib={'contype': '1', 'conaffinity': '1', 'rgba': rgba} ) agent_body = agent_xml.find('body') if agent_body.get('pos'): oripos = list(map(float, agent_body.get('pos').strip().split(" "))) # keep original y and z coordinates pos = list(ini_pos[i]) # pos[1] = oripos[1] # pos[2] = oripos[2] # print(tuple_to_str(pos)) agent_body.set('pos', tuple_to_str(pos)) # add class to all geoms set_geom_class(agent_body, agent_scopes[i]) # add prefix to all names, important to map joints add_prefix(agent_body, 'name', agent_scopes[i], force_set=True) # add aggent body to xml world_body.append(agent_body) # get agent actuators agent_actuator = agent_xml.find('actuator') # add same prefix to all motor joints add_prefix(agent_actuator, 'joint', agent_scopes[i]) add_prefix(agent_actuator, 'name', agent_scopes[i]) # add actuator set_motor_class(agent_actuator, agent_scopes[i]) if world_actuator is None: world_root.append(agent_actuator) world_actuator = world_root.find('actuator') # print(world_actuator) # print(ET.tostring(world_root)) else: for motor in list(agent_actuator): world_actuator.append(motor) # get agent tendons if exists agent_tendon = agent_xml.find('tendon') if agent_tendon: # add same prefix to all motor joints add_prefix(agent_tendon, 'joint', agent_scopes[i]) add_prefix(agent_tendon, 'name', agent_scopes[i]) # add tendon if world_tendons is None: world_root.append(agent_tendon) world_tendons = world_root.find('tendon') # print(world_actuator) # print(ET.tostring(world_root)) else: for tendon in list(agent_tendon): world_tendons.append(tendon) if outpath is None: outname = world_xml.split("/")[-1].split(".xml")[0] + '.' + ".".join(map(lambda x: x.split("/")[-1].split(".xml")[0], all_agent_xmls)) + ".xml" outpath = outdir + '/' + outname world.write(outpath) return ET.tostring(world_root), outpath