tools/rossim/rosigen.py (230 lines of code) (raw):
#!/usr/bin/env python3
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import array
import importlib
import json
import re
import traceback
from threading import Thread
import numpy
import rclpy
from rclpy.node import Node
class Rosigen:
PRIMITIVE_TYPES = (int, str, bool, float, bytes)
PUBLISH_HISTORY_DEPTH = 10 # Message queue size with 'keep last' QoS mode
def __init__(self, config_filename, values_filename=None):
rclpy.init()
self._node = Node("Rosigen")
self._publishers = {} # Dict of publishers, indexed by ROS2 topic name
self._timers = {} # Dict of timers used to trigger publishing, index by ROS2 topic name
# Dict of ROS2 message values, indexed by ROS2 topic name. Each message is of a ROS2 message
# type, structured according to the corresponding interface definition.
self._vals = {}
config = self._load_json(config_filename)
for topic in config["topics"]:
module = importlib.import_module(topic["module"])
topic_type = getattr(module, topic["type"])
self._publishers[topic["name"]] = self._node.create_publisher(
topic_type, topic["name"], self.PUBLISH_HISTORY_DEPTH
)
self._vals[topic["name"]] = topic_type()
self._timers[topic["name"]] = self._node.create_timer(
topic["period_sec"],
lambda name=topic["name"]: self.publish_single_message(name),
)
if values_filename is not None:
self.load_values(values_filename)
self._thread = Thread(target=self._publish_thread)
self._thread.start()
def publish_single_message(self, name):
# print("Publish single message of size: "+str(len(str(self._vals[name]))))
self._publishers[name].publish(self._vals[name])
def _publish_thread(self):
try:
rclpy.spin(self._node)
except rclpy.executors.ExternalShutdownException:
pass
def _is_list(self, item):
return type(item) in [numpy.ndarray, list, array.array]
def _is_primitive_list(self, member):
return (
type(member) in [numpy.ndarray, array.array]
or type(member) is list
and len(member) > 0
and type(member[0]) in self.PRIMITIVE_TYPES
)
def _save_values(self, msg):
obj = {}
for field in msg.get_fields_and_field_types():
field_val = getattr(msg, field)
if type(field_val) in self.PRIMITIVE_TYPES:
obj[field] = field_val
elif self._is_primitive_list(field_val):
obj[field] = []
for member in field_val:
if type(member).__module__ == numpy.__name__:
member = member.item()
obj[field].append(member)
elif self._is_list(field_val):
obj[field] = []
for member in field_val:
obj[field].append(self._save_values(member))
else:
obj[field] = self._save_values(field_val)
return obj
def save_values(self, filename):
vals = {}
for topic in self._vals:
vals[topic] = self._save_values(self._vals[topic])
self._save_json(filename, vals)
def _check_primitive_type(self, field_val, val):
if type(field_val).__module__ == numpy.__name__:
field_val = field_val.item()
if type(field_val) not in self.PRIMITIVE_TYPES:
raise Exception("error: invalid path")
# If field is Boolean, check whether the value is false ('0', 'False' or 'false'), then set
# the value to '' (blank string), so that type(field_type)(val) will return False:
if type(field_val) == bool and val in ["0", "False", "false"]:
val = ""
if type(field_val) == bytes:
val = [int(val)]
return val
def _set_value(self, msg, field, val):
if self._is_list(msg):
match = re.match(r"^\[(\d+)\]$", field)
if not match:
raise Exception("error: invalid index")
index = int(match.groups(1)[0])
if len(msg) <= index:
# for dynamic sized arrays first grow vals
for _i in range(0, index - len(msg) + 1):
msg.append(val)
val = self._check_primitive_type(msg[index], val)
msg[index] = type(msg[index])(val)
else:
field_val = getattr(msg, field)
val = self._check_primitive_type(field_val, val)
setattr(msg, field, type(field_val)(val))
def _load_values(self, msg, val):
for field in msg.get_fields_and_field_types():
field_val = getattr(msg, field)
if type(field_val) in self.PRIMITIVE_TYPES:
self._set_value(msg, field, val[field])
elif self._is_primitive_list(field_val):
length = len(field_val)
if length == 0:
# dynamic size take length of input
length = len(val[field])
for i in range(length):
self._set_value(field_val, f"[{i}]", val[field][i])
elif self._is_list(field_val):
for i in range(len(field_val)):
self._load_values(field_val[i], val[field][i])
else:
self._load_values(field_val, val[field])
def load_values(self, filename):
vals = self._load_json(filename)
for topic in self._vals:
self._load_values(self._vals[topic], vals[topic])
def _get_fields(self, msg):
obj = {}
for field in msg.get_fields_and_field_types():
field_val = getattr(msg, field)
if type(field_val) in self.PRIMITIVE_TYPES:
obj[field] = None # No further auto-completion
elif self._is_primitive_list(field_val):
obj[field] = {}
for i in range(len(field_val)):
obj[field][f"[{i}]"] = None
elif self._is_list(field_val):
obj[field] = {}
for i in range(len(field_val)):
obj[field][f"[{i}]"] = self._get_fields(field_val[i])
else:
obj[field] = self._get_fields(field_val)
return obj
def get_fields(self):
return {topic: self._get_fields(self._vals[topic]) for topic in self._vals}
def get_value(self, path):
try:
msg = self._vals[path[0]]
for i in range(1, len(path)):
if self._is_list(msg):
match = re.match(r"^\[(\d+)\]$", path[i])
if not match:
raise Exception("error: invalid index")
msg = msg[int(match.groups(1)[0])]
else:
msg = getattr(msg, path[i])
return msg
except Exception:
print("error: invalid path")
raise
def set_value(self, path, value):
msg = self.get_value(path[:-1])
self._set_value(msg, path[-1], value)
def _load_json(self, filename):
try:
with open(filename) as fp:
return json.load(fp)
except Exception:
print("error: failed to load " + filename)
raise
def _save_json(self, filename, data):
try:
with open(filename, "w") as fp:
return json.dump(data, fp, sort_keys=True, indent=4)
except Exception:
print("error: failed to save " + filename)
def stop(self):
# Destroying the node is not really necessary since it will be destroyed when garbage
# collected. But explicitly destroying makes it more predictable and also allows us to
# reinitialize ROS2 with a different config.
self._node.destroy_node()
# After ROS2 Humble, rclpy.shutdown() could fail on SIGTERM because ROS2 signal handler
# already called it. That is why we use try_shutdown() instead, which checks whether the
# context is already shutdown.
rclpy.utilities.try_shutdown()
self._thread.join()
def topic(self, topic):
return self._vals[topic]
if __name__ == "__main__":
import argparse
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import NestedCompleter, PathCompleter
parser = argparse.ArgumentParser(description="Generates ROS2 messages interactively")
parser.add_argument("-c", "--config", help="Config JSON file", required=True)
parser.add_argument(
"-v", "--values", help="Values JSON file. Generate one using the 'save' command."
)
args = parser.parse_args()
r = Rosigen(args.config, args.values)
path_completer = PathCompleter()
cmd_completion_dict = {
"set": r.get_fields(),
"get": r.get_fields(),
"save": path_completer,
"load": path_completer,
"exit": None,
}
cmd_completer = NestedCompleter.from_nested_dict(cmd_completion_dict)
def print_help():
print("Usage:")
print(" set <TOPIC> <MSG_MEMBERS...> <VALUE>")
print(" get <TOPIC> <MSG_MEMBERS...>")
print(" save <VALUE_JSON_FILE>")
print(" load <VALUE_JSON_FILE>")
print(" exit")
session = PromptSession()
try:
while True:
cmd = session.prompt("rosigen$ ", completer=cmd_completer).split()
try:
if len(cmd) == 0:
pass
elif cmd[0] == "exit" or cmd[0] == "quit":
break
elif cmd[0] == "set":
r.set_value(cmd[1:-1], cmd[-1])
elif cmd[0] == "get":
print(r.get_value(cmd[1:]))
elif cmd[0] == "save":
r.save_values(cmd[1])
elif cmd[0] == "load":
r.load_values(cmd[1])
elif cmd[0] == "help":
print_help()
else:
print("error: invalid command: " + cmd[0])
print_help()
except Exception as e:
print("TRACEBACK: " + traceback.format_exc())
print(e)
except KeyboardInterrupt:
pass
r.stop()