From abd0c1fa73dad325a8303b47b0a3fcdefdb79e97 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Sun, 2 May 2021 14:46:31 +0200 Subject: [PATCH] Add serde --- docs/api/rosbags.rst | 1 + docs/api/rosbags.serde.rst | 6 + docs/index.rst | 1 + docs/topics/serde.rst | 31 +++ src/rosbags/serde/__init__.py | 19 ++ src/rosbags/serde/cdr.py | 443 ++++++++++++++++++++++++++++++++ src/rosbags/serde/messages.py | 72 ++++++ src/rosbags/serde/primitives.py | 55 ++++ src/rosbags/serde/ros1.py | 180 +++++++++++++ src/rosbags/serde/serdes.py | 102 ++++++++ src/rosbags/serde/typing.py | 35 +++ src/rosbags/serde/utils.py | 103 ++++++++ tests/cdr.py | 441 +++++++++++++++++++++++++++++++ tests/test_serde.py | 382 +++++++++++++++++++++++++++ 14 files changed, 1871 insertions(+) create mode 100644 docs/api/rosbags.serde.rst create mode 100644 docs/topics/serde.rst create mode 100644 src/rosbags/serde/__init__.py create mode 100644 src/rosbags/serde/cdr.py create mode 100644 src/rosbags/serde/messages.py create mode 100644 src/rosbags/serde/primitives.py create mode 100644 src/rosbags/serde/ros1.py create mode 100644 src/rosbags/serde/serdes.py create mode 100644 src/rosbags/serde/typing.py create mode 100644 src/rosbags/serde/utils.py create mode 100644 tests/cdr.py create mode 100644 tests/test_serde.py diff --git a/docs/api/rosbags.rst b/docs/api/rosbags.rst index b106356c..c2d05d62 100644 --- a/docs/api/rosbags.rst +++ b/docs/api/rosbags.rst @@ -4,4 +4,5 @@ Rosbags namespace .. toctree:: :maxdepth: 4 + rosbags.serde rosbags.typesys diff --git a/docs/api/rosbags.serde.rst b/docs/api/rosbags.serde.rst new file mode 100644 index 00000000..0fe5f96e --- /dev/null +++ b/docs/api/rosbags.serde.rst @@ -0,0 +1,6 @@ +rosbags.serde +============= + +.. automodule:: rosbags.serde + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 83f2482d..a73cf508 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ :hidden: topics/typesys + topics/serde .. toctree:: diff --git a/docs/topics/serde.rst b/docs/topics/serde.rst new file mode 100644 index 00000000..69b95ebd --- /dev/null +++ b/docs/topics/serde.rst @@ -0,0 +1,31 @@ +Serialization and deserialization +================================= + +The serialization and deserialization system :py:mod:`rosbags.serde` supports multiple raw message formats. For each format it provides a pair of functions, one for serialization and one for deserialization. In addition to the data to process each function usually only requires the message type name. + +Deserialization +--------------- + +Deserialize a CDR bytes object using :py:func:`deserialize_cdr() `: + +.. code-block:: python + + from rosbags.serde import deserialize_cdr + + # rawdata is of type bytes and contains serialized message + msg = deserialize_cdr(rawdata, 'geometry_msgs/msg/Quaternion') + +Serialization +--------------- + +Serialize a message with CDR using :py:func:`serialize_cdr() `: + +.. code-block:: python + + from rosbags.serde import serialize_cdr + + # serialize message with system endianess + serialized = serialize_cdr(msg, 'geometry_msgs/msg/Quaternion') + + # serialize message with explicit endianess + serialized = serialize_cdr(msg, 'geometry_msgs/msg/Quaternion', little_endian=False) diff --git a/src/rosbags/serde/__init__.py b/src/rosbags/serde/__init__.py new file mode 100644 index 00000000..55cd0961 --- /dev/null +++ b/src/rosbags/serde/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Rosbags message serialization and deserialization. + +Serializers and deserializers convert between python messages objects and +the common rosbag serialization formats. Computationally cheap functions +convert directly between different serialization formats. + +""" + +from .messages import SerdeError +from .serdes import deserialize_cdr, ros1_to_cdr, serialize_cdr + +__all__ = [ + 'SerdeError', + 'deserialize_cdr', + 'ros1_to_cdr', + 'serialize_cdr', +] diff --git a/src/rosbags/serde/cdr.py b/src/rosbags/serde/cdr.py new file mode 100644 index 00000000..ba9cc6d8 --- /dev/null +++ b/src/rosbags/serde/cdr.py @@ -0,0 +1,443 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Code generators for CDR. + +Common Data Representation `CDR`_ is the serialization format used by most ROS2 +middlewares. + +.. _CDR: https://www.omg.org/cgi-bin/doc?formal/02-06-51 + +""" + +from __future__ import annotations + +import sys +from itertools import tee +from typing import TYPE_CHECKING, Iterator, Optional, Tuple, cast + +from .typing import Field +from .utils import SIZEMAP, Valtype, align, align_after, compile_lines + +if TYPE_CHECKING: + from typing import Callable, List + + +def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]: + """Generate cdr size calculation function. + + Args: + fields: Fields of message. + + Returns: + Size calculation function and static size. + + """ + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + size = 0 + is_stat = True + + aligned = 8 + icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) + next(inext) + lines = [ + 'import sys', + 'from rosbags.serde.messages import get_msgdef', + 'def getsize_cdr(pos, message):', + ] + for fcurr, fnext in zip(icurr, inext): + fieldname, desc = fcurr + + if desc.valtype == Valtype.MESSAGE: + if desc.args.size_cdr: + lines.append(f' pos += {desc.args.size_cdr}') + size += desc.args.size_cdr + else: + lines.append(f' func = get_msgdef("{desc.args.name}").getsize_cdr') + lines.append(f' pos = func(pos, message.{fieldname})') + is_stat = False + aligned = align_after(desc) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + lines.append(f' pos += 4 + len(message.{fieldname}.encode()) + 1') + aligned = 1 + is_stat = False + else: + lines.append(f' pos += {SIZEMAP[desc.args]}') + aligned = SIZEMAP[desc.args] + size += SIZEMAP[desc.args] + + elif desc.valtype == Valtype.ARRAY: + subdesc = desc.args[1] + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(f' val = message.{fieldname}') + for idx in range(desc.args[0]): + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(f' pos += 4 + len(val[{idx}].encode()) + 1') + aligned = 1 + is_stat = False + else: + lines.append(f' pos += {desc.args[0] * SIZEMAP[subdesc.args]}') + size += desc.args[0] * SIZEMAP[subdesc.args] + + else: + assert subdesc.valtype == Valtype.MESSAGE + anext = align(subdesc) + anext_after = align_after(subdesc) + + if subdesc.args.size_cdr: + for _ in range(desc.args[0]): + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + size = (size + anext - 1) & -anext + lines.append(f' pos += {subdesc.args.size_cdr}') + size += subdesc.args.size_cdr + else: + lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') + lines.append(f' val = message.{fieldname}') + for idx in range(desc.args[0]): + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = func(pos, val[{idx}])') + is_stat = False + aligned = align_after(subdesc) + else: + assert desc.valtype == Valtype.SEQUENCE + lines.append(' pos += 4') + aligned = 4 + subdesc = desc.args + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(f' for val in message.{fieldname}:') + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(' pos += 4 + len(val.encode()) + 1') + aligned = 1 + else: + anext = align(subdesc) + if aligned < anext: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + aligned = anext + lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}') + + else: + assert subdesc.valtype == Valtype.MESSAGE + anext = align(subdesc) + anext_after = align_after(subdesc) + lines.append(f' val = message.{fieldname}') + if subdesc.args.size_cdr: + if aligned < anext <= anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(' for _ in val:') + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos += {subdesc.args.size_cdr}') + + else: + lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') + if aligned < anext <= anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(' for item in val:') + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(' pos = func(pos, item)') + aligned = align_after(subdesc) + + aligned = min([aligned, 4]) + is_stat = False + + if fnext and aligned < (anext := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + aligned = anext + is_stat = False + lines.append(' return pos') + return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore + + +def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable: + """Generate cdr serialization function. + + Args: + fields: Fields of message. + endianess: Endianess of rawdata. + + Returns: + Serializer function. + + """ + # pylint: disable=too-many-branches,too-many-locals,too-many-statements + aligned = 8 + icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) + next(inext) + lines = [ + 'import sys', + 'import numpy', + 'from rosbags.serde.messages import SerdeError, get_msgdef', + f'from rosbags.serde.primitives import pack_bool_{endianess}', + f'from rosbags.serde.primitives import pack_int8_{endianess}', + f'from rosbags.serde.primitives import pack_int16_{endianess}', + f'from rosbags.serde.primitives import pack_int32_{endianess}', + f'from rosbags.serde.primitives import pack_int64_{endianess}', + f'from rosbags.serde.primitives import pack_uint8_{endianess}', + f'from rosbags.serde.primitives import pack_uint16_{endianess}', + f'from rosbags.serde.primitives import pack_uint32_{endianess}', + f'from rosbags.serde.primitives import pack_uint64_{endianess}', + f'from rosbags.serde.primitives import pack_float32_{endianess}', + f'from rosbags.serde.primitives import pack_float64_{endianess}', + 'def serialize_cdr(rawdata, pos, message):', + ] + for fcurr, fnext in zip(icurr, inext): + fieldname, desc = fcurr + + lines.append(f' val = message.{fieldname}') + if desc.valtype == Valtype.MESSAGE: + lines.append(f' func = get_msgdef("{desc.args.name}").serialize_cdr_{endianess}') + lines.append(' pos = func(rawdata, pos, val)') + aligned = align_after(desc) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + lines.append(' bval = memoryview(val.encode())') + lines.append(' length = len(bval) + 1') + lines.append(f' pack_int32_{endianess}(rawdata, pos, length)') + lines.append(' pos += 4') + lines.append(' rawdata[pos:pos + length - 1] = bval') + lines.append(' pos += length') + aligned = 1 + else: + lines.append(f' pack_{desc.args}_{endianess}(rawdata, pos, val)') + lines.append(f' pos += {SIZEMAP[desc.args]}') + aligned = SIZEMAP[desc.args] + + elif desc.valtype == Valtype.ARRAY: + subdesc = desc.args[1] + lines.append(f' if len(val) != {desc.args[0]}:') + lines.append(' raise SerdeError(\'Unexpected array length\')') + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + for idx in range(desc.args[0]): + lines.append(f' bval = memoryview(val[{idx}].encode())') + lines.append(' length = len(bval) + 1') + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(f' pack_int32_{endianess}(rawdata, pos, length)') + lines.append(' pos += 4') + lines.append(' rawdata[pos:pos + length - 1] = bval') + lines.append(' pos += length') + aligned = 1 + else: + if (endianess == 'le') != (sys.byteorder == 'little'): + lines.append(' val = val.byteswap()') + size = desc.args[0] * SIZEMAP[subdesc.args] + lines.append(f' rawdata[pos:pos + {size}] = val.view(numpy.uint8)') + lines.append(f' pos += {size}') + + else: + assert subdesc.valtype == Valtype.MESSAGE + anext = align(subdesc) + anext_after = align_after(subdesc) + lines.append( + f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', + ) + for idx in range(desc.args[0]): + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = func(rawdata, pos, val[{idx}])') + aligned = align_after(subdesc) + else: + assert desc.valtype == Valtype.SEQUENCE + lines.append(f' pack_int32_{endianess}(rawdata, pos, len(val))') + lines.append(' pos += 4') + aligned = 4 + subdesc = desc.args + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(' for item in val:') + lines.append(' bval = memoryview(item.encode())') + lines.append(' length = len(bval) + 1') + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(f' pack_int32_{endianess}(rawdata, pos, length)') + lines.append(' pos += 4') + lines.append(' rawdata[pos:pos + length - 1] = bval') + lines.append(' pos += length') + aligned = 1 + else: + lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}') + if (endianess == 'le') != (sys.byteorder == 'little'): + lines.append(' val = val.byteswap()') + if aligned < (anext := align(subdesc)): + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)') + lines.append(' pos += size') + aligned = anext + + if subdesc.valtype == Valtype.MESSAGE: + anext = align(subdesc) + lines.append( + f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', + ) + lines.append(' for item in val:') + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(' pos = func(rawdata, pos, item)') + aligned = align_after(subdesc) + + aligned = min([4, aligned]) + + if fnext and aligned < (anext := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + aligned = anext + lines.append(' return pos') + return compile_lines(lines).serialize_cdr # type: ignore + + +def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable: + """Generate cdr deserialization function. + + Args: + fields: Fields of message. + endianess: Endianess of rawdata. + + Returns: + Deserializer function. + + """ + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + aligned = 8 + icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) + next(inext) + lines = [ + 'import sys', + 'import numpy', + 'from rosbags.serde.messages import SerdeError, get_msgdef', + f'from rosbags.serde.primitives import unpack_bool_{endianess}', + f'from rosbags.serde.primitives import unpack_int8_{endianess}', + f'from rosbags.serde.primitives import unpack_int16_{endianess}', + f'from rosbags.serde.primitives import unpack_int32_{endianess}', + f'from rosbags.serde.primitives import unpack_int64_{endianess}', + f'from rosbags.serde.primitives import unpack_uint8_{endianess}', + f'from rosbags.serde.primitives import unpack_uint16_{endianess}', + f'from rosbags.serde.primitives import unpack_uint32_{endianess}', + f'from rosbags.serde.primitives import unpack_uint64_{endianess}', + f'from rosbags.serde.primitives import unpack_float32_{endianess}', + f'from rosbags.serde.primitives import unpack_float64_{endianess}', + 'def deserialize_cdr(rawdata, pos, cls):', + ] + + funcname = f'deserialize_cdr_{endianess}' + lines.append(' values = []') + for fcurr, fnext in zip(icurr, inext): + desc = fcurr[1] + + if desc.valtype == Valtype.MESSAGE: + lines.append(f' msgdef = get_msgdef("{desc.args.name}")') + lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') + lines.append(' values.append(obj)') + aligned = align_after(desc) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]') + lines.append(' string = bytes(rawdata[pos + 4:pos + 4 + length - 1]).decode()') + lines.append(' values.append(string)') + lines.append(' pos += 4 + length') + aligned = 1 + else: + lines.append(f' value = unpack_{desc.args}_{endianess}(rawdata, pos)[0]') + lines.append(' values.append(value)') + lines.append(f' pos += {SIZEMAP[desc.args]}') + aligned = SIZEMAP[desc.args] + + elif desc.valtype == Valtype.ARRAY: + subdesc = desc.args[1] + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(' value = []') + for idx in range(desc.args[0]): + if idx: + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]') + lines.append( + ' value.append(bytes(rawdata[pos + 4:pos + 4 + length - 1]).decode())', + ) + lines.append(' pos += 4 + length') + lines.append(' values.append(value)') + aligned = 1 + else: + size = desc.args[0] * SIZEMAP[subdesc.args] + lines.append( + f' val = numpy.frombuffer(rawdata, ' + f'dtype=numpy.{subdesc.args}, count={desc.args[0]}, offset=pos)', + ) + if (endianess == 'le') != (sys.byteorder == 'little'): + lines.append(' val = val.byteswap()') + lines.append(' values.append(val)') + lines.append(f' pos += {size}') + else: + assert subdesc.valtype == Valtype.MESSAGE + anext = align(subdesc) + anext_after = align_after(subdesc) + lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') + lines.append(' value = []') + for _ in range(desc.args[0]): + if anext > anext_after: + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') + lines.append(' value.append(obj)') + lines.append(' values.append(value)') + aligned = align_after(subdesc) + + else: + assert desc.valtype == Valtype.SEQUENCE + lines.append(f' size = unpack_int32_{endianess}(rawdata, pos)[0]') + lines.append(' pos += 4') + aligned = 4 + subdesc = desc.args + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(' value = []') + lines.append(' for _ in range(size):') + lines.append(' pos = (pos + 4 - 1) & -4') + lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]') + lines.append( + ' value.append(bytes(rawdata[pos + 4:pos + 4 + length - 1])' + '.decode())', + ) + lines.append(' pos += 4 + length') + lines.append(' values.append(value)') + aligned = 1 + else: + lines.append(f' length = size * {SIZEMAP[subdesc.args]}') + if aligned < (anext := align(subdesc)): + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append( + f' val = numpy.frombuffer(rawdata, ' + f'dtype=numpy.{subdesc.args}, count=size, offset=pos)', + ) + if (endianess == 'le') != (sys.byteorder == 'little'): + lines.append(' val = val.byteswap()') + lines.append(' values.append(val)') + lines.append(' pos += length') + aligned = anext + + if subdesc.valtype == Valtype.MESSAGE: + anext = align(subdesc) + lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') + lines.append(' value = []') + lines.append(' for _ in range(size):') + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') + lines.append(' value.append(obj)') + lines.append(' values.append(value)') + aligned = align_after(subdesc) + + aligned = min([4, aligned]) + + if fnext and aligned < (anext := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + aligned = anext + + lines.append(' return cls(*values), pos') + return compile_lines(lines).deserialize_cdr # type: ignore diff --git a/src/rosbags/serde/messages.py b/src/rosbags/serde/messages.py new file mode 100644 index 00000000..3c2193d7 --- /dev/null +++ b/src/rosbags/serde/messages.py @@ -0,0 +1,72 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Runtime message loader and cache.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from rosbags.typesys import types + +from .cdr import generate_deserialize_cdr, generate_getsize_cdr, generate_serialize_cdr +from .ros1 import generate_ros1_to_cdr +from .typing import Field, Msgdef +from .utils import Descriptor, Valtype + +if TYPE_CHECKING: + from typing import Any, Dict + + +MSGDEFCACHE: Dict[str, Msgdef] = {} + + +class SerdeError(Exception): + """Serialization and Deserialization Error.""" + + +def get_msgdef(typename: str) -> Msgdef: + """Retrieve message definition for typename. + + Message definitions are cached globally and generated as needed. + + Args: + typename: Msgdef type name to load. + + Returns: + Message definition. + + """ + if typename not in MSGDEFCACHE: + entries = types.FIELDDEFS[typename] + + def fixup(entry: Any) -> Descriptor: + if entry[0] == Valtype.BASE: + return Descriptor(Valtype.BASE, entry[1]) + if entry[0] == Valtype.MESSAGE: + return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1])) + if entry[0] == Valtype.ARRAY: + return Descriptor(Valtype.ARRAY, (entry[1], fixup(entry[2]))) + if entry[0] == Valtype.SEQUENCE: + return Descriptor(Valtype.SEQUENCE, fixup(entry[1])) + raise SerdeError( # pragma: no cover + f'Unknown field type {entry[0]!r} encountered.', + ) + + fields = [Field(name, fixup(desc)) for name, desc in entries] + + getsize_cdr, size_cdr = generate_getsize_cdr(fields) + + MSGDEFCACHE[typename] = Msgdef( + typename, + fields, + getattr(types, typename.replace('/', '__')), + size_cdr, + getsize_cdr, + generate_serialize_cdr(fields, 'le'), + generate_serialize_cdr(fields, 'be'), + generate_deserialize_cdr(fields, 'le'), + generate_deserialize_cdr(fields, 'be'), + generate_ros1_to_cdr(fields, typename, False), + generate_ros1_to_cdr(fields, typename, True), + ) + return MSGDEFCACHE[typename] diff --git a/src/rosbags/serde/primitives.py b/src/rosbags/serde/primitives.py new file mode 100644 index 00000000..21e04338 --- /dev/null +++ b/src/rosbags/serde/primitives.py @@ -0,0 +1,55 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Serialization primitives. + +These functions are used by generated code to serialize and desesialize +primitive values. + +""" + +from struct import Struct + +pack_bool_le = Struct('?').pack_into +pack_int8_le = Struct('b').pack_into +pack_int16_le = Struct('h').pack_into +pack_int32_be = Struct('>i').pack_into +pack_int64_be = Struct('>q').pack_into +pack_uint8_be = Struct('B').pack_into +pack_uint16_be = Struct('>H').pack_into +pack_uint32_be = Struct('>I').pack_into +pack_uint64_be = Struct('>Q').pack_into +pack_float32_be = Struct('>f').pack_into +pack_float64_be = Struct('>d').pack_into +unpack_bool_be = Struct('?').unpack_from +unpack_int8_be = Struct('b').unpack_from +unpack_int16_be = Struct('>h').unpack_from +unpack_int32_be = Struct('>i').unpack_from +unpack_int64_be = Struct('>q').unpack_from +unpack_uint8_be = Struct('B').unpack_from +unpack_uint16_be = Struct('>H').unpack_from +unpack_uint32_be = Struct('>I').unpack_from +unpack_uint64_be = Struct('>Q').unpack_from +unpack_float32_be = Struct('>f').unpack_from +unpack_float64_be = Struct('>d').unpack_from diff --git a/src/rosbags/serde/ros1.py b/src/rosbags/serde/ros1.py new file mode 100644 index 00000000..fb1c212b --- /dev/null +++ b/src/rosbags/serde/ros1.py @@ -0,0 +1,180 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Code generators for ROS1. + +`ROS1`_ uses a serialization format. This module supports fast byte-level +conversion of ROS1 to CDR. + +.. _ROS1: http://wiki.ros.org/ROS/Technical%20Overview + +""" + +from __future__ import annotations + +from itertools import tee +from typing import TYPE_CHECKING, Iterator, Optional, Tuple, cast + +from .typing import Field +from .utils import SIZEMAP, Valtype, align, align_after, compile_lines + +if TYPE_CHECKING: + from typing import Callable, List # pylint: disable=ungrouped-imports + + +def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable: + """Generate CDR serialization function. + + Args: + fields: Fields of message. + typename: Message type name. + copy: Generate serialization or sizing function. + + Returns: + ROS1 to CDR conversion function. + + """ + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + aligned = 8 + icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) + next(inext) + funcname = 'ros1_to_cdr' if copy else 'getsize_ros1_to_cdr' + lines = [ + 'import sys', + 'import numpy', + 'from rosbags.serde.messages import SerdeError, get_msgdef', + 'from rosbags.serde.primitives import pack_bool_le', + 'from rosbags.serde.primitives import pack_int8_le', + 'from rosbags.serde.primitives import pack_int16_le', + 'from rosbags.serde.primitives import pack_int32_le', + 'from rosbags.serde.primitives import pack_int64_le', + 'from rosbags.serde.primitives import pack_uint8_le', + 'from rosbags.serde.primitives import pack_uint16_le', + 'from rosbags.serde.primitives import pack_uint32_le', + 'from rosbags.serde.primitives import pack_uint64_le', + 'from rosbags.serde.primitives import pack_float32_le', + 'from rosbags.serde.primitives import pack_float64_le', + 'from rosbags.serde.primitives import unpack_int32_le', + f'def {funcname}(input, ipos, output, opos):', + ] + + if typename == 'std_msgs/msg/Header': + lines.append(' ipos += 4') + + for fcurr, fnext in zip(icurr, inext): + _, desc = fcurr + + if desc.valtype == Valtype.MESSAGE: + lines.append(f' func = get_msgdef("{desc.args.name}").{funcname}') + lines.append(' ipos, opos = func(input, ipos, output, opos)') + aligned = align_after(desc) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + lines.append(' length = unpack_int32_le(input, ipos)[0] + 1') + if copy: + lines.append(' pack_int32_le(output, opos, length)') + lines.append(' ipos += 4') + lines.append(' opos += 4') + if copy: + lines.append(' output[opos:opos + length - 1] = input[ipos:ipos + length - 1]') + lines.append(' ipos += length - 1') + lines.append(' opos += length') + aligned = 1 + else: + size = SIZEMAP[desc.args] + if copy: + lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]') + lines.append(f' ipos += {size}') + lines.append(f' opos += {size}') + aligned = size + + elif desc.valtype == Valtype.ARRAY: + subdesc = desc.args[1] + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + for _ in range(desc.args[0]): + lines.append(' opos = (opos + 4 - 1) & -4') + lines.append(' length = unpack_int32_le(input, ipos)[0] + 1') + if copy: + lines.append(' pack_int32_le(output, opos, length)') + lines.append(' ipos += 4') + lines.append(' opos += 4') + if copy: + lines.append( + ' output[opos:opos + length - 1] = input[ipos:ipos + length - 1]', + ) + lines.append(' ipos += length - 1') + lines.append(' opos += length') + aligned = 1 + else: + size = desc.args[0] * SIZEMAP[subdesc.args] + if copy: + lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]') + lines.append(f' ipos += {size}') + lines.append(f' opos += {size}') + aligned = SIZEMAP[subdesc.args] + + if subdesc.valtype == Valtype.MESSAGE: + anext = align(subdesc) + anext_after = align_after(subdesc) + + lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') + for _ in range(desc.args[0]): + if anext > anext_after: + lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + lines.append(' ipos, opos = func(input, ipos, output, opos)') + aligned = anext_after + else: + assert desc.valtype == Valtype.SEQUENCE + lines.append(' size = unpack_int32_le(input, ipos)[0]') + if copy: + lines.append(' pack_int32_le(output, opos, size)') + lines.append(' ipos += 4') + lines.append(' opos += 4') + subdesc = desc.args + aligned = 4 + + if subdesc.valtype == Valtype.BASE: + if subdesc.args == 'string': + lines.append(' for _ in range(size):') + lines.append(' length = unpack_int32_le(input, ipos)[0] + 1') + lines.append(' opos = (opos + 4 - 1) & -4') + if copy: + lines.append(' pack_int32_le(output, opos, length)') + lines.append(' ipos += 4') + lines.append(' opos += 4') + if copy: + lines.append( + ' output[opos:opos + length - 1] = input[ipos:ipos + length - 1]', + ) + lines.append(' ipos += length - 1') + lines.append(' opos += length') + aligned = 1 + else: + if aligned < (anext := align(subdesc)): + lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + lines.append(f' length = size * {SIZEMAP[subdesc.args]}') + if copy: + lines.append(' output[opos:opos + length] = input[ipos:ipos + length]') + lines.append(' ipos += length') + lines.append(' opos += length') + aligned = anext + + else: + assert subdesc.valtype == Valtype.MESSAGE + anext = align(subdesc) + lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') + lines.append(' for _ in range(size):') + lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + lines.append(' ipos, opos = func(input, ipos, output, opos)') + aligned = align_after(subdesc) + + aligned = min([aligned, 4]) + + if fnext and aligned < (anext := align(fnext.descriptor)): + lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + aligned = anext + + lines.append(' return ipos, opos') + return getattr(compile_lines(lines), funcname) diff --git a/src/rosbags/serde/serdes.py b/src/rosbags/serde/serdes.py new file mode 100644 index 00000000..260edc08 --- /dev/null +++ b/src/rosbags/serde/serdes.py @@ -0,0 +1,102 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Serialization, deserializion and conversion functions.""" + +from __future__ import annotations + +import sys +from struct import pack_into +from typing import TYPE_CHECKING + +from .messages import get_msgdef + +if TYPE_CHECKING: + from typing import Any + + +def deserialize_cdr(rawdata: bytes, typename: str) -> Any: + """Deserialize raw data into a message object. + + Args: + rawdata: Serialized data. + typename: Message type name. + + Returns: + Deserialized message object. + + """ + little_endian = bool(rawdata[1]) + + msgdef = get_msgdef(typename) + func = msgdef.deserialize_cdr_le if little_endian else msgdef.deserialize_cdr_be + message, pos = func(rawdata[4:], 0, msgdef.cls) + assert pos + 4 + 3 >= len(rawdata) + return message + + +def serialize_cdr( + message: Any, + typename: str, + little_endian: bool = sys.byteorder == 'little', +) -> memoryview: + """Serialize message object to bytes. + + Args: + message: Message object. + typename: Message type name. + little_endian: Should use little endianess. + + Returns: + Serialized bytes. + + """ + msgdef = get_msgdef(typename) + size = 4 + msgdef.getsize_cdr(0, message) + rawdata = memoryview(bytearray(size)) + pack_into('BB', rawdata, 0, 0, little_endian) + + func = msgdef.serialize_cdr_le if little_endian else msgdef.serialize_cdr_be + + pos = func(rawdata[4:], 0, message) + assert pos + 4 == size + return rawdata.toreadonly() + + +def ros1_to_cdr(raw: bytes, typename: str) -> memoryview: + """Convert serialized ROS1 message directly to CDR. + + This should be reasonably fast as conversions happen on a byte-level + without going through deserialization and serialization. + + Args: + raw: ROS1 serialized message. + typename: Message type name. + + Returns: + CDR serialized message. + + """ + msgdef = get_msgdef(typename) + + ipos, opos = msgdef.getsize_ros1_to_cdr( + raw, + 0, + None, + 0, + ) + assert ipos == len(raw) + + raw = memoryview(raw) + size = 4 + opos + rawdata = memoryview(bytearray(size)) + pack_into('BB', rawdata, 0, 0, True) + + ipos, opos = msgdef.ros1_to_cdr( + raw, + 0, + rawdata[4:], + 0, + ) + assert ipos == len(raw) + assert opos + 4 == size + return rawdata.toreadonly() diff --git a/src/rosbags/serde/typing.py b/src/rosbags/serde/typing.py new file mode 100644 index 00000000..a3b0d70c --- /dev/null +++ b/src/rosbags/serde/typing.py @@ -0,0 +1,35 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Python types used in this package.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from typing import Any, Callable, List # pylint: disable=ungrouped-imports + + from .utils import Descriptor + + +class Field(NamedTuple): + """Metadata of a field.""" + + name: str + descriptor: Descriptor + + +class Msgdef(NamedTuple): + """Metadata of a message.""" + + name: str + fields: List[Field] + cls: Any + size_cdr: int + getsize_cdr: Callable + serialize_cdr_le: Callable + serialize_cdr_be: Callable + deserialize_cdr_le: Callable + deserialize_cdr_be: Callable + getsize_ros1_to_cdr: Callable + ros1_to_cdr: Callable diff --git a/src/rosbags/serde/utils.py b/src/rosbags/serde/utils.py new file mode 100644 index 00000000..15362b23 --- /dev/null +++ b/src/rosbags/serde/utils.py @@ -0,0 +1,103 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Helpers used by code generators.""" + +from __future__ import annotations + +from enum import IntEnum +from importlib.util import module_from_spec, spec_from_loader +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from types import ModuleType + from typing import Any, Dict, List + + +class Valtype(IntEnum): + """Msg field value types.""" + + BASE = 1 + MESSAGE = 2 + ARRAY = 3 + SEQUENCE = 4 + + +class Descriptor(NamedTuple): + """Value type descriptor.""" + + valtype: Valtype + args: Any # Union[Descriptor, Msgdef, Tuple[int, Descriptor], str] + + +SIZEMAP: Dict[str, int] = { + 'bool': 1, + 'int8': 1, + 'int16': 2, + 'int32': 4, + 'int64': 8, + 'uint8': 1, + 'uint16': 2, + 'uint32': 4, + 'uint64': 8, + 'float32': 4, + 'float64': 8, +} + + +def align(entry: Descriptor) -> int: + """Get alignment requirement for entry. + + Args: + entry: Field. + + Returns: + Required alignment in bytes. + + """ + if entry.valtype == Valtype.BASE: + if entry.args == 'string': + return 4 + return SIZEMAP[entry.args] + if entry.valtype == Valtype.MESSAGE: + return align(entry.args.fields[0].descriptor) + if entry.valtype == Valtype.ARRAY: + return align(entry.args[1]) + assert entry.valtype == Valtype.SEQUENCE + return 4 + + +def align_after(entry: Descriptor) -> int: + """Get alignment after entry. + + Args: + entry: Field. + + Returns: + Memory alignment after entry. + + """ + if entry.valtype == Valtype.BASE: + if entry.args == 'string': + return 1 + return SIZEMAP[entry.args] + if entry.valtype == Valtype.MESSAGE: + return align_after(entry.args.fields[-1].descriptor) + if entry.valtype == Valtype.ARRAY: + return align_after(entry.args[1]) + assert entry.valtype == Valtype.SEQUENCE + return min([4, align_after(entry.args)]) + + +def compile_lines(lines: List[str]) -> ModuleType: + """Compile lines of code to module. + + Args: + lines: Lines of python code. + + Returns: + Compiled and loaded module. + + """ + module = module_from_spec(spec_from_loader('tmpmod', loader=None)) + exec('\n'.join(lines), module.__dict__) # pylint: disable=exec-used + return module diff --git a/tests/cdr.py b/tests/cdr.py new file mode 100644 index 00000000..f1361891 --- /dev/null +++ b/tests/cdr.py @@ -0,0 +1,441 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Reference CDR message serializer and deserializer.""" + +from __future__ import annotations + +import sys +from struct import Struct, pack_into, unpack_from +from typing import TYPE_CHECKING, Dict, List, Union, cast + +import numpy + +from rosbags.serde.messages import SerdeError, get_msgdef +from rosbags.serde.typing import Msgdef +from rosbags.serde.utils import SIZEMAP, Valtype + +if TYPE_CHECKING: + from typing import Any, Tuple + + from rosbags.serde.typing import Descriptor + +Array = Union[List[Msgdef], List[str], numpy.ndarray] +BasetypeMap = Dict[str, Struct] +BASETYPEMAP_LE: BasetypeMap = { + 'bool': Struct('?'), + 'int8': Struct('b'), + 'int16': Struct('h'), + 'int32': Struct('>i'), + 'int64': Struct('>q'), + 'uint8': Struct('B'), + 'uint16': Struct('>H'), + 'uint32': Struct('>I'), + 'uint64': Struct('>Q'), + 'float32': Struct('>f'), + 'float64': Struct('>d'), +} + + +def deserialize_number(rawdata: bytes, bmap: BasetypeMap, pos: int, basetype: str) \ + -> Tuple[Union[bool, float, int], int]: + """Deserialize a single boolean, float, or int. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Read position. + basetype: Number type string. + + Returns: + Deserialized number and new read position. + + """ + dtype, size = bmap[basetype], SIZEMAP[basetype] + pos = (pos + size - 1) & -size + return dtype.unpack_from(rawdata, pos)[0], pos + size + + +def deserialize_string(rawdata: bytes, bmap: BasetypeMap, pos: int) \ + -> Tuple[str, int]: + """Deserialize a string value. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Read position. + + Returns: + Deserialized string and new read position. + + """ + pos = (pos + 4 - 1) & -4 + length = bmap['int32'].unpack_from(rawdata, pos)[0] + val = bytes(rawdata[pos + 4:pos + 4 + length - 1]) + return val.decode(), pos + 4 + length + + +def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, desc: Descriptor) \ + -> Tuple[Array, int]: + """Deserialize an array of items of same type. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Read position. + num: Number of elements. + desc: Element type descriptor. + + Returns: + Deserialized array and new read position. + + Raises: + SerdeError: Unexpected element type. + + """ + if desc.valtype == Valtype.BASE: + if desc.args == 'string': + strs = [] + while (num := num - 1) >= 0: + val, pos = deserialize_string(rawdata, bmap, pos) + strs.append(val) + return strs, pos + + size = SIZEMAP[desc.args] + pos = (pos + size - 1) & -size + ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) + if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): + ndarr = ndarr.byteswap() # no inplace on readonly array + return ndarr, pos + num * SIZEMAP[desc.args] + + if desc.valtype == Valtype.MESSAGE: + msgs = [] + while (num := num - 1) >= 0: + msg, pos = deserialize_message(rawdata, bmap, pos, desc.args) + msgs.append(msg) + return msgs, pos + + raise SerdeError(f'Nested arrays {desc!r} are not supported.') + + +def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msgdef) \ + -> Tuple[Msgdef, int]: + """Deserialize a message. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Read position. + msgdef: Message definition. + + Returns: + Deserialized message and new read position. + + """ + values: List[Any] = [] + + for _, desc in msgdef.fields: + if desc.valtype == Valtype.MESSAGE: + obj, pos = deserialize_message(rawdata, bmap, pos, desc.args) + values.append(obj) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + val, pos = deserialize_string(rawdata, bmap, pos) + values.append(val) + else: + num, pos = deserialize_number(rawdata, bmap, pos, desc.args) + values.append(num) + + elif desc.valtype == Valtype.ARRAY: + arr, pos = deserialize_array(rawdata, bmap, pos, *desc.args) + values.append(arr) + + elif desc.valtype == Valtype.SEQUENCE: + size, pos = deserialize_number(rawdata, bmap, pos, 'int32') + arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args) + values.append(arr) + + return msgdef.cls(*values), pos + + +def deserialize(rawdata: bytes, typename: str) -> Msgdef: + """Deserialize raw data into a message object. + + Args: + rawdata: Serialized data. + typename: Type to deserialize. + + Returns: + Deserialized message object. + + """ + _, little_endian = unpack_from('BB', rawdata, 0) + + msgdef = get_msgdef(typename) + obj, _ = deserialize_message( + rawdata[4:], + BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE, + 0, + msgdef, + ) + + return obj + + +def serialize_number( + rawdata: memoryview, + bmap: BasetypeMap, + pos: int, + basetype: str, + val: Union[bool, float, int], +) -> int: + """Serialize a single boolean, float, or int. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Write position. + basetype: Number type string. + val: Value to serialize. + + Returns: + Next write position. + + """ + dtype, size = bmap[basetype], SIZEMAP[basetype] + pos = (pos + size - 1) & -size + dtype.pack_into(rawdata, pos, val) + return pos + size + + +def serialize_string(rawdata: memoryview, bmap: BasetypeMap, pos: int, val: str) \ + -> int: + """Deserialize a string value. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Write position. + val: Value to serialize. + + Returns: + Next write position. + + """ + bval = memoryview(val.encode()) + length = len(bval) + 1 + + pos = (pos + 4 - 1) & -4 + bmap['int32'].pack_into(rawdata, pos, length) + rawdata[pos + 4:pos + 4 + length - 1] = bval + return pos + 4 + length + + +def serialize_array( + rawdata: memoryview, + bmap: BasetypeMap, + pos: int, + desc: Descriptor, + val: Array, +) -> int: + """Serialize an array of items of same type. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Write position. + desc: Element type descriptor. + val: Value to serialize. + + Returns: + Next write position. + + Raises: + SerdeError: Unexpected element type. + + """ + if desc.valtype == Valtype.BASE: + if desc.args == 'string': + for item in val: + pos = serialize_string(rawdata, bmap, pos, cast(str, item)) + return pos + + size = SIZEMAP[desc.args] + pos = (pos + size - 1) & -size + size *= len(val) + val = cast(numpy.ndarray, val) + if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): + val = val.byteswap() # no inplace on readonly array + rawdata[pos:pos + size] = memoryview(val.tobytes()) + return pos + size + + if desc.valtype == Valtype.MESSAGE: + for item in val: + pos = serialize_message(rawdata, bmap, pos, item, desc.args) + return pos + + raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover + + +def serialize_message( + rawdata: memoryview, + bmap: BasetypeMap, + pos: int, + message: Any, + msgdef: Msgdef, +) -> int: + """Serialize a message. + + Args: + rawdata: Serialized data. + bmap: Basetype metadata. + pos: Write position. + message: Message object. + msgdef: Message definition. + + Returns: + Next write position. + + """ + for fieldname, desc in msgdef.fields: + val = getattr(message, fieldname) + if desc.valtype == Valtype.MESSAGE: + pos = serialize_message(rawdata, bmap, pos, val, desc.args) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + pos = serialize_string(rawdata, bmap, pos, val) + else: + pos = serialize_number(rawdata, bmap, pos, desc.args, val) + + elif desc.valtype == Valtype.ARRAY: + pos = serialize_array(rawdata, bmap, pos, desc.args[1], val) + + elif desc.valtype == Valtype.SEQUENCE: + size = len(val) + pos = serialize_number(rawdata, bmap, pos, 'int32', size) + pos = serialize_array(rawdata, bmap, pos, desc.args, val) + + return pos + + +def get_array_size(desc: Descriptor, val: Array, size: int) -> int: + """Calculate size of an array. + + Args: + desc: Element type descriptor. + val: Array to calculate size of. + size: Current size of message. + + Returns: + Size of val in bytes. + + Raises: + SerdeError: Unexpected element type. + + """ + if desc.valtype == Valtype.BASE: + if desc.args == 'string': + for item in val: + size = (size + 4 - 1) & -4 + size += 4 + len(item) + 1 + return size + + isize = SIZEMAP[desc.args] + size = (size + isize - 1) & -isize + return size + isize * len(val) + + if desc.valtype == Valtype.MESSAGE: + for item in val: + size = get_size(item, desc.args, size) + return size + + raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover + + +def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int: + """Calculate size of serialzied message. + + Args: + message: Message object. + msgdef: Message definition. + size: Current size of message. + + Returns: + Size of message in bytes. + + Raises: + SerdeError: Unexpected array length in message. + + """ + for fieldname, desc in msgdef.fields: + val = getattr(message, fieldname) + if desc.valtype == Valtype.MESSAGE: + size = get_size(val, desc.args, size) + + elif desc.valtype == Valtype.BASE: + if desc.args == 'string': + size = (size + 4 - 1) & -4 + size += 4 + len(val.encode()) + 1 + else: + isize = SIZEMAP[desc.args] + size = (size + isize - 1) & -isize + size += isize + + elif desc.valtype == Valtype.ARRAY: + if len(val) != desc.args[0]: + raise SerdeError(f'Unexpected array length: {len(val)} != {desc.args[0]}.') + size = get_array_size(desc.args[1], val, size) + + elif desc.valtype == Valtype.SEQUENCE: + size = (size + 4 - 1) & -4 + size += 4 + size = get_array_size(desc.args, val, size) + + return size + + +def serialize( + message: Any, + typename: str, + little_endian: bool = sys.byteorder == 'little', +) -> memoryview: + """Serialize message object to bytes. + + Args: + message: Message object. + typename: Type to serialize. + little_endian: Should use little endianess. + + Returns: + Serialized bytes. + + """ + msgdef = get_msgdef(typename) + size = 4 + get_size(message, msgdef) + rawdata = memoryview(bytearray(size)) + + pack_into('BB', rawdata, 0, 0, little_endian) + pos = serialize_message( + rawdata[4:], + BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE, + 0, + message, + msgdef, + ) + assert pos + 4 == size + return rawdata.toreadonly() diff --git a/tests/test_serde.py b/tests/test_serde.py new file mode 100644 index 00000000..6353b277 --- /dev/null +++ b/tests/test_serde.py @@ -0,0 +1,382 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Serializer and deserializer tests.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import numpy +import pytest + +from rosbags.serde import SerdeError, deserialize_cdr, ros1_to_cdr, serialize_cdr +from rosbags.serde.messages import get_msgdef +from rosbags.typesys import get_types_from_msg, register_types + +from .cdr import deserialize, serialize + +if TYPE_CHECKING: + from typing import Any, Tuple, Union + +MSG_POLY = ( + ( + b'\x00\x01\x00\x00' # header + b'\x02\x00\x00\x00' # number of points = 2 + b'\x00\x00\x80\x3f' # x = 1 + b'\x00\x00\x00\x40' # y = 2 + b'\x00\x00\x40\x40' # z = 3 + b'\x00\x00\xa0\x3f' # x = 1.25 + b'\x00\x00\x10\x40' # y = 2.25 + b'\x00\x00\x50\x40' # z = 3.25 + ), + 'geometry_msgs/msg/Polygon', + True, +) + +MSG_MAGN = ( + ( + b'\x00\x01\x00\x00' # header + b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns + b'\x06\x00\x00\x00foo42\x00' # frameid 'foo42' + b'\x00\x00\x00\x00\x00\x00' # padding + b'\x00\x00\x00\x00\x00\x00\x60\x40' # x = 128 + b'\x00\x00\x00\x00\x00\x00\x60\x40' # y = 128 + b'\x00\x00\x00\x00\x00\x00\x60\x40' # z = 128 + b'\x00\x00\x00\x00\x00\x00\xF0\x3F' # covariance matrix = 3x3 diag + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\xF0\x3F' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\xF0\x3F' + ), + 'sensor_msgs/msg/MagneticField', + True, +) + +MSG_MAGN_BIG = ( + ( + b'\x00\x00\x00\x00' # header + b'\x00\x00\x02\xc4\x00\x00\x01\x00' # timestamp = 708s 256ns + b'\x00\x00\x00\x06foo42\x00' # frameid 'foo42' + b'\x00\x00\x00\x00\x00\x00' # padding + b'\x40\x60\x00\x00\x00\x00\x00\x00' # x = 128 + b'\x40\x60\x00\x00\x00\x00\x00\x00' # y = 128 + b'\x40\x60\x00\x00\x00\x00\x00\x00' # z = 128 + b'\x3F\xF0\x00\x00\x00\x00\x00\x00' # covariance matrix = 3x3 diag + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x3F\xF0\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x3F\xF0\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00' # garbage + ), + 'sensor_msgs/msg/MagneticField', + False, +) + +MSG_JOINT = ( + ( + b'\x00\x01\x00\x00' # header + b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns + b'\x04\x00\x00\x00bar\x00' # frameid 'bar' + b'\x02\x00\x00\x00' # number of strings + b'\x02\x00\x00\x00a\x00' # string 'a' + b'\x00\x00' # padding + b'\x02\x00\x00\x00b\x00' # string 'b' + b'\x00\x00' # padding + b'\x00\x00\x00\x00' # number of points + b'\x00\x00\x00' # garbage + ), + 'trajectory_msgs/msg/JointTrajectory', + True, +) + +MESSAGES = [MSG_POLY, MSG_MAGN, MSG_MAGN_BIG, MSG_JOINT] + +STATIC_64_64 = """ +uint64[2] u64 +""" + +STATIC_64_16 = """ +uint64 u64 +uint16 u16 +""" + +STATIC_16_64 = """ +uint16 u16 +uint64 u64 +""" + +DYNAMIC_64_64 = """ +uint64[] u64 +""" + +DYNAMIC_64_B_64 = """ +uint64 u64 +bool b +float64 f64 +""" + +DYNAMIC_64_S = """ +uint64 u64 +string s +""" + +DYNAMIC_S_64 = """ +string s +uint64 u64 +""" + +CUSTOM = """ +string base_str +float32 base_f32 +test_msgs/msg/static_64_64 msg_s66 +test_msgs/msg/static_64_16 msg_s61 +test_msgs/msg/static_16_64 msg_s16 +test_msgs/msg/dynamic_64_64 msg_d66 +test_msgs/msg/dynamic_64_b_64 msg_d6b6 +test_msgs/msg/dynamic_64_s msg_d6s +test_msgs/msg/dynamic_s_64 msg_ds6 + +string[2] arr_base_str +float32[2] arr_base_f32 +test_msgs/msg/static_64_64[2] arr_msg_s66 +test_msgs/msg/static_64_16[2] arr_msg_s61 +test_msgs/msg/static_16_64[2] arr_msg_s16 +test_msgs/msg/dynamic_64_64[2] arr_msg_d66 +test_msgs/msg/dynamic_64_b_64[2] arr_msg_d6b6 +test_msgs/msg/dynamic_64_s[2] arr_msg_d6s +test_msgs/msg/dynamic_s_64[2] arr_msg_ds6 + +string[] seq_base_str +float32[] seq_base_f32 +test_msgs/msg/static_64_64[] seq_msg_s66 +test_msgs/msg/static_64_16[] seq_msg_s61 +test_msgs/msg/static_16_64[] seq_msg_s16 +test_msgs/msg/dynamic_64_64[] seq_msg_d66 +test_msgs/msg/dynamic_64_b_64[] seq_msg_d6b6 +test_msgs/msg/dynamic_64_s[] seq_msg_d6s +test_msgs/msg/dynamic_s_64[] seq_msg_ds6 +""" + + +@pytest.fixture() +def _comparable(): + """Make messages containing numpy arrays comparable. + + Notes: + This solution is necessary as numpy.ndarray is not directly patchable. + """ + frombuffer = numpy.frombuffer + + def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool: + return (getattr(self, '_mock_wraps') == getattr(other, '_mock_wraps', other)).all() + + class CNDArray(MagicMock): + """Mock ndarray.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.__eq__ = arreq # type: ignore + + def byteswap(self, *args: Any) -> 'CNDArray': + """Wrap return value also in mock.""" + return CNDArray(wraps=self._mock_wraps.byteswap(*args)) + + def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: + return CNDArray(wraps=frombuffer(*args, **kwargs)) + + with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer): + yield + + +@pytest.mark.parametrize('message', MESSAGES) +def test_serde(message: Tuple[bytes, str, bool]): + """Test serialization deserialization roundtrip.""" + rawdata, typ, is_little = message + + serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little) + assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little) + assert serdeser == rawdata[0:len(serdeser)] + assert len(rawdata) - len(serdeser) < 4 + assert all(x == 0 for x in rawdata[len(serdeser):]) + + +@pytest.mark.usefixtures('_comparable') +def test_deserializer(): + """Test deserializer.""" + msg = deserialize_cdr(*MSG_POLY[:2]) + assert msg == deserialize(*MSG_POLY[:2]) + assert len(msg.points) == 2 + assert msg.points[0].x == 1 + assert msg.points[0].y == 2 + assert msg.points[0].z == 3 + assert msg.points[1].x == 1.25 + assert msg.points[1].y == 2.25 + assert msg.points[1].z == 3.25 + + msg = deserialize_cdr(*MSG_MAGN[:2]) + assert msg == deserialize(*MSG_MAGN[:2]) + assert 'MagneticField' in repr(msg) + assert msg.header.stamp.sec == 708 + assert msg.header.stamp.nanosec == 256 + assert msg.header.frame_id == 'foo42' + field = msg.magnetic_field + assert (field.x, field.y, field.z) == (128., 128., 128.) + assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all() + + msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2]) + assert msg_big == deserialize(*MSG_MAGN_BIG[:2]) + assert msg.magnetic_field == msg_big.magnetic_field + + +@pytest.mark.usefixtures('_comparable') +def test_serializer(): + """Test serializer.""" + + class Foo: # pylint: disable=too-few-public-methods + """Dummy class.""" + + data = 7 + + msg = Foo() + ret = serialize_cdr(msg, 'std_msgs/msg/Int8', True) + assert ret == serialize(msg, 'std_msgs/msg/Int8', True) + assert ret == b'\x00\x01\x00\x00\x07' + + ret = serialize_cdr(msg, 'std_msgs/msg/Int8', False) + assert ret == serialize(msg, 'std_msgs/msg/Int8', False) + assert ret == b'\x00\x00\x00\x00\x07' + + ret = serialize_cdr(msg, 'std_msgs/msg/Int16', True) + assert ret == serialize(msg, 'std_msgs/msg/Int16', True) + assert ret == b'\x00\x01\x00\x00\x07\x00' + + ret = serialize_cdr(msg, 'std_msgs/msg/Int16', False) + assert ret == serialize(msg, 'std_msgs/msg/Int16', False) + assert ret == b'\x00\x00\x00\x00\x00\x07' + + +@pytest.mark.usefixtures('_comparable') +def test_serializer_errors(): + """Test seralizer with broken messages.""" + + class Foo: # pylint: disable=too-few-public-methods + """Dummy class.""" + + coef = numpy.array([1, 2, 3, 4]) + + msg = Foo() + ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True) + assert ret == serialize(msg, 'shape_msgs/msg/Plane', True) + + msg.coef = numpy.array([1, 2, 3, 4, 4]) + with pytest.raises(SerdeError, match='array length'): + serialize_cdr(msg, 'shape_msgs/msg/Plane', True) + + +@pytest.mark.usefixtures('_comparable') +def test_custom_type(): + """Test custom type.""" + cname = 'test_msgs/msg/custom' + register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64'))) + register_types(dict(get_types_from_msg(STATIC_64_16, 'test_msgs/msg/static_64_16'))) + register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) + register_types(dict(get_types_from_msg(DYNAMIC_64_64, 'test_msgs/msg/dynamic_64_64'))) + register_types(dict(get_types_from_msg(DYNAMIC_64_B_64, 'test_msgs/msg/dynamic_64_b_64'))) + register_types(dict(get_types_from_msg(DYNAMIC_64_S, 'test_msgs/msg/dynamic_64_s'))) + register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) + register_types(dict(get_types_from_msg(CUSTOM, cname))) + + static_64_64 = get_msgdef('test_msgs/msg/static_64_64').cls + static_64_16 = get_msgdef('test_msgs/msg/static_64_16').cls + static_16_64 = get_msgdef('test_msgs/msg/static_16_64').cls + dynamic_64_64 = get_msgdef('test_msgs/msg/dynamic_64_64').cls + dynamic_64_b_64 = get_msgdef('test_msgs/msg/dynamic_64_b_64').cls + dynamic_64_s = get_msgdef('test_msgs/msg/dynamic_64_s').cls + dynamic_s_64 = get_msgdef('test_msgs/msg/dynamic_s_64').cls + custom = get_msgdef('test_msgs/msg/custom').cls + + msg = custom( + 'str', + 1.5, + static_64_64(numpy.array([64, 64], dtype=numpy.uint64)), + static_64_16(64, 16), + static_16_64(16, 64), + dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)), + dynamic_64_b_64(64, True, 1.25), + dynamic_64_s(64, 's'), + dynamic_s_64('s', 64), + # arrays + ['str_1', ''], + numpy.array([1.5, 0.75], dtype=numpy.float32), + [ + static_64_64(numpy.array([64, 64], dtype=numpy.uint64)), + static_64_64(numpy.array([64, 64], dtype=numpy.uint64)), + ], + [static_64_16(64, 16), static_64_16(64, 16)], + [static_16_64(16, 64), static_16_64(16, 64)], + [ + dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)), + dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)), + ], + [ + dynamic_64_b_64(64, True, 1.25), + dynamic_64_b_64(64, True, 1.25), + ], + [dynamic_64_s(64, 's'), dynamic_64_s(64, 's')], + [dynamic_s_64('s', 64), dynamic_s_64('s', 64)], + # sequences + ['str_1', ''], + numpy.array([1.5, 0.75], dtype=numpy.float32), + [ + static_64_64(numpy.array([64, 64], dtype=numpy.uint64)), + static_64_64(numpy.array([64, 64], dtype=numpy.uint64)), + ], + [static_64_16(64, 16), static_64_16(64, 16)], + [static_16_64(16, 64), static_16_64(16, 64)], + [ + dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)), + dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)), + ], + [ + dynamic_64_b_64(64, True, 1.25), + dynamic_64_b_64(64, True, 1.25), + ], + [dynamic_64_s(64, 's'), dynamic_64_s(64, 's')], + [dynamic_s_64('s', 64), dynamic_s_64('s', 64)], + ) + + res = deserialize_cdr(serialize_cdr(msg, cname), cname) + assert res == deserialize(serialize(msg, cname), cname) + assert res == msg + + +def test_ros1_to_cdr(): + """Test ROS1 to CDR conversion.""" + register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) + msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_cdr = ( + b'\x00\x01\x00\x00' + b'\x01\x00' + b'\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x02' + ) + assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr + + register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) + msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_cdr = ( + b'\x00\x01\x00\x00' + b'\x02\x00\x00\x00X\x00' + b'\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x02' + ) + assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr