Convert connections attribute to list

This commit is contained in:
Marko Durkovic 2022-04-21 15:15:10 +02:00
parent f7d69e35d5
commit 34ffe96692
14 changed files with 72 additions and 74 deletions

View File

@ -23,7 +23,7 @@ def offset_timestamps(src: Path, dst: Path, offset: int) -> None:
""" """
with Reader(src) as reader, Writer(dst) as writer: with Reader(src) as reader, Writer(dst) as writer:
conn_map = {} conn_map = {}
for conn in reader.connections.values(): for conn in reader.connections:
ext = cast(ConnectionExtRosbag2, conn.ext) ext = cast(ConnectionExtRosbag2, conn.ext)
conn_map[conn.id] = writer.add_connection( conn_map[conn.id] = writer.add_connection(
conn.topic, conn.topic,

View File

@ -22,7 +22,7 @@ def remove_topic(src: Path, dst: Path, topic: str) -> None:
""" """
with Reader(src) as reader, Writer(dst) as writer: with Reader(src) as reader, Writer(dst) as writer:
conn_map = {} conn_map = {}
for conn in reader.connections.values(): for conn in reader.connections:
if conn.topic == topic: if conn.topic == topic:
continue continue
ext = cast(ConnectionExtRosbag2, conn.ext) ext = cast(ConnectionExtRosbag2, conn.ext)

View File

@ -20,7 +20,7 @@ def process_bag(src: Path) -> None:
""" """
with Reader(src) as reader: with Reader(src) as reader:
typs = {} typs = {}
for conn in reader.connections.values(): for conn in reader.connections:
typs.update(get_types_from_msg(conn.msgdef, conn.msgtype)) typs.update(get_types_from_msg(conn.msgdef, conn.msgtype))
register_types(typs) register_types(typs)

View File

@ -37,8 +37,8 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
# create reader instance # create reader instance
with Reader('/home/ros/rosbag_2020_03_24.bag') as reader: with Reader('/home/ros/rosbag_2020_03_24.bag') as reader:
# topic and msgtype information is available on .connections dictionary # topic and msgtype information is available on .connections list
for connection in reader.connections.values(): for connection in reader.connections:
print(connection.topic, connection.msgtype) print(connection.topic, connection.msgtype)
# iterate over messages # iterate over messages
@ -48,7 +48,7 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
print(msg.header.frame_id) print(msg.header.frame_id)
# messages() accepts connection filters # messages() accepts connection filters
connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] connections = [x for x in reader.connections if x.topic == '/imu_raw/Imu']
for connection, timestamp, rawdata in reader.messages(connections=connections): for connection, timestamp, rawdata in reader.messages(connections=connections):
msg = deserialize_cdr(ros1_to_cdr(rawdata, connection.msgtype), connection.msgtype) msg = deserialize_cdr(ros1_to_cdr(rawdata, connection.msgtype), connection.msgtype)
print(msg.header.frame_id) print(msg.header.frame_id)

View File

@ -52,8 +52,8 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are used to r
# create reader instance and open for reading # create reader instance and open for reading
with Reader('/home/ros/rosbag_2020_03_24') as reader: with Reader('/home/ros/rosbag_2020_03_24') as reader:
# topic and msgtype information is available on .connections dict # topic and msgtype information is available on .connections list
for connection in reader.connections.values(): for connection in reader.connections:
print(connection.topic, connection.msgtype) print(connection.topic, connection.msgtype)
# iterate over messages # iterate over messages
@ -63,7 +63,7 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are used to r
print(msg.header.frame_id) print(msg.header.frame_id)
# messages() accepts connection filters # messages() accepts connection filters
connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu'] connections = [x for x in reader.connections if x.topic == '/imu_raw/Imu']
for connection, timestamp, rawdata in reader.messages(connections=connections): for connection, timestamp, rawdata in reader.messages(connections=connections):
msg = deserialize_cdr(rawdata, connection.msgtype) msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id) print(msg.header.frame_id)

View File

