From 16d1758327469047e8559361465148fbfc3d3081 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Wed, 13 Apr 2022 09:40:22 +0200 Subject: [PATCH] Unify rosbag1 and rosbag2 connection class --- docs/examples/edit_rosbags_edit_timestamps.py | 8 +- docs/examples/edit_rosbags_remove_topic.py | 8 +- src/rosbags/convert/converter.py | 89 ++++++--- src/rosbags/interfaces/__init__.py | 36 ++++ src/rosbags/interfaces/py.typed | 0 src/rosbags/rosbag1/reader.py | 27 ++- src/rosbags/rosbag1/writer.py | 19 +- src/rosbags/rosbag2/connection.py | 18 -- src/rosbags/rosbag2/reader.py | 15 +- src/rosbags/rosbag2/writer.py | 32 +-- tests/test_convert.py | 185 +++++++++++++----- tests/test_roundtrip.py | 4 +- tests/test_writer.py | 10 +- 13 files changed, 301 insertions(+), 150 deletions(-) create mode 100644 src/rosbags/interfaces/__init__.py create mode 100644 src/rosbags/interfaces/py.typed delete mode 100644 src/rosbags/rosbag2/connection.py diff --git a/docs/examples/edit_rosbags_edit_timestamps.py b/docs/examples/edit_rosbags_edit_timestamps.py index 36e2cdc1..6de70aae 100644 --- a/docs/examples/edit_rosbags_edit_timestamps.py +++ b/docs/examples/edit_rosbags_edit_timestamps.py @@ -2,8 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast +from rosbags.interfaces import ConnectionExtRosbag2 from rosbags.rosbag2 import Reader, Writer from rosbags.serde import deserialize_cdr, serialize_cdr @@ -23,11 +24,12 @@ def offset_timestamps(src: Path, dst: Path, offset: int) -> None: with Reader(src) as reader, Writer(dst) as writer: conn_map = {} for conn in reader.connections.values(): + ext = cast(ConnectionExtRosbag2, conn.ext) conn_map[conn.id] = writer.add_connection( conn.topic, conn.msgtype, - conn.serialization_format, - conn.offered_qos_profiles, + ext.serialization_format, + ext.offered_qos_profiles, ) for conn, timestamp, data in reader.messages(): diff --git a/docs/examples/edit_rosbags_remove_topic.py b/docs/examples/edit_rosbags_remove_topic.py index 0e416a84..aa00dabd 100644 --- a/docs/examples/edit_rosbags_remove_topic.py +++ b/docs/examples/edit_rosbags_remove_topic.py @@ -2,8 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast +from rosbags.interfaces import ConnectionExtRosbag2 from rosbags.rosbag2 import Reader, Writer if TYPE_CHECKING: @@ -24,11 +25,12 @@ def remove_topic(src: Path, dst: Path, topic: str) -> None: for conn in reader.connections.values(): if conn.topic == topic: continue + ext = cast(ConnectionExtRosbag2, conn.ext) conn_map[conn.id] = writer.add_connection( conn.topic, conn.msgtype, - conn.serialization_format, - conn.offered_qos_profiles, + ext.serialization_format, + ext.offered_qos_profiles, ) rconns = [reader.connections[x] for x in conn_map] diff --git a/src/rosbags/convert/converter.py b/src/rosbags/convert/converter.py index 03af6168..ce47674c 100644 --- a/src/rosbags/convert/converter.py +++ b/src/rosbags/convert/converter.py @@ -4,19 +4,17 @@ from __future__ import annotations -from dataclasses import asdict from typing import TYPE_CHECKING +from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2 from rosbags.rosbag1 import Reader as Reader1 from rosbags.rosbag1 import ReaderError as ReaderError1 from rosbags.rosbag1 import Writer as Writer1 from rosbags.rosbag1 import WriterError as WriterError1 -from rosbags.rosbag1.reader import Connection as Connection1 from rosbags.rosbag2 import Reader as Reader2 from rosbags.rosbag2 import ReaderError as ReaderError2 from rosbags.rosbag2 import Writer as Writer2 from rosbags.rosbag2 import WriterError as WriterError2 -from rosbags.rosbag2.connection import Connection as Connection2 from rosbags.serde import cdr_to_ros1, ros1_to_cdr from rosbags.typesys import get_types_from_msg, register_types from rosbags.typesys.msg import generate_msgdef @@ -48,7 +46,7 @@ class ConverterError(Exception): """Converter Error.""" -def upgrade_connection(rconn: Connection1) -> Connection2: +def upgrade_connection(rconn: Connection) -> Connection: """Convert rosbag1 connection to rosbag2 connection. Args: @@ -58,17 +56,22 @@ def upgrade_connection(rconn: Connection1) -> Connection2: Rosbag2 connection. """ - return Connection2( - -1, - 0, + assert isinstance(rconn.ext, ConnectionExtRosbag1) + return Connection( + rconn.id, rconn.topic, rconn.msgtype, - 'cdr', - LATCH if rconn.latching else '', + '', + '', + 0, + ConnectionExtRosbag2( + 'cdr', + LATCH if rconn.ext.latching else '', + ), ) -def downgrade_connection(rconn: Connection2) -> Connection1: +def downgrade_connection(rconn: Connection) -> Connection: """Convert rosbag2 connection to rosbag1 connection. Args: @@ -78,15 +81,19 @@ def downgrade_connection(rconn: Connection2) -> Connection1: Rosbag1 connection. """ + assert isinstance(rconn.ext, ConnectionExtRosbag2) msgdef, md5sum = generate_msgdef(rconn.msgtype) - return Connection1( - -1, + return Connection( + rconn.id, rconn.topic, rconn.msgtype, msgdef, md5sum, - None, - int('durability: 1' in rconn.offered_qos_profiles), + -1, + ConnectionExtRosbag1( + None, + int('durability: 1' in rconn.ext.offered_qos_profiles), + ), ) @@ -100,13 +107,26 @@ def convert_1to2(src: Path, dst: Path) -> None: """ with Reader1(src) as reader, Writer2(dst) as writer: typs: dict[str, Any] = {} - connmap: dict[int, Connection2] = {} + connmap: dict[int, Connection] = {} for rconn in reader.connections.values(): candidate = upgrade_connection(rconn) - existing = next((x for x in writer.connections.values() if x == candidate), None) - wconn = existing if existing else writer.add_connection(**asdict(candidate)) - connmap[rconn.id] = wconn + assert isinstance(candidate.ext, ConnectionExtRosbag2) + for conn in writer.connections.values(): + assert isinstance(conn.ext, ConnectionExtRosbag2) + if ( + conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and + conn.ext == candidate.ext + ): + break + else: + conn = writer.add_connection( + candidate.topic, + candidate.msgtype, + candidate.ext.serialization_format, + candidate.ext.offered_qos_profiles, + ) + connmap[rconn.id] = conn typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype)) register_types(typs) @@ -124,22 +144,27 @@ def convert_2to1(src: Path, dst: Path) -> None: """ with Reader2(src) as reader, Writer1(dst) as writer: - connmap: dict[int, Connection1] = {} + connmap: dict[int, Connection] = {} for rconn in reader.connections.values(): candidate = downgrade_connection(rconn) - # yapf: disable - existing = next( - ( - x - for x in writer.connections.values() - if x.topic == candidate.topic - if x.md5sum == candidate.md5sum - if x.latching == candidate.latching - ), - None, - ) - # yapf: enable - connmap[rconn.id] = existing if existing else writer.add_connection(*candidate[1:]) + assert isinstance(candidate.ext, ConnectionExtRosbag1) + for conn in writer.connections.values(): + assert isinstance(conn.ext, ConnectionExtRosbag1) + if ( + conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and + conn.ext.latching == candidate.ext.latching + ): + break + else: + conn = writer.add_connection( + candidate.topic, + candidate.msgtype, + candidate.msgdef, + candidate.md5sum, + candidate.ext.callerid, + candidate.ext.latching, + ) + connmap[rconn.id] = conn for rconn, timestamp, data in reader.messages(): data = cdr_to_ros1(data, rconn.msgtype) diff --git a/src/rosbags/interfaces/__init__.py b/src/rosbags/interfaces/__init__.py new file mode 100644 index 00000000..df5b221c --- /dev/null +++ b/src/rosbags/interfaces/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2020-2022 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Shared interfaces.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from typing import Optional, Union + + +class ConnectionExtRosbag1(NamedTuple): + """Rosbag1 specific connection extensions.""" + + callerid: Optional[str] + latching: Optional[int] + + +class ConnectionExtRosbag2(NamedTuple): + """Rosbag2 specific connection extensions.""" + + serialization_format: str + offered_qos_profiles: str + + +class Connection(NamedTuple): + """Connection information.""" + + id: int + topic: str + msgtype: str + msgdef: str + md5sum: str + msgcount: int + ext: Union[ConnectionExtRosbag1, ConnectionExtRosbag2] diff --git a/src/rosbags/interfaces/py.typed b/src/rosbags/interfaces/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/rosbags/rosbag1/reader.py b/src/rosbags/rosbag1/reader.py index 8798c529..99ec9014 100644 --- a/src/rosbags/rosbag1/reader.py +++ b/src/rosbags/rosbag1/reader.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, NamedTuple from lz4.frame import decompress as lz4_decompress +from rosbags.interfaces import Connection, ConnectionExtRosbag1 from rosbags.typesys.msg import normalize_msgtype if TYPE_CHECKING: @@ -50,18 +51,6 @@ class RecordType(IntEnum): CONNECTION = 7 -class Connection(NamedTuple): - """Connection information.""" - - id: int - topic: str - msgtype: str - msgdef: str - md5sum: str - callerid: Optional[str] - latching: Optional[int] - - class ChunkInfo(NamedTuple): """Chunk information.""" @@ -427,6 +416,13 @@ class Reader: } assert all(self.indexes[x] for x in self.connections) + for cid, connection in self.connections.items(): + self.connections[cid] = Connection( + *connection[0:5], + len(self.indexes[cid]), + connection[6], + ) + self.topics = {} for topic, group in groupby( sorted(self.connections.values(), key=lambda x: x.topic), @@ -499,8 +495,11 @@ class Reader: normalize_msgtype(typ), msgdef, md5sum, - callerid, - latching, + 0, + ConnectionExtRosbag1( + callerid, + latching, + ), ) def read_chunk_info(self) -> ChunkInfo: diff --git a/src/rosbags/rosbag1/writer.py b/src/rosbags/rosbag1/writer.py index 8063757f..f71e0366 100644 --- a/src/rosbags/rosbag1/writer.py +++ b/src/rosbags/rosbag1/writer.py @@ -15,9 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict from lz4.frame import compress as lz4_compress +from rosbags.interfaces import Connection, ConnectionExtRosbag1 from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef -from .reader import Connection, RecordType +from .reader import RecordType if TYPE_CHECKING: from types import TracebackType @@ -249,8 +250,11 @@ class Writer: denormalize_msgtype(msgtype), msgdef, md5sum, - callerid, - latching, + -1, + ConnectionExtRosbag1( + callerid, + latching, + ), ) if any(x[1:] == connection[1:] for x in self.connections.values()): @@ -314,10 +318,11 @@ class Writer: header.set_string('type', connection.msgtype) header.set_string('md5sum', connection.md5sum) header.set_string('message_definition', connection.msgdef) - if connection.callerid is not None: - header.set_string('callerid', connection.callerid) - if connection.latching is not None: - header.set_string('latching', str(connection.latching)) + assert isinstance(connection.ext, ConnectionExtRosbag1) + if connection.ext.callerid is not None: + header.set_string('callerid', connection.ext.callerid) + if connection.ext.latching is not None: + header.set_string('latching', str(connection.ext.latching)) header.write(bio) def write_chunk(self, chunk: WriteChunk) -> None: diff --git a/src/rosbags/rosbag2/connection.py b/src/rosbags/rosbag2/connection.py deleted file mode 100644 index 73d5daa3..00000000 --- a/src/rosbags/rosbag2/connection.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2020-2022 Ternaris. -# SPDX-License-Identifier: Apache-2.0 -"""Rosbag2 connection.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - - -@dataclass -class Connection: - """Connection metadata.""" - id: int = field(compare=False) # pylint: disable=invalid-name - count: int = field(compare=False) - topic: str - msgtype: str - serialization_format: str - offered_qos_profiles: str diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index b3145aa3..72ea005d 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -14,7 +14,7 @@ import zstandard from ruamel.yaml import YAML from ruamel.yaml.error import YAMLError -from .connection import Connection +from rosbags.interfaces import Connection, ConnectionExtRosbag2 if TYPE_CHECKING: from types import TracebackType @@ -139,15 +139,20 @@ class Reader: self.connections = { idx + 1: Connection( id=idx + 1, - count=x['message_count'], topic=x['topic_metadata']['name'], msgtype=x['topic_metadata']['type'], - serialization_format=x['topic_metadata']['serialization_format'], - offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''), + msgdef='', + md5sum='', + msgcount=x['message_count'], + ext=ConnectionExtRosbag2( + serialization_format=x['topic_metadata']['serialization_format'], + offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''), + ), ) for idx, x in enumerate(self.metadata['topics_with_message_count']) } noncdr = { - fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr' + fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) + if (fmt := x.ext.serialization_format) != 'cdr' } if noncdr: raise ReaderError(f'Serialization format {noncdr!r} is not supported.') diff --git a/src/rosbags/rosbag2/writer.py b/src/rosbags/rosbag2/writer.py index 027647ac..f7263148 100644 --- a/src/rosbags/rosbag2/writer.py +++ b/src/rosbags/rosbag2/writer.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING import zstandard from ruamel.yaml import YAML -from .connection import Connection +from rosbags.interfaces import Connection, ConnectionExtRosbag2 if TYPE_CHECKING: from types import TracebackType @@ -82,6 +82,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_format = '' self.compressor: Optional[zstandard.ZstdCompressor] = None self.connections: dict[int, Connection] = {} + self.counts: dict[int, int] = {} self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None @@ -152,16 +153,25 @@ class Writer: # pylint: disable=too-many-instance-attributes connection = Connection( id=len(self.connections.values()) + 1, - count=0, topic=topic, msgtype=msgtype, - serialization_format=serialization_format, - offered_qos_profiles=offered_qos_profiles, + msgdef='', + md5sum='', + msgcount=0, + ext=ConnectionExtRosbag2( + serialization_format=serialization_format, + offered_qos_profiles=offered_qos_profiles, + ), ) - if connection in self.connections.values(): - raise WriterError(f'Connection can only be added once: {connection!r}.') + for conn in self.connections.values(): + if ( + conn.topic == connection.topic and conn.msgtype == connection.msgtype and + conn.ext == connection.ext + ): + raise WriterError(f'Connection can only be added once: {connection!r}.') self.connections[connection.id] = connection + self.counts[connection.id] = 0 meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles) self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) return connection @@ -191,7 +201,7 @@ class Writer: # pylint: disable=too-many-instance-attributes 'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)', (connection.id, timestamp, data), ) - connection.count += 1 + self.counts[connection.id] += 1 def close(self) -> None: """Close rosbag2 after writing. @@ -237,11 +247,11 @@ class Writer: # pylint: disable=too-many-instance-attributes 'topic_metadata': { 'name': x.topic, 'type': x.msgtype, - 'serialization_format': x.serialization_format, - 'offered_qos_profiles': x.offered_qos_profiles, + 'serialization_format': x.ext.serialization_format, + 'offered_qos_profiles': x.ext.offered_qos_profiles, }, - 'message_count': x.count, - } for x in self.connections.values() + 'message_count': self.counts[x.id], + } for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) ], 'compression_format': self.compression_format, 'compression_mode': self.compression_mode, diff --git a/tests/test_convert.py b/tests/test_convert.py index fa93351b..6d030d8f 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -2,18 +2,25 @@ # SPDX-License-Identifier: Apache-2.0 """Rosbag1to2 converter tests.""" +from __future__ import annotations + import sys from pathlib import Path -from unittest.mock import Mock, call, patch +from typing import TYPE_CHECKING +from unittest.mock import call, patch import pytest from rosbags.convert import ConverterError, convert from rosbags.convert.__main__ import main from rosbags.convert.converter import LATCH +from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2 from rosbags.rosbag1 import ReaderError from rosbags.rosbag2 import WriterError +if TYPE_CHECKING: + from typing import Any + def test_cliwrapper(tmp_path: Path) -> None: """Test cli wrapper.""" @@ -114,71 +121,83 @@ def test_convert_1to2(tmp_path: Path) -> None: patch('rosbags.convert.converter.register_types') as register_types, \ patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr: + readerinst = reader.return_value.__enter__.return_value + writerinst = writer.return_value.__enter__.return_value + connections = [ - Mock(topic='/topic', msgtype='typ', latching=False), - Mock(topic='/topic', msgtype='typ', latching=True), + Connection(1, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)), + Connection(2, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, True)), + Connection(3, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)), + Connection(4, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1('caller', False)), ] wconnections = [ - Mock(topic='/topic', msgtype='typ'), - Mock(topic='/topic', msgtype='typ'), + Connection(1, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')), + Connection(2, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', LATCH)), + Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')), ] - reader.return_value.__enter__.return_value.connections = { + readerinst.connections = { 1: connections[0], 2: connections[1], + 3: connections[2], + 4: connections[3], } - reader.return_value.__enter__.return_value.messages.return_value = [ + readerinst.messages.return_value = [ (connections[0], 42, b'\x42'), (connections[1], 43, b'\x43'), + (connections[2], 44, b'\x44'), + (connections[3], 45, b'\x45'), ] - writer.return_value.__enter__.return_value.add_connection.side_effect = [ - wconnections[0], - wconnections[1], - ] + writerinst.connections = {} + + def add_connection(*_: Any) -> Connection: # noqa: ANN401 + """Mock for Writer.add_connection.""" + writerinst.connections = { + conn.id: conn + for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) + } + return wconnections[len(writerinst.connections) - 1] + + writerinst.add_connection.side_effect = add_connection ros1_to_cdr.return_value = b'666' convert(Path('foo.bag'), None) reader.assert_called_with(Path('foo.bag')) - reader.return_value.__enter__.return_value.messages.assert_called_with() + readerinst.messages.assert_called_with() writer.assert_called_with(Path('foo')) - writer.return_value.__enter__.return_value.add_connection.assert_has_calls( + writerinst.add_connection.assert_has_calls( [ - call( - id=-1, - count=0, - topic='/topic', - msgtype='typ', - serialization_format='cdr', - offered_qos_profiles='', - ), - call( - id=-1, - count=0, - topic='/topic', - msgtype='typ', - serialization_format='cdr', - offered_qos_profiles=LATCH, - ), + call('/topic', 'typ', 'cdr', ''), + call('/topic', 'typ', 'cdr', LATCH), + call('/other', 'typ', 'cdr', ''), ], ) - writer.return_value.__enter__.return_value.write.assert_has_calls( - [call(wconnections[0], 42, b'666'), - call(wconnections[1], 43, b'666')], + writerinst.write.assert_has_calls( + [ + call(wconnections[0], 42, b'666'), + call(wconnections[1], 43, b'666'), + call(wconnections[2], 44, b'666'), + call(wconnections[2], 45, b'666'), + ], ) register_types.assert_called_with({'typ': 'def'}) - ros1_to_cdr.assert_has_calls([call(b'\x42', 'typ'), call(b'\x43', 'typ')]) + ros1_to_cdr.assert_has_calls( + [ + call(b'\x42', 'typ'), + call(b'\x43', 'typ'), + call(b'\x44', 'typ'), + call(b'\x45', 'typ'), + ], + ) - writer.return_value.__enter__.return_value.add_connection.side_effect = [ - wconnections[0], - wconnections[1], - ] + writerinst.connections.clear() ros1_to_cdr.side_effect = KeyError('exc') with pytest.raises(ConverterError, match='Converting rosbag: .*exc'): convert(Path('foo.bag'), None) @@ -204,30 +223,79 @@ def test_convert_2to1(tmp_path: Path) -> None: patch('rosbags.convert.converter.Writer1') as writer, \ patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1: + readerinst = reader.return_value.__enter__.return_value + writerinst = writer.return_value.__enter__.return_value + connections = [ - Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=''), - Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=LATCH), + Connection(1, '/topic', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')), + Connection( + 2, + '/topic', + 'std_msgs/msg/Bool', + '', + '', + -1, + ConnectionExtRosbag2('', LATCH), + ), + Connection(3, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')), + Connection(4, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '0')), ] wconnections = [ - Mock(topic='/topic', msgtype='typ'), - Mock(topic='/topic', msgtype='typ'), + Connection( + 1, + '/topic', + 'std_msgs/msg/Bool', + '', + '8b94c1b53db61fb6aed406028ad6332a', + -1, + ConnectionExtRosbag1(None, False), + ), + Connection( + 2, + '/topic', + 'std_msgs/msg/Bool', + '', + '8b94c1b53db61fb6aed406028ad6332a', + -1, + ConnectionExtRosbag1(None, True), + ), + Connection( + 3, + '/other', + 'std_msgs/msg/Bool', + '', + '8b94c1b53db61fb6aed406028ad6332a', + -1, + ConnectionExtRosbag1(None, False), + ), ] - reader.return_value.__enter__.return_value.connections = { + readerinst.connections = { 1: connections[0], 2: connections[1], + 3: connections[2], + 4: connections[3], } - reader.return_value.__enter__.return_value.messages.return_value = [ + readerinst.messages.return_value = [ (connections[0], 42, b'\x42'), (connections[1], 43, b'\x43'), + (connections[2], 44, b'\x44'), + (connections[3], 45, b'\x45'), ] - writer.return_value.__enter__.return_value.add_connection.side_effect = [ - wconnections[0], - wconnections[1], - ] + writerinst.connections = {} + + def add_connection(*_: Any) -> Connection: # noqa: ANN401 + """Mock for Writer.add_connection.""" + writerinst.connections = { + conn.id: conn + for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) + } + return wconnections[len(writerinst.connections) - 1] + + writerinst.add_connection.side_effect = add_connection cdr_to_ros1.return_value = b'666' @@ -255,24 +323,35 @@ def test_convert_2to1(tmp_path: Path) -> None: None, 1, ), + call( + '/other', + 'std_msgs/msg/Bool', + 'bool data\n', + '8b94c1b53db61fb6aed406028ad6332a', + None, + 0, + ), ], ) writer.return_value.__enter__.return_value.write.assert_has_calls( - [call(wconnections[0], 42, b'666'), - call(wconnections[1], 43, b'666')], + [ + call(wconnections[0], 42, b'666'), + call(wconnections[1], 43, b'666'), + call(wconnections[2], 44, b'666'), + call(wconnections[2], 45, b'666'), + ], ) cdr_to_ros1.assert_has_calls( [ call(b'\x42', 'std_msgs/msg/Bool'), call(b'\x43', 'std_msgs/msg/Bool'), + call(b'\x44', 'std_msgs/msg/Bool'), + call(b'\x45', 'std_msgs/msg/Bool'), ], ) - writer.return_value.__enter__.return_value.add_connection.side_effect = [ - wconnections[0], - wconnections[1], - ] + writerinst.connections.clear() cdr_to_ros1.side_effect = KeyError('exc') with pytest.raises(ConverterError, match='Converting rosbag: .*exc'): convert(Path('foo'), None) diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 6743deed..d3ce5d44 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -36,7 +36,9 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None: with rbag: gen = rbag.messages() rconnection, _, raw = next(gen) - assert rconnection == wconnection + assert rconnection.topic == wconnection.topic + assert rconnection.msgtype == wconnection.msgtype + assert rconnection.ext == wconnection.ext msg = deserialize_cdr(raw, rconnection.msgtype) assert getattr(msg, 'data', None) == Foo.data with pytest.raises(StopIteration): diff --git a/tests/test_writer.py b/tests/test_writer.py index 7dc7ee4d..5d2206a8 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -8,8 +8,8 @@ from typing import TYPE_CHECKING import pytest +from rosbags.interfaces import Connection, ConnectionExtRosbag2 from rosbags.rosbag2 import Writer, WriterError -from rosbags.rosbag2.connection import Connection if TYPE_CHECKING: from pathlib import Path @@ -81,7 +81,11 @@ def test_failure_cases(tmp_path: Path) -> None: bag = Writer(tmp_path / 'write') with pytest.raises(WriterError, match='was not opened'): - bag.write(Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', ''), 0, b'') + bag.write( + Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', '')), + 0, + b'', + ) bag = Writer(tmp_path / 'topic') bag.open() @@ -91,6 +95,6 @@ def test_failure_cases(tmp_path: Path) -> None: bag = Writer(tmp_path / 'notopic') bag.open() - connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '') + connection = Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', '')) with pytest.raises(WriterError, match='unknown connection'): bag.write(connection, 42, b'\x00')