diff --git a/src/rosbags/typesys/__init__.py b/src/rosbags/typesys/__init__.py index 1dc19c22..413df8b0 100644 --- a/src/rosbags/typesys/__init__.py +++ b/src/rosbags/typesys/__init__.py @@ -18,11 +18,12 @@ Supported formats: from .base import TypesysError from .idl import get_types_from_idl -from .msg import get_types_from_msg +from .msg import generate_msgdef, get_types_from_msg from .register import register_types __all__ = [ 'TypesysError', + 'generate_msgdef', 'get_types_from_idl', 'get_types_from_msg', 'register_types', diff --git a/src/rosbags/typesys/msg.py b/src/rosbags/typesys/msg.py index 75a24c4d..514d77e7 100644 --- a/src/rosbags/typesys/msg.py +++ b/src/rosbags/typesys/msg.py @@ -12,11 +12,13 @@ Rosbag1 connection information. from __future__ import annotations +from hashlib import md5 from pathlib import PurePosixPath as Path from typing import TYPE_CHECKING -from .base import Nodetype, parse_message_definition +from .base import Nodetype, TypesysError, parse_message_definition from .peg import Rule, Visitor, parse_grammar +from .types import FIELDDEFS if TYPE_CHECKING: from typing import Any, List @@ -129,6 +131,20 @@ def normalize_fieldtype(typename: str, field: Fielddesc, names: List[str]) -> Fi return (ftype, (ifield, args[1])) +def denormalize_msgtype(typename: str) -> str: + """Undo message tyoename normalization. + + Args: + typename: Normalized message typename. + + Returns: + ROS1 style name. + + """ + assert '/msg/' in typename + return str((path := Path(typename)).parent.parent / path.name) + + class VisitorMSG(Visitor): """MSG file visitor.""" @@ -223,3 +239,99 @@ def get_types_from_msg(text: str, name: str) -> Typesdict: """ return parse_message_definition(VisitorMSG(), f'MSG: {name}\n{text}') + + +def gendefhash(typename: str, subdefs: dict[str, tuple[str, str]]) -> tuple[str, str]: + """Generate message definition and hash for type. + + The subdefs argument will be filled with child definitions. + + Args: + typename: Name of type to generate definition for. + subdefs: Child definitions. + + Returns: + Message definition and hash. + + Raises: + TypesysError: Type does not exist. + + """ + # pylint: disable=too-many-branches + typemap = { + 'builtin_interfaces/msg/Time': 'time', + 'builtin_interfaces/msg/Duration': 'duration', + } + + deftext: list[str] = [] + hashtext: list[str] = [] + if typename not in FIELDDEFS: + raise TypesysError(f'Type {typename!r} is unknown.') + + for name, typ, value in FIELDDEFS[typename][0]: + deftext.append(f'{typ} {name}={value}') + hashtext.append(f'{typ} {name}={value}') + + for name, (ftype, args) in FIELDDEFS[typename][1]: + if ftype == Nodetype.BASE: + deftext.append(f'{args} {name}') + hashtext.append(f'{args} {name}') + elif ftype == Nodetype.NAME: + assert isinstance(args, str) + subname = args + if subname in typemap: + deftext.append(f'{typemap[subname]} {name}') + hashtext.append(f'{typemap[subname]} {name}') + else: + if subname not in subdefs: + subdefs[subname] = ('', '') + subdefs[subname] = gendefhash(subname, subdefs) + deftext.append(f'{denormalize_msgtype(subname)} {name}') + hashtext.append(f'{subdefs[subname][1]} {name}') + else: + assert isinstance(args, tuple) + subdesc, num = args + count = '' if num is None else str(num) + subtype, subname = subdesc + if subtype == Nodetype.BASE: + deftext.append(f'{subname}[{count}] {name}') + hashtext.append(f'{subname}[{count}] {name}') + elif subname in typemap: + deftext.append(f'{typemap[subname]}[{count}] {name}') + hashtext.append(f'{typemap[subname]}[{count}] {name}') + else: + if subname not in subdefs: + subdefs[subname] = ('', '') + subdefs[subname] = gendefhash(subname, subdefs) + deftext.append(f'{denormalize_msgtype(subname)}[{count}] {name}') + hashtext.append(f'{subdefs[subname][1]} {name}') + + if typename == 'std_msgs/msg/Header': + deftext.insert(0, 'uint32 seq') + hashtext.insert(0, 'uint32 seq') + + deftext.append('') + return '\n'.join(deftext), md5('\n'.join(hashtext).encode()).hexdigest() + + +def generate_msgdef(typename: str) -> tuple[str, str]: + """Generate message definition for type. + + Args: + typename: Name of type to generate definition for. + + Returns: + Message definition. + + """ + subdefs: dict[str, tuple[str, str]] = {} + msgdef, md5sum = gendefhash(typename, subdefs) + + msgdef = ''.join( + [ + msgdef, + *[f'{"=" * 80}\nMSG: {denormalize_msgtype(k)}\n{v[0]}' for k, v in subdefs.items()], + ], + ) + + return msgdef, md5sum diff --git a/tests/test_parse.py b/tests/test_parse.py index cc78bbfc..e0fb995c 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -4,7 +4,13 @@ import pytest -from rosbags.typesys import TypesysError, get_types_from_idl, get_types_from_msg, register_types +from rosbags.typesys import ( + TypesysError, + generate_msgdef, + get_types_from_idl, + get_types_from_msg, + register_types, +) from rosbags.typesys.base import Nodetype from rosbags.typesys.types import FIELDDEFS @@ -181,3 +187,42 @@ def test_register_types(): with pytest.raises(TypesysError, match='different definition'): register_types({'foo': [[], [('x', (1, 'bool'))]]}) + + +def test_generate_msgdef(): + """Test message definition generator.""" + res = generate_msgdef('std_msgs/msg/Header') + assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed') + + res = generate_msgdef('geometry_msgs/msg/PointStamped') + assert res[0].split(f'{"=" * 80}\n') == [ + 'std_msgs/Header header\ngeometry_msgs/Point point\n', + 'MSG: std_msgs/Header\nuint32 seq\ntime stamp\nstring frame_id\n', + 'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n', + ] + + res = generate_msgdef('geometry_msgs/msg/Twist') + assert res[0].split(f'{"=" * 80}\n') == [ + 'geometry_msgs/Vector3 linear\ngeometry_msgs/Vector3 angular\n', + 'MSG: geometry_msgs/Vector3\nfloat64 x\nfloat64 y\nfloat64 z\n', + ] + + res = generate_msgdef('shape_msgs/msg/Mesh') + assert res[0].split(f'{"=" * 80}\n') == [ + 'shape_msgs/MeshTriangle[] triangles\ngeometry_msgs/Point[] vertices\n', + 'MSG: shape_msgs/MeshTriangle\nuint32[3] vertex_indices\n', + 'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n', + ] + + res = generate_msgdef('shape_msgs/msg/Plane') + assert res[0] == 'float64[4] coef\n' + + res = generate_msgdef('sensor_msgs/msg/MultiEchoLaserScan') + assert len(res[0].split('=' * 80)) == 3 + + register_types(get_types_from_msg('time[3] times\nuint8 foo=42', 'foo_msgs/Timelist')) + res = generate_msgdef('foo_msgs/msg/Timelist') + assert res[0] == 'uint8 foo=42\ntime[3] times\n' + + with pytest.raises(TypesysError, match='is unknown'): + generate_msgdef('foo_msgs/msg/Badname')