@ -111,10 +111,10 @@ def convert_1to2(src: Path, dst: Path) -> None:
typs: dict[str, Any] = {} typs: dict[str, Any] = {}
connmap: dict[int, Connection] = {} connmap: dict[int, Connection] = {}
for rconn in reader.connections.values(): for rconn in reader.connections:
candidate = upgrade_connection(rconn) candidate = upgrade_connection(rconn)
assert isinstance(candidate.ext, ConnectionExtRosbag2) assert isinstance(candidate.ext, ConnectionExtRosbag2)
for conn in writer.connections.values(): for conn in writer.connections:
assert isinstance(conn.ext, ConnectionExtRosbag2) assert isinstance(conn.ext, ConnectionExtRosbag2)
if ( if (
conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and
@ -147,10 +147,10 @@ def convert_2to1(src: Path, dst: Path) -> None:
""" """
with Reader2(src) as reader, Writer1(dst) as writer: with Reader2(src) as reader, Writer1(dst) as writer:
connmap: dict[int, Connection] = {} connmap: dict[int, Connection] = {}
for rconn in reader.connections.values(): for rconn in reader.connections:
candidate = downgrade_connection(rconn) candidate = downgrade_connection(rconn)
assert isinstance(candidate.ext, ConnectionExtRosbag1) assert isinstance(candidate.ext, ConnectionExtRosbag1)
for conn in writer.connections.values(): for conn in writer.connections:
assert isinstance(conn.ext, ConnectionExtRosbag1) assert isinstance(conn.ext, ConnectionExtRosbag1)
if ( if (
conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and

View File

@ -143,11 +143,11 @@ class AnyReader:
typs: dict[str, Any] = {} typs: dict[str, Any] = {}
for reader in self.readers: for reader in self.readers:
assert isinstance(reader, Reader1) assert isinstance(reader, Reader1)
for connection in reader.connections.values(): for connection in reader.connections:
typs.update(get_types_from_msg(connection.msgdef, connection.msgtype)) typs.update(get_types_from_msg(connection.msgdef, connection.msgtype))
register_types(typs, self.typestore) register_types(typs, self.typestore)
self.connections = [y for x in self.readers for y in x.connections.values()] self.connections = [y for x in self.readers for y in x.connections]
self.isopen = True self.isopen = True
def close(self) -> None: def close(self) -> None:

View File

@ -342,7 +342,7 @@ class Reader:
raise ReaderError(f'File {str(self.path)!r} does not exist.') raise ReaderError(f'File {str(self.path)!r} does not exist.')
self.bio: Optional[BinaryIO] = None self.bio: Optional[BinaryIO] = None
self.connections: dict[int, Connection] = {} self.connections: list[Connection] = []
self.indexes: dict[int, list[IndexData]] self.indexes: dict[int, list[IndexData]]
self.chunk_infos: list[ChunkInfo] = [] self.chunk_infos: list[ChunkInfo] = []
self.chunks: dict[int, Chunk] = {} self.chunks: dict[int, Chunk] = {}
@ -384,7 +384,7 @@ class Reader:
self.bio.seek(index_pos) self.bio.seek(index_pos)
try: try:
self.connections = dict(self.read_connection() for _ in range(conn_count)) self.connections = [self.read_connection() for _ in range(conn_count)]
self.chunk_infos = [self.read_chunk_info() for _ in range(chunk_count)] self.chunk_infos = [self.read_chunk_info() for _ in range(chunk_count)]
except ReaderError as err: except ReaderError as err:
raise ReaderError(f'Bag index looks damaged: {err.args}') from None raise ReaderError(f'Bag index looks damaged: {err.args}') from None
@ -402,14 +402,15 @@ class Reader:
self.indexes = { self.indexes = {
cid: list(heapq.merge(*x, key=lambda x: x.time)) for cid, x in indexes.items() cid: list(heapq.merge(*x, key=lambda x: x.time)) for cid, x in indexes.items()
} }
assert all(self.indexes[x] for x in self.connections) assert all(self.indexes[x.id] for x in self.connections)
for cid, connection in self.connections.items(): self.connections = [
self.connections[cid] = Connection( Connection(
*connection[0:5], *x[0:5],
len(self.indexes[cid]), len(self.indexes[x.id]),
*connection[6:], *x[6:],
) ) for x in self.connections
]
except ReaderError: except ReaderError:
self.close() self.close()
raise raise
@ -445,7 +446,7 @@ class Reader:
"""Topic information.""" """Topic information."""
topics = {} topics = {}
for topic, group in groupby( for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic), sorted(self.connections, key=lambda x: x.topic),
key=lambda x: x.topic, key=lambda x: x.topic,
): ):
connections = list(group) connections = list(group)
@ -462,7 +463,7 @@ class Reader:
) )
return topics return topics
def read_connection(self) -> tuple[int, Connection]: def read_connection(self) -> Connection:
"""Read connection record from current position.""" """Read connection record from current position."""
assert self.bio assert self.bio
header = Header.read(self.bio, RecordType.CONNECTION) header = Header.read(self.bio, RecordType.CONNECTION)
@ -477,7 +478,7 @@ class Reader:
callerid = header.get_string('callerid') if 'callerid' in header else None callerid = header.get_string('callerid') if 'callerid' in header else None
latching = int(header.get_string('latching')) if 'latching' in header else None latching = int(header.get_string('latching')) if 'latching' in header else None
return conn, Connection( return Connection(
conn, conn,
topic, topic,
normalize_msgtype(typ), normalize_msgtype(typ),
@ -593,7 +594,9 @@ class Reader:
raise ReaderError('Rosbag is not open.') raise ReaderError('Rosbag is not open.')
if not connections: if not connections:
connections = self.connections.values() connections = self.connections
connmap = {x.id: x for x in self.connections}
indexes = [self.indexes[x.id] for x in connections] indexes = [self.indexes[x.id] for x in connections]
for entry in heapq.merge(*indexes): for entry in heapq.merge(*indexes):
@ -624,7 +627,7 @@ class Reader:
raise ReaderError('Expected to find message data.') raise ReaderError('Expected to find message data.')
data = read_bytes(chunk, read_uint32(chunk)) data = read_bytes(chunk, read_uint32(chunk))
connection = self.connections[header.get_uint32('conn')] connection = connmap[header.get_uint32('conn')]
assert entry.time == header.get_time('time') assert entry.time == header.get_time('time')
yield connection, entry.time, data yield connection, entry.time, data

View File

@ -159,7 +159,7 @@ class Writer:
self.bio: Optional[BinaryIO] = None self.bio: Optional[BinaryIO] = None
self.compressor: Callable[[bytes], bytes] = lambda x: x self.compressor: Callable[[bytes], bytes] = lambda x: x
self.compression_format = 'none' self.compression_format = 'none'
self.connections: dict[int, Connection] = {} self.connections: list[Connection] = []
self.chunks: list[WriteChunk] = [ self.chunks: list[WriteChunk] = [
WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)), WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)),
] ]
@ -258,7 +258,7 @@ class Writer:
self, self,
) )
if any(x[1:] == connection[1:] for x in self.connections.values()): if any(x[1:] == connection[1:] for x in self.connections):
raise WriterError( raise WriterError(
f'Connections can only be added once with same arguments: {connection!r}.', f'Connections can only be added once with same arguments: {connection!r}.',
) )
@ -266,7 +266,7 @@ class Writer:
bio = self.chunks[-1].data bio = self.chunks[-1].data
self.write_connection(connection, bio) self.write_connection(connection, bio)
self.connections[connection.id] = connection self.connections.append(connection)
return connection return connection
def write(self, connection: Connection, timestamp: int, data: bytes) -> None: def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
@ -284,7 +284,7 @@ class Writer:
if not self.bio: if not self.bio:
raise WriterError('Bag was not opened.') raise WriterError('Bag was not opened.')
if connection not in self.connections.values(): if connection not in self.connections:
raise WriterError(f'There is no connection {connection!r}.') from None raise WriterError(f'There is no connection {connection!r}.') from None
chunk = self.chunks[-1] chunk = self.chunks[-1]
@ -367,7 +367,7 @@ class Writer:
index_pos = self.bio.tell() index_pos = self.bio.tell()
for connection in self.connections.values(): for connection in self.connections:
self.write_connection(connection, self.bio) self.write_connection(connection, self.bio)
for chunk in self.chunks: for chunk in self.chunks:

View File

@ -136,8 +136,8 @@ class Reader:
if missing := [x for x in self.paths if not x.exists()]: if missing := [x for x in self.paths if not x.exists()]:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
self.connections = { self.connections = [
idx + 1: Connection( Connection(
id=idx + 1, id=idx + 1,
topic=x['topic_metadata']['name'], topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'], msgtype=x['topic_metadata']['type'],
@ -150,9 +150,9 @@ class Reader:
), ),
owner=self, owner=self,
) for idx, x in enumerate(self.metadata['topics_with_message_count']) ) for idx, x in enumerate(self.metadata['topics_with_message_count'])
} ]
noncdr = { noncdr = {
fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) fmt for x in self.connections if isinstance(x.ext, ConnectionExtRosbag2)
if (fmt := x.ext.serialization_format) != 'cdr' if (fmt := x.ext.serialization_format) != 'cdr'
} }
if noncdr: if noncdr:
@ -209,10 +209,7 @@ class Reader:
@property @property
def topics(self) -> dict[str, TopicInfo]: def topics(self) -> dict[str, TopicInfo]:
"""Topic information.""" """Topic information."""
return { return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x])
for x in self.connections.values()
}
def messages( # pylint: disable=too-many-locals def messages( # pylint: disable=too-many-locals
self, self,
@ -278,7 +275,7 @@ class Reader:
cur.execute('SELECT name,id FROM topics') cur.execute('SELECT name,id FROM topics')
connmap: dict[int, Connection] = { connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections.values() if x.topic == row[0]), row[1]: next((x for x in self.connections if x.topic == row[0]),
None) # type: ignore None) # type: ignore
for row in cur for row in cur
} }

View File

@ -81,7 +81,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_mode = '' self.compression_mode = ''
self.compression_format = '' self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {} self.connections: list[Connection] = []
self.counts: dict[int, int] = {} self.counts: dict[int, int] = {}
self.conn: Optional[sqlite3.Connection] = None self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None self.cursor: Optional[sqlite3.Cursor] = None
@ -152,7 +152,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
raise WriterError('Bag was not opened.') raise WriterError('Bag was not opened.')
connection = Connection( connection = Connection(
id=len(self.connections.values()) + 1, id=len(self.connections) + 1,
topic=topic, topic=topic,
msgtype=msgtype, msgtype=msgtype,
msgdef='', msgdef='',
@ -164,14 +164,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
), ),
owner=self, owner=self,
) )
for conn in self.connections.values(): for conn in self.connections:
if ( if (
conn.topic == connection.topic and conn.msgtype == connection.msgtype and conn.topic == connection.topic and conn.msgtype == connection.msgtype and
conn.ext == connection.ext conn.ext == connection.ext
): ):
raise WriterError(f'Connection can only be added once: {connection!r}.') raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection self.connections.append(connection)
self.counts[connection.id] = 0 self.counts[connection.id] = 0
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles) meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
@ -191,7 +191,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
""" """
if not self.cursor: if not self.cursor:
raise WriterError('Bag was not opened.') raise WriterError('Bag was not opened.')
if connection not in self.connections.values(): if connection not in self.connections:
raise WriterError(f'Tried to write to unknown connection {connection!r}.') raise WriterError(f'Tried to write to unknown connection {connection!r}.')
if self.compression_mode == 'message': if self.compression_mode == 'message':
@ -252,7 +252,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
'offered_qos_profiles': x.ext.offered_qos_profiles, 'offered_qos_profiles': x.ext.offered_qos_profiles,
}, },
'message_count': self.counts[x.id], 'message_count': self.counts[x.id],
} for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2) } for x in self.connections if isinstance(x.ext, ConnectionExtRosbag2)
], ],
'compression_format': self.compression_format, 'compression_format': self.compression_format,
'compression_mode': self.compression_mode, 'compression_mode': self.compression_mode,

