From f33e65b14ae17a7f0988630b4f1bf237645ff760 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Sun, 1 Aug 2021 18:22:36 +0200 Subject: [PATCH] Change to connection oriented reader API --- README.rst | 4 +- docs/topics/rosbag1.rst | 15 ++--- docs/topics/rosbag2.rst | 31 +++++----- src/rosbags/convert/converter.py | 48 +++++++++++----- src/rosbags/rosbag1/reader.py | 15 ++--- src/rosbags/rosbag2/connection.py | 18 ++++++ src/rosbags/rosbag2/reader.py | 94 ++++++++++++++++++++----------- src/rosbags/rosbag2/writer.py | 72 ++++++++++++++--------- tests/test_convert.py | 51 +++++++++++++---- tests/test_reader.py | 51 +++++++++-------- tests/test_reader1.py | 18 +++--- tests/test_roundtrip.py | 9 +-- tests/test_writer.py | 36 ++++++------ 13 files changed, 290 insertions(+), 172 deletions(-) create mode 100644 src/rosbags/rosbag2/connection.py diff --git a/README.rst b/README.rst index 7a8e2a1b..f25b40c1 100644 --- a/README.rst +++ b/README.rst @@ -41,8 +41,8 @@ Read and deserialize rosbag2 messages: # create reader instance and open for reading with Reader('/home/ros/rosbag_2020_03_24') as reader: - for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): - msg = deserialize_cdr(rawdata, msgtype) + for connection, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): + msg = deserialize_cdr(rawdata, connection.msgtype) print(msg.header.frame_id) diff --git a/docs/topics/rosbag1.rst b/docs/topics/rosbag1.rst index 02ab7144..33d5fdd3 100644 --- a/docs/topics/rosbag1.rst +++ b/docs/topics/rosbag1.rst @@ -13,15 +13,16 @@ Instances of the :py:class:`Reader ` class are typically # create reader instance with Reader('/home/ros/rosbag_2020_03_24.bag') as reader: - # topic and msgtype information is available on .topics dictionary - for topic, info in reader.topics.items(): - print(topic, info) + # topic and msgtype information is available on .connections dictionary + for connection in reader.connections.values(): + print(connection.topic, connection.msgtype) # iterate over messages - for topic, msgtype, timestamp, rawdata in reader.messages(): - if topic == '/imu_raw/Imu': + for connection, timestamp, rawdata in reader.messages(): + if connection.topic == '/imu_raw/Imu': print(timestamp) - # messages() accepts topic filters - for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): + # messages() accepts connection filters + connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] + for connection, timestamp, rawdata in reader.messages(connections=connections): print(timestamp) diff --git a/docs/topics/rosbag2.rst b/docs/topics/rosbag2.rst index 68c7edb2..3e788bb3 100644 --- a/docs/topics/rosbag2.rst +++ b/docs/topics/rosbag2.rst @@ -27,16 +27,18 @@ Instances of the :py:class:`Writer ` class can create an from rosbags.rosbag2 import Writer from rosbags.serde import serialize_cdr + from rosbags.typesys.types import std_msgs__msg__String as String # create writer instance and open for writing with Writer('/home/ros/rosbag_2020_03_24') as writer: - # add new topic - topic = '/imu_raw/Imu' - msgtype = 'sensor_msgs/msg/Imu' - writer.add_topic(topic, msgtype, 'cdr') + # add new connection + topic = '/chatter' + msgtype = String.__msgtype__ + connection = writer.add_connection(topic, msgtype, 'cdr', '') # serialize and write message - writer.write(topic, timestamp, serialize_cdr(message, msgtype)) + message = String('hello world') + writer.write(connection, timestamp, serialize_cdr(message, msgtype)) Reading rosbag2 --------------- @@ -49,17 +51,18 @@ Instances of the :py:class:`Reader ` class are used to r # create reader instance and open for reading with Reader('/home/ros/rosbag_2020_03_24') as reader: - # topic and msgtype information is available on .topics dict - for topic, msgtype in reader.topics.items(): - print(topic, msgtype) + # topic and msgtype information is available on .connections dict + for connection in reader.connections.values(): + print(connection.topic, connection.msgtype) # iterate over messages - for topic, msgtype, timestamp, rawdata in reader.messages(): - if topic == '/imu_raw/Imu': - msg = deserialize_cdr(rawdata, msgtype) + for connection, timestamp, rawdata in reader.messages(): + if connection.topic == '/imu_raw/Imu': + msg = deserialize_cdr(rawdata, connection.msgtype) print(msg.header.frame_id) - # messages() accepts topic filters - for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): - msg = deserialize_cdr(rawdata, msgtype) + # messages() accepts connection filters + connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] + for connection, timestamp, rawdata in reader.messages(connections=connections): + msg = deserialize_cdr(rawdata, connection.msgtype) print(msg.header.frame_id) diff --git a/src/rosbags/convert/converter.py b/src/rosbags/convert/converter.py index 8c46464f..8f76567a 100644 --- a/src/rosbags/convert/converter.py +++ b/src/rosbags/convert/converter.py @@ -4,10 +4,12 @@ from __future__ import annotations +from dataclasses import asdict from typing import TYPE_CHECKING from rosbags.rosbag1 import Reader, ReaderError from rosbags.rosbag2 import Writer, WriterError +from rosbags.rosbag2.connection import Connection as WConnection from rosbags.serde import ros1_to_cdr from rosbags.typesys import get_types_from_msg, register_types @@ -15,6 +17,8 @@ if TYPE_CHECKING: from pathlib import Path from typing import Any, Dict, Optional + from rosbags.rosbag1.reader import Connection as RConnection + LATCH = """ - history: 3 depth: 0 @@ -38,6 +42,26 @@ class ConverterError(Exception): """Converter Error.""" +def convert_connection(rconn: RConnection) -> WConnection: + """Convert rosbag1 connection to rosbag2 connection. + + Args: + rconn: Rosbag1 connection. + + Returns: + Rosbag2 connection. + + """ + return WConnection( + -1, + 0, + rconn.topic, + rconn.msgtype, + 'cdr', + LATCH if rconn.latching else '', + ) + + def convert(src: Path, dst: Optional[Path]) -> None: """Convert Rosbag1 to Rosbag2. @@ -57,21 +81,19 @@ def convert(src: Path, dst: Optional[Path]) -> None: try: with Reader(src) as reader, Writer(dst) as writer: typs: Dict[str, Any] = {} - for name, topic in reader.topics.items(): - connection = next( # pragma: no branch - x for x in reader.connections.values() if x.topic == name - ) - writer.add_topic( - name, - topic.msgtype, - offered_qos_profiles=LATCH if connection.latching else '', - ) - typs.update(get_types_from_msg(topic.msgdef, topic.msgtype)) + connmap: Dict[int, WConnection] = {} + + for rconn in reader.connections.values(): + candidate = convert_connection(rconn) + existing = next((x for x in writer.connections.values() if x == candidate), None) + wconn = existing if existing else writer.add_connection(**asdict(candidate)) + connmap[rconn.cid] = wconn + typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype)) register_types(typs) - for topic, msgtype, timestamp, data in reader.messages(): - data = ros1_to_cdr(data, msgtype) - writer.write(topic, timestamp, data) + for rconn, timestamp, data in reader.messages(): + data = ros1_to_cdr(data, rconn.msgtype) + writer.write(connmap[rconn.cid], timestamp, data) except ReaderError as err: raise ConverterError(f'Reading source bag: {err}') from err except WriterError as err: diff --git a/src/rosbags/rosbag1/reader.py b/src/rosbags/rosbag1/reader.py index fb9fb0c4..b928cff3 100644 --- a/src/rosbags/rosbag1/reader.py +++ b/src/rosbags/rosbag1/reader.py @@ -249,7 +249,7 @@ class Header(dict): raise ReaderError(f'Could not read time field {name!r}.') from err @classmethod - def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header': + def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header: """Read header from file handle. Args: @@ -588,7 +588,7 @@ class Reader: topics: Optional[Iterable[str]] = None, start: Optional[int] = None, stop: Optional[int] = None, - ) -> Generator[Tuple[str, str, int, bytes], None, None]: + ) -> Generator[Tuple[Connection, int, bytes], None, None]: """Read messages from bag. Args: @@ -598,7 +598,7 @@ class Reader: stop: Yield only messages before this timestamp (ns). Yields: - Tuples of topic name, type, timestamp (ns), and rawdata. + Tuples of connection, timestamp (ns), and rawdata. Raises: ReaderError: Bag not open or data corrupt. @@ -635,13 +635,10 @@ class Reader: if have != RecordType.MSGDATA: raise ReaderError('Expected to find message data.') - connection = self.connections[header.get_uint32('conn')] - time = header.get_time('time') - data = read_bytes(chunk, read_uint32(chunk)) - - assert entry.time == time - yield connection.topic, connection.msgtype, time, data + connection = self.connections[header.get_uint32('conn')] + assert entry.time == header.get_time('time') + yield connection, entry.time, data def __enter__(self) -> Reader: """Open rosbag1 when entering contextmanager.""" diff --git a/src/rosbags/rosbag2/connection.py b/src/rosbags/rosbag2/connection.py new file mode 100644 index 00000000..76d080fc --- /dev/null +++ b/src/rosbags/rosbag2/connection.py @@ -0,0 +1,18 @@ +# Copyright 2020-2021 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Rosbag2 connection.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class Connection: + """Connection metadata.""" + id: int = field(compare=False) # pylint: disable=invalid-name + count: int = field(compare=False) + topic: str + msgtype: str + serialization_format: str + offered_qos_profiles: str diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index a38e0f47..03973657 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -13,9 +13,11 @@ from typing import TYPE_CHECKING import zstandard from ruamel.yaml import YAML, YAMLError +from .connection import Connection + if TYPE_CHECKING: from types import TracebackType - from typing import Any, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union + from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union class ReaderError(Exception): @@ -96,11 +98,21 @@ class Reader: if missing: raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') - topics = [x['topic_metadata'] for x in self.metadata['topics_with_message_count']] - noncdr = {y for x in topics if (y := x['serialization_format']) != 'cdr'} + self.connections = { + idx + 1: Connection( + id=idx + 1, + count=x['message_count'], + topic=x['topic_metadata']['name'], + msgtype=x['topic_metadata']['type'], + serialization_format=x['topic_metadata']['serialization_format'], + offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''), + ) for idx, x in enumerate(self.metadata['topics_with_message_count']) + } + noncdr = { + y for x in self.connections.values() if (y := x.serialization_format) != 'cdr' + } if noncdr: raise ReaderError(f'Serialization format {noncdr!r} is not supported.') - self.topics = {x['name']: x['type'] for x in topics} if self.compression_mode and (cfmt := self.compression_format) != 'zstd': raise ReaderError(f'Compression format {cfmt!r} is not supported.') @@ -149,22 +161,31 @@ class Reader: mode = self.metadata.get('compression_mode', '').lower() return mode if mode != 'none' else None + @property + def topics(self) -> Dict[str, Connection]: + """Topic information. + + For the moment this a dictionary mapping topic names to connections. + + """ + return {x.topic: x for x in self.connections.values()} + def messages( # pylint: disable=too-many-locals self, - topics: Iterable[str] = (), + connections: Iterable[Connection] = (), start: Optional[int] = None, stop: Optional[int] = None, - ) -> Generator[Tuple[str, str, int, bytes], None, None]: + ) -> Generator[Tuple[Connection, int, bytes], None, None]: """Read messages from bag. Args: - topics: Iterable with topic names to filter for. An empty iterable - yields all messages. + connections: Iterable with connections to filter for. An empty + iterable disables filtering on connections. start: Yield only messages at or after this timestamp (ns). stop: Yield only messages before this timestamp (ns). Yields: - Tuples of topic name, type, timestamp (ns), and rawdata. + Tuples of connection, timestamp (ns), and rawdata. Raises: ReaderError: Bag not open. @@ -173,7 +194,32 @@ class Reader: if not self.bio: raise ReaderError('Rosbag is not open.') - topics = tuple(topics) + query = [ + 'SELECT topics.id,messages.timestamp,messages.data', + 'FROM messages JOIN topics ON messages.topic_id=topics.id', + ] + args: List[Any] = [] + clause = 'WHERE' + + if connections: + topics = {x.topic for x in connections} + query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})') + args += topics + clause = 'AND' + + if start is not None: + query.append(f'{clause} messages.timestamp >= ?') + args.append(start) + clause = 'AND' + + if stop is not None: + query.append(f'{clause} messages.timestamp < ?') + args.append(stop) + clause = 'AND' + + query.append('ORDER BY timestamp') + querystr = ' '.join(query) + for filepath in self.paths: with decompress(filepath, self.compression_mode == 'file') as path: conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True) @@ -186,34 +232,16 @@ class Reader: if cur.fetchone()[0] != 2: raise ReaderError(f'Cannot open database {path} or database missing tables.') - query = [ - 'SELECT topics.name,topics.type,messages.timestamp,messages.data', - 'FROM messages JOIN topics ON messages.topic_id=topics.id', - ] - args: List[Any] = [] - - if topics: - query.append(f'WHERE topics.name IN ({",".join("?" for _ in topics)})') - args += topics - - if start is not None: - query.append(f'{"AND" if args else "WHERE"} messages.timestamp >= ?') - args.append(start) - - if stop is not None: - query.append(f'{"AND" if args else "WHERE"} messages.timestamp < ?') - args.append(stop) - - query.append('ORDER BY timestamp') - cur.execute(' '.join(query), args) + cur.execute(querystr, args) if self.compression_mode == 'message': decomp = zstandard.ZstdDecompressor().decompress for row in cur: - topic, msgtype, timestamp, data = row - yield topic, msgtype, timestamp, decomp(data) + cid, timestamp, data = row + yield self.connections[cid], timestamp, decomp(data) else: - yield from cur + for cid, timestamp, data in cur: + yield self.connections[cid], timestamp, data def __enter__(self) -> Reader: """Open rosbag2 when entering contextmanager.""" diff --git a/src/rosbags/rosbag2/writer.py b/src/rosbags/rosbag2/writer.py index 23d4c701..18becac6 100644 --- a/src/rosbags/rosbag2/writer.py +++ b/src/rosbags/rosbag2/writer.py @@ -1,6 +1,6 @@ # Copyright 2020-2021 Ternaris. # SPDX-License-Identifier: Apache-2.0 -"""Rosbag2 reader.""" +"""Rosbag2 writer.""" from __future__ import annotations @@ -12,6 +12,8 @@ from typing import TYPE_CHECKING import zstandard from ruamel.yaml import YAML +from .connection import Connection + if TYPE_CHECKING: from types import TracebackType from typing import Any, Dict, Literal, Optional, Type, Union @@ -77,10 +79,9 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_mode = '' self.compression_format = '' self.compressor: Optional[zstandard.ZstdCompressor] = None - self.topics: Dict[str, Any] = {} + self.connections: Dict[int, Connection] = {} self.conn = None self.cursor: Optional[sqlite3.Cursor] = None - self.topics = {} def set_compression(self, mode: CompressionMode, fmt: CompressionFormat): """Enable compression on bag. @@ -118,22 +119,27 @@ class Writer: # pylint: disable=too-many-instance-attributes self.conn.executescript(self.SQLITE_SCHEMA) self.cursor = self.conn.cursor() - def add_topic( + def add_connection( self, - name: str, - typ: str, + topic: str, + msgtype: str, serialization_format: str = 'cdr', offered_qos_profiles: str = '', - ): - """Add a topic. + **_kw: Any, + ) -> Connection: + """Add a connection. This function can only be called after opening a bag. Args: - name: Topic name. - typ: Message type. + topic: Topic name. + msgtype: Message type. serialization_format: Serialization format. offered_qos_profiles: QOS Profile. + _kw: Ignored to allow consuming dicts from connection objects. + + Returns: + Connection object. Raises: WriterError: Bag not open or topic previously registered. @@ -141,17 +147,28 @@ class Writer: # pylint: disable=too-many-instance-attributes """ if not self.cursor: raise WriterError('Bag was not opened.') - if name in self.topics: - raise WriterError(f'Topics can only be added once: {name!r}.') - meta = (len(self.topics) + 1, name, typ, serialization_format, offered_qos_profiles) - self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) - self.topics[name] = [*meta, 0] - def write(self, topic: str, timestamp: int, data: bytes): + connection = Connection( + id=len(self.connections.values()) + 1, + count=0, + topic=topic, + msgtype=msgtype, + serialization_format=serialization_format, + offered_qos_profiles=offered_qos_profiles, + ) + if connection in self.connections.values(): + raise WriterError(f'Connection can only be added once: {connection!r}.') + + self.connections[connection.id] = connection + meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles) + self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) + return connection + + def write(self, connection: Connection, timestamp: int, data: bytes): """Write message to rosbag2. Args: - topic: Topic message belongs to. + connection: Connection to write message to. timestamp: Message timestamp (ns). data: Serialized message data. @@ -161,19 +178,18 @@ class Writer: # pylint: disable=too-many-instance-attributes """ if not self.cursor: raise WriterError('Bag was not opened.') - if topic not in self.topics: - raise WriterError(f'Tried to write to unknown topic {topic!r}.') + if connection not in self.connections.values(): + raise WriterError(f'Tried to write to unknown connection {connection!r}.') if self.compression_mode == 'message': assert self.compressor data = self.compressor.compress(data) - tmeta = self.topics[topic] self.cursor.execute( 'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)', - (tmeta[0], timestamp, data), + (connection.id, timestamp, data), ) - tmeta[-1] += 1 + connection.count += 1 def close(self): """Close rosbag2 after writing. @@ -214,13 +230,13 @@ class Writer: # pylint: disable=too-many-instance-attributes 'topics_with_message_count': [ { 'topic_metadata': { - 'name': x[1], - 'type': x[2], - 'serialization_format': x[3], - 'offered_qos_profiles': x[4], + 'name': x.topic, + 'type': x.msgtype, + 'serialization_format': x.serialization_format, + 'offered_qos_profiles': x.offered_qos_profiles, }, - 'message_count': x[5], - } for x in self.topics.values() + 'message_count': x.count, + } for x in self.connections.values() ], 'compression_format': self.compression_format, 'compression_mode': self.compression_mode, diff --git a/tests/test_convert.py b/tests/test_convert.py index db3cefc8..057a6092 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -76,17 +76,29 @@ def test_convert(tmp_path: Path): patch('rosbags.convert.converter.register_types') as register_types, \ patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr: + connections = [ + Mock(topic='/topic', msgtype='typ', latching=False), + Mock(topic='/topic', msgtype='typ', latching=True), + ] + + wconnections = [ + Mock(topic='/topic', msgtype='typ'), + Mock(topic='/topic', msgtype='typ'), + ] + reader.return_value.__enter__.return_value.connections = { - 0: Mock(topic='/topic', latching=False), - 1: Mock(topic='/latched', latching=True), - } - reader.return_value.__enter__.return_value.topics = { - '/topic': Mock(msgtype='typ', msgdef='def'), - '/latched': Mock(msgtype='typ', msgdef='def'), + 1: connections[0], + 2: connections[1], } + reader.return_value.__enter__.return_value.messages.return_value = [ - ('/topic', 'typ', 42, b'\x42'), - ('/latched', 'typ', 43, b'\x43'), + (connections[0], 42, b'\x42'), + (connections[1], 43, b'\x43'), + ] + + writer.return_value.__enter__.return_value.add_connection.side_effect = [ + wconnections[0], + wconnections[1], ] ros1_to_cdr.return_value = b'666' @@ -97,14 +109,29 @@ def test_convert(tmp_path: Path): reader.return_value.__enter__.return_value.messages.assert_called_with() writer.assert_called_with(Path('foo')) - writer.return_value.__enter__.return_value.add_topic.assert_has_calls( + writer.return_value.__enter__.return_value.add_connection.assert_has_calls( [ - call('/topic', 'typ', offered_qos_profiles=''), - call('/latched', 'typ', offered_qos_profiles=LATCH), + call( + id=-1, + count=0, + topic='/topic', + msgtype='typ', + serialization_format='cdr', + offered_qos_profiles='', + ), + call( + id=-1, + count=0, + topic='/topic', + msgtype='typ', + serialization_format='cdr', + offered_qos_profiles=LATCH, + ), ], ) writer.return_value.__enter__.return_value.write.assert_has_calls( - [call('/topic', 42, b'666'), call('/latched', 43, b'666')], + [call(wconnections[0], 42, b'666'), + call(wconnections[1], 43, b'666')], ) register_types.assert_called_with({'typ': 'def'}) diff --git a/tests/test_reader.py b/tests/test_reader.py index 66d1d02d..2862954b 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -126,25 +126,26 @@ def test_reader(bag: Path): assert reader.message_count == 4 if reader.compression_mode: assert reader.compression_format == 'zstd' - + assert [*reader.connections.keys()] == [1, 2, 3] + assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint'] gen = reader.messages() - topic, msgtype, timestamp, rawdata = next(gen) - assert topic == '/poly' - assert msgtype == 'geometry_msgs/msg/Polygon' + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/poly' + assert connection.msgtype == 'geometry_msgs/msg/Polygon' assert timestamp == 666 assert rawdata == MSG_POLY[0] for idx in range(2): - topic, msgtype, timestamp, rawdata = next(gen) - assert topic == '/magn' - assert msgtype == 'sensor_msgs/msg/MagneticField' + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/magn' + assert connection.msgtype == 'sensor_msgs/msg/MagneticField' assert timestamp == 708 assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0] - topic, msgtype, timestamp, rawdata = next(gen) - assert topic == '/joint' - assert msgtype == 'trajectory_msgs/msg/JointTrajectory' + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/joint' + assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory' with pytest.raises(StopIteration): next(gen) @@ -153,32 +154,32 @@ def test_reader(bag: Path): def test_message_filters(bag: Path): """Test reader filters messages.""" with Reader(bag) as reader: - - gen = reader.messages(['/magn']) - topic, _, _, _ = next(gen) - assert topic == '/magn' - topic, _, _, _ = next(gen) - assert topic == '/magn' + magn_connections = [x for x in reader.connections.values() if x.topic == '/magn'] + gen = reader.messages(connections=magn_connections) + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/magn' with pytest.raises(StopIteration): next(gen) gen = reader.messages(start=667) - topic, _, _, _ = next(gen) - assert topic == '/magn' - topic, _, _, _ = next(gen) - assert topic == '/magn' - topic, _, _, _ = next(gen) - assert topic == '/joint' + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/joint' with pytest.raises(StopIteration): next(gen) gen = reader.messages(stop=667) - topic, _, _, _ = next(gen) - assert topic == '/poly' + connection, _, _ = next(gen) + assert connection.topic == '/poly' with pytest.raises(StopIteration): next(gen) - gen = reader.messages(['/magn'], stop=667) + gen = reader.messages(connections=magn_connections, stop=667) with pytest.raises(StopIteration): next(gen) diff --git a/tests/test_reader1.py b/tests/test_reader1.py index 294e052b..f0a08124 100644 --- a/tests/test_reader1.py +++ b/tests/test_reader1.py @@ -141,7 +141,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too- header['index_pos'] = pack(' (path / 'compress_message.db3').stat().st_size @@ -76,7 +77,7 @@ def test_failure_cases(tmp_path: Path): bag = Writer(tmp_path / 'topic') with pytest.raises(WriterError, match='was not opened'): - bag.add_topic('/tf', 'tf_msgs/msg/tf2') + bag.add_connection('/tf', 'tf_msgs/msg/tf2') bag = Writer(tmp_path / 'write') with pytest.raises(WriterError, match='was not opened'): @@ -84,11 +85,12 @@ def test_failure_cases(tmp_path: Path): bag = Writer(tmp_path / 'topic') bag.open() - bag.add_topic('/tf', 'tf_msgs/msg/tf2') + bag.add_connection('/tf', 'tf_msgs/msg/tf2') with pytest.raises(WriterError, match='only be added once'): - bag.add_topic('/tf', 'tf_msgs/msg/tf2') + bag.add_connection('/tf', 'tf_msgs/msg/tf2') bag = Writer(tmp_path / 'notopic') bag.open() - with pytest.raises(WriterError, match='unknown topic'): - bag.write('/test', 42, b'\x00') + connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '') + with pytest.raises(WriterError, match='unknown connection'): + bag.write(connection, 42, b'\x00')