Unify rosbag1 and rosbag2 connection class

This commit is contained in:
Marko Durkovic
2022-04-13 09:40:22 +02:00
parent dee7e9c2fc
commit 16d1758327
13 changed files with 301 additions and 150 deletions
+57 -32
View File
@@ -4,19 +4,17 @@
from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
from rosbags.rosbag1 import Reader as Reader1
from rosbags.rosbag1 import ReaderError as ReaderError1
from rosbags.rosbag1 import Writer as Writer1
from rosbags.rosbag1 import WriterError as WriterError1
from rosbags.rosbag1.reader import Connection as Connection1
from rosbags.rosbag2 import Reader as Reader2
from rosbags.rosbag2 import ReaderError as ReaderError2
from rosbags.rosbag2 import Writer as Writer2
from rosbags.rosbag2 import WriterError as WriterError2
from rosbags.rosbag2.connection import Connection as Connection2
from rosbags.serde import cdr_to_ros1, ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types
from rosbags.typesys.msg import generate_msgdef
@@ -48,7 +46,7 @@ class ConverterError(Exception):
"""Converter Error."""
def upgrade_connection(rconn: Connection1) -> Connection2:
def upgrade_connection(rconn: Connection) -> Connection:
"""Convert rosbag1 connection to rosbag2 connection.
Args:
@@ -58,17 +56,22 @@ def upgrade_connection(rconn: Connection1) -> Connection2:
Rosbag2 connection.
"""
return Connection2(
-1,
0,
assert isinstance(rconn.ext, ConnectionExtRosbag1)
return Connection(
rconn.id,
rconn.topic,
rconn.msgtype,
'cdr',
LATCH if rconn.latching else '',
'',
'',
0,
ConnectionExtRosbag2(
'cdr',
LATCH if rconn.ext.latching else '',
),
)
def downgrade_connection(rconn: Connection2) -> Connection1:
def downgrade_connection(rconn: Connection) -> Connection:
"""Convert rosbag2 connection to rosbag1 connection.
Args:
@@ -78,15 +81,19 @@ def downgrade_connection(rconn: Connection2) -> Connection1:
Rosbag1 connection.
"""
assert isinstance(rconn.ext, ConnectionExtRosbag2)
msgdef, md5sum = generate_msgdef(rconn.msgtype)
return Connection1(
-1,
return Connection(
rconn.id,
rconn.topic,
rconn.msgtype,
msgdef,
md5sum,
None,
int('durability: 1' in rconn.offered_qos_profiles),
-1,
ConnectionExtRosbag1(
None,
int('durability: 1' in rconn.ext.offered_qos_profiles),
),
)
@@ -100,13 +107,26 @@ def convert_1to2(src: Path, dst: Path) -> None:
"""
with Reader1(src) as reader, Writer2(dst) as writer:
typs: dict[str, Any] = {}
connmap: dict[int, Connection2] = {}
connmap: dict[int, Connection] = {}
for rconn in reader.connections.values():
candidate = upgrade_connection(rconn)
existing = next((x for x in writer.connections.values() if x == candidate), None)
wconn = existing if existing else writer.add_connection(**asdict(candidate))
connmap[rconn.id] = wconn
assert isinstance(candidate.ext, ConnectionExtRosbag2)
for conn in writer.connections.values():
assert isinstance(conn.ext, ConnectionExtRosbag2)
if (
conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and
conn.ext == candidate.ext
):
break
else:
conn = writer.add_connection(
candidate.topic,
candidate.msgtype,
candidate.ext.serialization_format,
candidate.ext.offered_qos_profiles,
)
connmap[rconn.id] = conn
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
register_types(typs)
@@ -124,22 +144,27 @@ def convert_2to1(src: Path, dst: Path) -> None:
"""
with Reader2(src) as reader, Writer1(dst) as writer:
connmap: dict[int, Connection1] = {}
connmap: dict[int, Connection] = {}
for rconn in reader.connections.values():
candidate = downgrade_connection(rconn)
# yapf: disable
existing = next(
(
x
for x in writer.connections.values()
if x.topic == candidate.topic
if x.md5sum == candidate.md5sum
if x.latching == candidate.latching
),
None,
)
# yapf: enable
connmap[rconn.id] = existing if existing else writer.add_connection(*candidate[1:])
assert isinstance(candidate.ext, ConnectionExtRosbag1)
for conn in writer.connections.values():
assert isinstance(conn.ext, ConnectionExtRosbag1)
if (
conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and
conn.ext.latching == candidate.ext.latching
):
break
else:
conn = writer.add_connection(
candidate.topic,
candidate.msgtype,
candidate.msgdef,
candidate.md5sum,
candidate.ext.callerid,
candidate.ext.latching,
)
connmap[rconn.id] = conn
for rconn, timestamp, data in reader.messages():
data = cdr_to_ros1(data, rconn.msgtype)
+36
View File
@@ -0,0 +1,36 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Shared interfaces."""
from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Optional, Union
class ConnectionExtRosbag1(NamedTuple):
"""Rosbag1 specific connection extensions."""
callerid: Optional[str]
latching: Optional[int]
class ConnectionExtRosbag2(NamedTuple):
"""Rosbag2 specific connection extensions."""
serialization_format: str
offered_qos_profiles: str
class Connection(NamedTuple):
"""Connection information."""
id: int
topic: str
msgtype: str
msgdef: str
md5sum: str
msgcount: int
ext: Union[ConnectionExtRosbag1, ConnectionExtRosbag2]
View File
+13 -14
View File
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, NamedTuple
from lz4.frame import decompress as lz4_decompress
from rosbags.interfaces import Connection, ConnectionExtRosbag1
from rosbags.typesys.msg import normalize_msgtype
if TYPE_CHECKING:
@@ -50,18 +51,6 @@ class RecordType(IntEnum):
CONNECTION = 7
class Connection(NamedTuple):
"""Connection information."""
id: int
topic: str
msgtype: str
msgdef: str
md5sum: str
callerid: Optional[str]
latching: Optional[int]
class ChunkInfo(NamedTuple):
"""Chunk information."""
@@ -427,6 +416,13 @@ class Reader:
}
assert all(self.indexes[x] for x in self.connections)
for cid, connection in self.connections.items():
self.connections[cid] = Connection(
*connection[0:5],
len(self.indexes[cid]),
connection[6],
)
self.topics = {}
for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic),
@@ -499,8 +495,11 @@ class Reader:
normalize_msgtype(typ),
msgdef,
md5sum,
callerid,
latching,
0,
ConnectionExtRosbag1(
callerid,
latching,
),
)
def read_chunk_info(self) -> ChunkInfo:
+12 -7
View File
@@ -15,9 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict
from lz4.frame import compress as lz4_compress
from rosbags.interfaces import Connection, ConnectionExtRosbag1
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
from .reader import Connection, RecordType
from .reader import RecordType
if TYPE_CHECKING:
from types import TracebackType
@@ -249,8 +250,11 @@ class Writer:
denormalize_msgtype(msgtype),
msgdef,
md5sum,
callerid,
latching,
-1,
ConnectionExtRosbag1(
callerid,
latching,
),
)
if any(x[1:] == connection[1:] for x in self.connections.values()):
@@ -314,10 +318,11 @@ class Writer:
header.set_string('type', connection.msgtype)
header.set_string('md5sum', connection.md5sum)
header.set_string('message_definition', connection.msgdef)
if connection.callerid is not None:
header.set_string('callerid', connection.callerid)
if connection.latching is not None:
header.set_string('latching', str(connection.latching))
assert isinstance(connection.ext, ConnectionExtRosbag1)
if connection.ext.callerid is not None:
header.set_string('callerid', connection.ext.callerid)
if connection.ext.latching is not None:
header.set_string('latching', str(connection.ext.latching))
header.write(bio)
def write_chunk(self, chunk: WriteChunk) -> None:
-18
View File
@@ -1,18 +0,0 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 connection."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class Connection:
"""Connection metadata."""
id: int = field(compare=False) # pylint: disable=invalid-name
count: int = field(compare=False)
topic: str
msgtype: str
serialization_format: str
offered_qos_profiles: str
+10 -5
View File
@@ -14,7 +14,7 @@ import zstandard
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError
from .connection import Connection
from rosbags.interfaces import Connection, ConnectionExtRosbag2
if TYPE_CHECKING:
from types import TracebackType
@@ -139,15 +139,20 @@ class Reader:
self.connections = {
idx + 1: Connection(
id=idx + 1,
count=x['message_count'],
topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'],
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
msgdef='',
md5sum='',
msgcount=x['message_count'],
ext=ConnectionExtRosbag2(
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
),
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
}
noncdr = {
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
if (fmt := x.ext.serialization_format) != 'cdr'
}
if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
+21 -11
View File
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML
from .connection import Connection
from rosbags.interfaces import Connection, ConnectionExtRosbag2
if TYPE_CHECKING:
from types import TracebackType
@@ -82,6 +82,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {}
self.counts: dict[int, int] = {}
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
@@ -152,16 +153,25 @@ class Writer: # pylint: disable=too-many-instance-attributes
connection = Connection(
id=len(self.connections.values()) + 1,
count=0,
topic=topic,
msgtype=msgtype,
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
msgdef='',
md5sum='',
msgcount=0,
ext=ConnectionExtRosbag2(
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
),
)
if connection in self.connections.values():
raise WriterError(f'Connection can only be added once: {connection!r}.')
for conn in self.connections.values():
if (
conn.topic == connection.topic and conn.msgtype == connection.msgtype and
conn.ext == connection.ext
):
raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection
self.counts[connection.id] = 0
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
@@ -191,7 +201,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
(connection.id, timestamp, data),
)
connection.count += 1
self.counts[connection.id] += 1
def close(self) -> None:
"""Close rosbag2 after writing.
@@ -237,11 +247,11 @@ class Writer: # pylint: disable=too-many-instance-attributes
'topic_metadata': {
'name': x.topic,
'type': x.msgtype,
'serialization_format': x.serialization_format,
'offered_qos_profiles': x.offered_qos_profiles,
'serialization_format': x.ext.serialization_format,
'offered_qos_profiles': x.ext.offered_qos_profiles,
},
'message_count': x.count,
} for x in self.connections.values()
'message_count': self.counts[x.id],
} for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
],
'compression_format': self.compression_format,
'compression_mode': self.compression_mode,