From 9f3da0c2be974d9060c5337408f9d909fe0eb736 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Wed, 4 May 2022 17:40:50 +0200 Subject: [PATCH] Fix serialization of empty message sequences --- src/rosbags/serde/cdr.py | 6 ++++-- tests/test_serde.py | 23 +++++++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/rosbags/serde/cdr.py b/src/rosbags/serde/cdr.py index 894cc5fc..0ef393af 100644 --- a/src/rosbags/serde/cdr.py +++ b/src/rosbags/serde/cdr.py @@ -133,7 +133,8 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]: lines.append(f' val = message.{fieldname}') if subdesc.args.size_cdr: if aligned < anext_before <= anext_after: - lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + lines.append(' if len(val):') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' for _ in val:') if anext_before > anext_after: lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') @@ -144,7 +145,8 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]: f' func = get_msgdef("{subdesc.args.name}", typestore).getsize_cdr', ) if aligned < anext_before <= anext_after: - lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + lines.append(' if len(val):') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' for item in val:') if anext_before > anext_after: lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') diff --git a/tests/test_serde.py b/tests/test_serde.py index 276bcd6c..8142ac79 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -180,6 +180,11 @@ uint64[] su64 uint64 u64 """ +SMSG_U64 = """ +su64_u64[] seq +uint64 u64 +""" + @pytest.fixture() def _comparable() -> Generator[None, None, None]: @@ -457,16 +462,22 @@ def test_padding_empty_sequence() -> None: def test_align_after_empty_sequence() -> None: """Test alignment after empty sequences.""" register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64'))) + register_types(dict(get_types_from_msg(SMSG_U64, 'test_msgs/msg/smsg_u64'))) - su64_b = get_msgdef('test_msgs/msg/su64_u64', types).cls - msg = su64_b(numpy.array([], dtype=numpy.uint64), 42) + su64_u64 = get_msgdef('test_msgs/msg/su64_u64', types).cls + smsg_u64 = get_msgdef('test_msgs/msg/smsg_u64', types).cls + msg1 = su64_u64(numpy.array([], dtype=numpy.uint64), 42) + msg2 = smsg_u64([], 42) - cdr = serialize_cdr(msg, msg.__msgtype__) + cdr = serialize_cdr(msg1, msg1.__msgtype__) assert cdr[4:] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00' + assert serialize_cdr(msg2, msg2.__msgtype__) == cdr - ros1 = cdr_to_ros1(cdr, msg.__msgtype__) + ros1 = cdr_to_ros1(cdr, msg1.__msgtype__) assert ros1 == b'\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00' + assert cdr_to_ros1(cdr, msg2.__msgtype__) == ros1 - assert ros1_to_cdr(ros1, msg.__msgtype__) == cdr + assert ros1_to_cdr(ros1, msg1.__msgtype__) == cdr - assert deserialize_cdr(cdr, msg.__msgtype__) == msg + assert deserialize_cdr(cdr, msg1.__msgtype__) == msg1 + assert deserialize_cdr(cdr, msg2.__msgtype__) == msg2