View File

@ -146,12 +146,12 @@ def test_convert_1to2(tmp_path: Path) -> None:
Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None), Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None),
] ]
readerinst.connections = { readerinst.connections = [
1: connections[0], connections[0],
2: connections[1], connections[1],
3: connections[2], connections[2],
4: connections[3], connections[3],
} ]
readerinst.messages.return_value = [ readerinst.messages.return_value = [
(connections[0], 42, b'\x42'), (connections[0], 42, b'\x42'),
@ -160,14 +160,13 @@ def test_convert_1to2(tmp_path: Path) -> None:
(connections[3], 45, b'\x45'), (connections[3], 45, b'\x45'),
] ]
writerinst.connections = {} writerinst.connections = []
def add_connection(*_: Any) -> Connection: # noqa: ANN401 def add_connection(*_: Any) -> Connection: # noqa: ANN401
"""Mock for Writer.add_connection.""" """Mock for Writer.add_connection."""
writerinst.connections = { writerinst.connections = [
conn.id: conn conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) ]
}
return wconnections[len(writerinst.connections) - 1] return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection writerinst.add_connection.side_effect = add_connection
@ -311,12 +310,12 @@ def test_convert_2to1(tmp_path: Path) -> None:
), ),
] ]
readerinst.connections = { readerinst.connections = [
1: connections[0], connections[0],
2: connections[1], connections[1],
3: connections[2], connections[2],
4: connections[3], connections[3],
} ]
readerinst.messages.return_value = [ readerinst.messages.return_value = [
(connections[0], 42, b'\x42'), (connections[0], 42, b'\x42'),
@ -325,14 +324,13 @@ def test_convert_2to1(tmp_path: Path) -> None:
(connections[3], 45, b'\x45'), (connections[3], 45, b'\x45'),
] ]
writerinst.connections = {} writerinst.connections = []
def add_connection(*_: Any) -> Connection: # noqa: ANN401 def add_connection(*_: Any) -> Connection: # noqa: ANN401
"""Mock for Writer.add_connection.""" """Mock for Writer.add_connection."""
writerinst.connections = { writerinst.connections = [
conn.id: conn conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections) ]
}
return wconnections[len(writerinst.connections) - 1] return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection writerinst.add_connection.side_effect = add_connection

View File

@ -126,7 +126,7 @@ def test_reader(bag: Path) -> None:
assert reader.message_count == 4 assert reader.message_count == 4
if reader.compression_mode: if reader.compression_mode:
assert reader.compression_format == 'zstd' assert reader.compression_format == 'zstd'
assert [*reader.connections.keys()] == [1, 2, 3] assert [x.id for x in reader.connections] == [1, 2, 3]
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint'] assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
gen = reader.messages() gen = reader.messages()
@ -154,7 +154,7 @@ def test_reader(bag: Path) -> None:
def test_message_filters(bag: Path) -> None: def test_message_filters(bag: Path) -> None:
"""Test reader filters messages.""" """Test reader filters messages."""
with Reader(bag) as reader: with Reader(bag) as reader:
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn'] magn_connections = [x for x in reader.connections if x.topic == '/magn']
gen = reader.messages(connections=magn_connections) gen = reader.messages(connections=magn_connections)
connection, _, _ = next(gen) connection, _, _ = next(gen)
assert connection.topic == '/magn' assert connection.topic == '/magn'

View File

@ -274,7 +274,7 @@ def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
assert msgs[0][2] == b'MSGCONTENT5' assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][2] == b'MSGCONTENT10' assert msgs[1][2] == b'MSGCONTENT10'
connections = [x for x in reader.connections.values() if x.topic == '/topic0'] connections = [x for x in reader.connections if x.topic == '/topic0']
msgs = list(reader.messages(connections)) msgs = list(reader.messages(connections))
assert len(msgs) == 1 assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10' assert msgs[0][2] == b'MSGCONTENT10'