From 34ffe9669257bed40bfc60c3e338719efeb6d294 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Thu, 21 Apr 2022 15:15:10 +0200 Subject: [PATCH] Convert connections attribute to list --- docs/examples/edit_rosbags_edit_timestamps.py | 2 +- docs/examples/edit_rosbags_remove_topic.py | 2 +- docs/examples/register_types_rosbag1.py | 2 +- docs/topics/rosbag1.rst | 6 +-- docs/topics/rosbag2.rst | 6 +-- src/rosbags/convert/converter.py | 8 ++-- src/rosbags/highlevel/anyreader.py | 4 +- src/rosbags/rosbag1/reader.py | 31 +++++++------- src/rosbags/rosbag1/writer.py | 10 ++--- src/rosbags/rosbag2/reader.py | 15 +++---- src/rosbags/rosbag2/writer.py | 12 +++--- tests/test_convert.py | 42 +++++++++---------- tests/test_reader.py | 4 +- tests/test_reader1.py | 2 +- 14 files changed, 72 insertions(+), 74 deletions(-) diff --git a/docs/examples/edit_rosbags_edit_timestamps.py b/docs/examples/edit_rosbags_edit_timestamps.py index 6de70aae..04af43db 100644 --- a/docs/examples/edit_rosbags_edit_timestamps.py +++ b/docs/examples/edit_rosbags_edit_timestamps.py @@ -23,7 +23,7 @@ def offset_timestamps(src: Path, dst: Path, offset: int) -> None: """ with Reader(src) as reader, Writer(dst) as writer: conn_map = {} - for conn in reader.connections.values(): + for conn in reader.connections: ext = cast(ConnectionExtRosbag2, conn.ext) conn_map[conn.id] = writer.add_connection( conn.topic, diff --git a/docs/examples/edit_rosbags_remove_topic.py b/docs/examples/edit_rosbags_remove_topic.py index aa00dabd..4824343c 100644 --- a/docs/examples/edit_rosbags_remove_topic.py +++ b/docs/examples/edit_rosbags_remove_topic.py @@ -22,7 +22,7 @@ def remove_topic(src: Path, dst: Path, topic: str) -> None: """ with Reader(src) as reader, Writer(dst) as writer: conn_map = {} - for conn in reader.connections.values(): + for conn in reader.connections: if conn.topic == topic: continue ext = cast(ConnectionExtRosbag2, conn.ext) diff --git a/docs/examples/register_types_rosbag1.py b/docs/examples/register_types_rosbag1.py index 0be64334..c183a014 100644 --- a/docs/examples/register_types_rosbag1.py +++ b/docs/examples/register_types_rosbag1.py @@ -20,7 +20,7 @@ def process_bag(src: Path) -> None: """ with Reader(src) as reader: typs = {} - for conn in reader.connections.values(): + for conn in reader.connections: typs.update(get_types_from_msg(conn.msgdef, conn.msgtype)) register_types(typs) diff --git a/docs/topics/rosbag1.rst b/docs/topics/rosbag1.rst index 0e9b7c94..73b0aa62 100644 --- a/docs/topics/rosbag1.rst +++ b/docs/topics/rosbag1.rst @@ -37,8 +37,8 @@ Instances of the :py:class:`Reader ` class are typically # create reader instance with Reader('/home/ros/rosbag_2020_03_24.bag') as reader: - # topic and msgtype information is available on .connections dictionary - for connection in reader.connections.values(): + # topic and msgtype information is available on .connections list + for connection in reader.connections: print(connection.topic, connection.msgtype) # iterate over messages @@ -48,7 +48,7 @@ Instances of the :py:class:`Reader ` class are typically print(msg.header.frame_id) # messages() accepts connection filters - connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] + connections = [x for x in reader.connections if x.topic == '/imu_raw/Imu'] for connection, timestamp, rawdata in reader.messages(connections=connections): msg = deserialize_cdr(ros1_to_cdr(rawdata, connection.msgtype), connection.msgtype) print(msg.header.frame_id) diff --git a/docs/topics/rosbag2.rst b/docs/topics/rosbag2.rst index 1f3fbe57..dbee4d06 100644 --- a/docs/topics/rosbag2.rst +++ b/docs/topics/rosbag2.rst @@ -52,8 +52,8 @@ Instances of the :py:class:`Reader ` class are used to r # create reader instance and open for reading with Reader('/home/ros/rosbag_2020_03_24') as reader: - # topic and msgtype information is available on .connections dict - for connection in reader.connections.values(): + # topic and msgtype information is available on .connections list + for connection in reader.connections: print(connection.topic, connection.msgtype) # iterate over messages @@ -63,7 +63,7 @@ Instances of the :py:class:`Reader ` class are used to r print(msg.header.frame_id) # messages() accepts connection filters - connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] + connections = [x for x in reader.connections if x.topic == '/imu_raw/Imu'] for connection, timestamp, rawdata in reader.messages(connections=connections): msg = deserialize_cdr(rawdata, connection.msgtype) print(msg.header.frame_id) diff --git a/src/rosbags/convert/converter.py b/src/rosbags/convert/converter.py index 093d2f03..339ad0ce 100644 --- a/src/rosbags/convert/converter.py +++ b/src/rosbags/convert/converter.py @@ -111,10 +111,10 @@ def convert_1to2(src: Path, dst: Path) -> None: typs: dict[str, Any] = {} connmap: dict[int, Connection] = {} - for rconn in reader.connections.values(): + for rconn in reader.connections: candidate = upgrade_connection(rconn) assert isinstance(candidate.ext, ConnectionExtRosbag2) - for conn in writer.connections.values(): + for conn in writer.connections: assert isinstance(conn.ext, ConnectionExtRosbag2) if ( conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and @@ -147,10 +147,10 @@ def convert_2to1(src: Path, dst: Path) -> None: """ with Reader2(src) as reader, Writer1(dst) as writer: connmap: dict[int, Connection] = {} - for rconn in reader.connections.values(): + for rconn in reader.connections: candidate = downgrade_connection(rconn) assert isinstance(candidate.ext, ConnectionExtRosbag1) - for conn in writer.connections.values(): + for conn in writer.connections: assert isinstance(conn.ext, ConnectionExtRosbag1) if ( conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and diff --git a/src/rosbags/highlevel/anyreader.py b/src/rosbags/highlevel/anyreader.py index 1a635a17..ba265b56 100644 --- a/src/rosbags/highlevel/anyreader.py +++ b/src/rosbags/highlevel/anyreader.py @@ -143,11 +143,11 @@ class AnyReader: typs: dict[str, Any] = {} for reader in self.readers: assert isinstance(reader, Reader1) - for connection in reader.connections.values(): + for connection in reader.connections: typs.update(get_types_from_msg(connection.msgdef, connection.msgtype)) register_types(typs, self.typestore) - self.connections = [y for x in self.readers for y in x.connections.values()] + self.connections = [y for x in self.readers for y in x.connections] self.isopen = True def close(self) -> None: diff --git a/src/rosbags/rosbag1/reader.py b/src/rosbags/rosbag1/reader.py index 1f3058aa..483e731f 100644 --- a/src/rosbags/rosbag1/reader.py +++ b/src/rosbags/rosbag1/reader.py @@ -342,7 +342,7 @@ class Reader: raise ReaderError(f'File {str(self.path)!r} does not exist.') self.bio: Optional[BinaryIO] = None - self.connections: dict[int, Connection] = {} + self.connections: list[Connection] = [] self.indexes: dict[int, list[IndexData]] self.chunk_infos: list[ChunkInfo] = [] self.chunks: dict[int, Chunk] = {} @@ -384,7 +384,7 @@ class Reader: self.bio.seek(index_pos) try: - self.connections = dict(self.read_connection() for _ in range(conn_count)) + self.connections = [self.read_connection() for _ in range(conn_count)] self.chunk_infos = [self.read_chunk_info() for _ in range(chunk_count)] except ReaderError as err: raise ReaderError(f'Bag index looks damaged: {err.args}') from None @@ -402,14 +402,15 @@ class Reader: self.indexes = { cid: list(heapq.merge(*x, key=lambda x: x.time)) for cid, x in indexes.items() } - assert all(self.indexes[x] for x in self.connections) + assert all(self.indexes[x.id] for x in self.connections) - for cid, connection in self.connections.items(): - self.connections[cid] = Connection( - *connection[0:5], - len(self.indexes[cid]), - *connection[6:], - ) + self.connections = [ + Connection( + *x[0:5], + len(self.indexes[x.id]), + *x[6:], + ) for x in self.connections + ] except ReaderError: self.close() raise @@ -445,7 +446,7 @@ class Reader: """Topic information.""" topics = {} for topic, group in groupby( - sorted(self.connections.values(), key=lambda x: x.topic), + sorted(self.connections, key=lambda x: x.topic), key=lambda x: x.topic, ): connections = list(group) @@ -462,7 +463,7 @@ class Reader: ) return topics - def read_connection(self) -> tuple[int, Connection]: + def read_connection(self) -> Connection: """Read connection record from current position.""" assert self.bio header = Header.read(self.bio, RecordType.CONNECTION) @@ -477,7 +478,7 @@ class Reader: callerid = header.get_string('callerid') if 'callerid' in header else None latching = int(header.get_string('latching')) if 'latching' in header else None - return conn, Connection( + return Connection( conn, topic, normalize_msgtype(typ), @@ -593,7 +594,9 @@ class Reader: raise ReaderError('Rosbag is not open.') if not connections: - connections = self.connections.values() + connections = self.connections + + connmap = {x.id: x for x in self.connections} indexes = [self.indexes[x.id] for x in connections] for entry in heapq.merge(*indexes): @@ -624,7 +627,7 @@ class Reader: raise ReaderError('Expected to find message data.') data = read_bytes(chunk, read_uint32(chunk)) - connection = self.connections[header.get_uint32('conn')] + connection = connmap[header.get_uint32('conn')] assert entry.time == header.get_time('time') yield connection, entry.time, data diff --git a/src/rosbags/rosbag1/writer.py b/src/rosbags/rosbag1/writer.py index 8150b36d..d1ba14ed 100644 --- a/src/rosbags/rosbag1/writer.py +++ b/src/rosbags/rosbag1/writer.py @@ -159,7 +159,7 @@ class Writer: self.bio: Optional[BinaryIO] = None self.compressor: Callable[[bytes], bytes] = lambda x: x self.compression_format = 'none' - self.connections: dict[int, Connection] = {} + self.connections: list[Connection] = [] self.chunks: list[WriteChunk] = [ WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)), ] @@ -258,7 +258,7 @@ class Writer: self, ) - if any(x[1:] == connection[1:] for x in self.connections.values()): + if any(x[1:] == connection[1:] for x in self.connections): raise WriterError( f'Connections can only be added once with same arguments: {connection!r}.', ) @@ -266,7 +266,7 @@ class Writer: bio = self.chunks[-1].data self.write_connection(connection, bio) - self.connections[connection.id] = connection + self.connections.append(connection) return connection def write(self, connection: Connection, timestamp: int, data: bytes) -> None: @@ -284,7 +284,7 @@ class Writer: if not self.bio: raise WriterError('Bag was not opened.') - if connection not in self.connections.values(): + if connection not in self.connections: raise WriterError(f'There is no connection {connection!r}.') from None chunk = self.chunks[-1] @@ -367,7 +367,7 @@ class Writer: index_pos = self.bio.tell() - for connection in self.connections.values(): + for connection in self.connections: self.write_connection(connection, self.bio) for chunk in self.chunks: diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index 6ac3e90a..0f261a99 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -136,8 +136,8 @@ class Reader: if missing := [x for x in self.paths if not x.exists()]: raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') - self.connections = { - idx + 1: Connection( + self.connections = [ + Connection( id=idx + 1, topic=x['topic_metadata']['name'], msgtype=x['topic_metadata']['type'], @@ -150,9 +150,9 @@ class Reader: ), owner=self, ) for idx, x in enumerate(self.metadata['topics_with_message_count']) - } + ] noncdr = { - fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) + fmt for x in self.connections if isinstance(x.ext, ConnectionExtRosbag2) if (fmt := x.ext.serialization_format) != 'cdr' } if noncdr: @@ -209,10 +209,7 @@ class Reader: @property def topics(self) -> dict[str, TopicInfo]: """Topic information.""" - return { - x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) - for x in self.connections.values() - } + return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections} def messages( # pylint: disable=too-many-locals self, @@ -278,7 +275,7 @@ class Reader: cur.execute('SELECT name,id FROM topics') connmap: dict[int, Connection] = { - row[1]: next((x for x in self.connections.values() if x.topic == row[0]), + row[1]: next((x for x in self.connections if x.topic == row[0]), None) # type: ignore for row in cur } diff --git a/src/rosbags/rosbag2/writer.py b/src/rosbags/rosbag2/writer.py index 5c5becc0..6efaf197 100644 --- a/src/rosbags/rosbag2/writer.py +++ b/src/rosbags/rosbag2/writer.py @@ -81,7 +81,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_mode = '' self.compression_format = '' self.compressor: Optional[zstandard.ZstdCompressor] = None - self.connections: dict[int, Connection] = {} + self.connections: list[Connection] = [] self.counts: dict[int, int] = {} self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None @@ -152,7 +152,7 @@ class Writer: # pylint: disable=too-many-instance-attributes raise WriterError('Bag was not opened.') connection = Connection( - id=len(self.connections.values()) + 1, + id=len(self.connections) + 1, topic=topic, msgtype=msgtype, msgdef='', @@ -164,14 +164,14 @@ class Writer: # pylint: disable=too-many-instance-attributes ), owner=self, ) - for conn in self.connections.values(): + for conn in self.connections: if ( conn.topic == connection.topic and conn.msgtype == connection.msgtype and conn.ext == connection.ext ): raise WriterError(f'Connection can only be added once: {connection!r}.') - self.connections[connection.id] = connection + self.connections.append(connection) self.counts[connection.id] = 0 meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles) self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) @@ -191,7 +191,7 @@ class Writer: # pylint: disable=too-many-instance-attributes """ if not self.cursor: raise WriterError('Bag was not opened.') - if connection not in self.connections.values(): + if connection not in self.connections: raise WriterError(f'Tried to write to unknown connection {connection!r}.') if self.compression_mode == 'message': @@ -252,7 +252,7 @@ class Writer: # pylint: disable=too-many-instance-attributes 'offered_qos_profiles': x.ext.offered_qos_profiles, }, 'message_count': self.counts[x.id], - } for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) + } for x in self.connections if isinstance(x.ext, ConnectionExtRosbag2) ], 'compression_format': self.compression_format, 'compression_mode': self.compression_mode, diff --git a/tests/test_convert.py b/tests/test_convert.py index d4c69293..8efa2192 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -146,12 +146,12 @@ def test_convert_1to2(tmp_path: Path) -> None: Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None), ] - readerinst.connections = { - 1: connections[0], - 2: connections[1], - 3: connections[2], - 4: connections[3], - } + readerinst.connections = [ + connections[0], + connections[1], + connections[2], + connections[3], + ] readerinst.messages.return_value = [ (connections[0], 42, b'\x42'), @@ -160,14 +160,13 @@ def test_convert_1to2(tmp_path: Path) -> None: (connections[3], 45, b'\x45'), ] - writerinst.connections = {} + writerinst.connections = [] def add_connection(*_: Any) -> Connection: # noqa: ANN401 """Mock for Writer.add_connection.""" - writerinst.connections = { - conn.id: conn - for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) - } + writerinst.connections = [ + conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) + ] return wconnections[len(writerinst.connections) - 1] writerinst.add_connection.side_effect = add_connection @@ -311,12 +310,12 @@ def test_convert_2to1(tmp_path: Path) -> None: ), ] - readerinst.connections = { - 1: connections[0], - 2: connections[1], - 3: connections[2], - 4: connections[3], - } + readerinst.connections = [ + connections[0], + connections[1], + connections[2], + connections[3], + ] readerinst.messages.return_value = [ (connections[0], 42, b'\x42'), @@ -325,14 +324,13 @@ def test_convert_2to1(tmp_path: Path) -> None: (connections[3], 45, b'\x45'), ] - writerinst.connections = {} + writerinst.connections = [] def add_connection(*_: Any) -> Connection: # noqa: ANN401 """Mock for Writer.add_connection.""" - writerinst.connections = { - conn.id: conn - for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) - } + writerinst.connections = [ + conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) + ] return wconnections[len(writerinst.connections) - 1] writerinst.add_connection.side_effect = add_connection diff --git a/tests/test_reader.py b/tests/test_reader.py index 4f15f2ff..84ecddfd 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -126,7 +126,7 @@ def test_reader(bag: Path) -> None: assert reader.message_count == 4 if reader.compression_mode: assert reader.compression_format == 'zstd' - assert [*reader.connections.keys()] == [1, 2, 3] + assert [x.id for x in reader.connections] == [1, 2, 3] assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint'] gen = reader.messages() @@ -154,7 +154,7 @@ def test_reader(bag: Path) -> None: def test_message_filters(bag: Path) -> None: """Test reader filters messages.""" with Reader(bag) as reader: - magn_connections = [x for x in reader.connections.values() if x.topic == '/magn'] + magn_connections = [x for x in reader.connections if x.topic == '/magn'] gen = reader.messages(connections=magn_connections) connection, _, _ = next(gen) assert connection.topic == '/magn' diff --git a/tests/test_reader1.py b/tests/test_reader1.py index 550f417e..2a9cbe1d 100644 --- a/tests/test_reader1.py +++ b/tests/test_reader1.py @@ -274,7 +274,7 @@ def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements assert msgs[0][2] == b'MSGCONTENT5' assert msgs[1][2] == b'MSGCONTENT10' - connections = [x for x in reader.connections.values() if x.topic == '/topic0'] + connections = [x for x in reader.connections if x.topic == '/topic0'] msgs = list(reader.messages(connections)) assert len(msgs) == 1 assert msgs[0][2] == b'MSGCONTENT10'