Unify rosbag1 and rosbag2 connection class
This commit is contained in:
@@ -4,19 +4,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
|
||||
from rosbags.rosbag1 import Reader as Reader1
|
||||
from rosbags.rosbag1 import ReaderError as ReaderError1
|
||||
from rosbags.rosbag1 import Writer as Writer1
|
||||
from rosbags.rosbag1 import WriterError as WriterError1
|
||||
from rosbags.rosbag1.reader import Connection as Connection1
|
||||
from rosbags.rosbag2 import Reader as Reader2
|
||||
from rosbags.rosbag2 import ReaderError as ReaderError2
|
||||
from rosbags.rosbag2 import Writer as Writer2
|
||||
from rosbags.rosbag2 import WriterError as WriterError2
|
||||
from rosbags.rosbag2.connection import Connection as Connection2
|
||||
from rosbags.serde import cdr_to_ros1, ros1_to_cdr
|
||||
from rosbags.typesys import get_types_from_msg, register_types
|
||||
from rosbags.typesys.msg import generate_msgdef
|
||||
@@ -48,7 +46,7 @@ class ConverterError(Exception):
|
||||
"""Converter Error."""
|
||||
|
||||
|
||||
def upgrade_connection(rconn: Connection1) -> Connection2:
|
||||
def upgrade_connection(rconn: Connection) -> Connection:
|
||||
"""Convert rosbag1 connection to rosbag2 connection.
|
||||
|
||||
Args:
|
||||
@@ -58,17 +56,22 @@ def upgrade_connection(rconn: Connection1) -> Connection2:
|
||||
Rosbag2 connection.
|
||||
|
||||
"""
|
||||
return Connection2(
|
||||
-1,
|
||||
0,
|
||||
assert isinstance(rconn.ext, ConnectionExtRosbag1)
|
||||
return Connection(
|
||||
rconn.id,
|
||||
rconn.topic,
|
||||
rconn.msgtype,
|
||||
'cdr',
|
||||
LATCH if rconn.latching else '',
|
||||
'',
|
||||
'',
|
||||
0,
|
||||
ConnectionExtRosbag2(
|
||||
'cdr',
|
||||
LATCH if rconn.ext.latching else '',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade_connection(rconn: Connection2) -> Connection1:
|
||||
def downgrade_connection(rconn: Connection) -> Connection:
|
||||
"""Convert rosbag2 connection to rosbag1 connection.
|
||||
|
||||
Args:
|
||||
@@ -78,15 +81,19 @@ def downgrade_connection(rconn: Connection2) -> Connection1:
|
||||
Rosbag1 connection.
|
||||
|
||||
"""
|
||||
assert isinstance(rconn.ext, ConnectionExtRosbag2)
|
||||
msgdef, md5sum = generate_msgdef(rconn.msgtype)
|
||||
return Connection1(
|
||||
-1,
|
||||
return Connection(
|
||||
rconn.id,
|
||||
rconn.topic,
|
||||
rconn.msgtype,
|
||||
msgdef,
|
||||
md5sum,
|
||||
None,
|
||||
int('durability: 1' in rconn.offered_qos_profiles),
|
||||
-1,
|
||||
ConnectionExtRosbag1(
|
||||
None,
|
||||
int('durability: 1' in rconn.ext.offered_qos_profiles),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -100,13 +107,26 @@ def convert_1to2(src: Path, dst: Path) -> None:
|
||||
"""
|
||||
with Reader1(src) as reader, Writer2(dst) as writer:
|
||||
typs: dict[str, Any] = {}
|
||||
connmap: dict[int, Connection2] = {}
|
||||
connmap: dict[int, Connection] = {}
|
||||
|
||||
for rconn in reader.connections.values():
|
||||
candidate = upgrade_connection(rconn)
|
||||
existing = next((x for x in writer.connections.values() if x == candidate), None)
|
||||
wconn = existing if existing else writer.add_connection(**asdict(candidate))
|
||||
connmap[rconn.id] = wconn
|
||||
assert isinstance(candidate.ext, ConnectionExtRosbag2)
|
||||
for conn in writer.connections.values():
|
||||
assert isinstance(conn.ext, ConnectionExtRosbag2)
|
||||
if (
|
||||
conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and
|
||||
conn.ext == candidate.ext
|
||||
):
|
||||
break
|
||||
else:
|
||||
conn = writer.add_connection(
|
||||
candidate.topic,
|
||||
candidate.msgtype,
|
||||
candidate.ext.serialization_format,
|
||||
candidate.ext.offered_qos_profiles,
|
||||
)
|
||||
connmap[rconn.id] = conn
|
||||
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
|
||||
register_types(typs)
|
||||
|
||||
@@ -124,22 +144,27 @@ def convert_2to1(src: Path, dst: Path) -> None:
|
||||
|
||||
"""
|
||||
with Reader2(src) as reader, Writer1(dst) as writer:
|
||||
connmap: dict[int, Connection1] = {}
|
||||
connmap: dict[int, Connection] = {}
|
||||
for rconn in reader.connections.values():
|
||||
candidate = downgrade_connection(rconn)
|
||||
# yapf: disable
|
||||
existing = next(
|
||||
(
|
||||
x
|
||||
for x in writer.connections.values()
|
||||
if x.topic == candidate.topic
|
||||
if x.md5sum == candidate.md5sum
|
||||
if x.latching == candidate.latching
|
||||
),
|
||||
None,
|
||||
)
|
||||
# yapf: enable
|
||||
connmap[rconn.id] = existing if existing else writer.add_connection(*candidate[1:])
|
||||
assert isinstance(candidate.ext, ConnectionExtRosbag1)
|
||||
for conn in writer.connections.values():
|
||||
assert isinstance(conn.ext, ConnectionExtRosbag1)
|
||||
if (
|
||||
conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and
|
||||
conn.ext.latching == candidate.ext.latching
|
||||
):
|
||||
break
|
||||
else:
|
||||
conn = writer.add_connection(
|
||||
candidate.topic,
|
||||
candidate.msgtype,
|
||||
candidate.msgdef,
|
||||
candidate.md5sum,
|
||||
candidate.ext.callerid,
|
||||
candidate.ext.latching,
|
||||
)
|
||||
connmap[rconn.id] = conn
|
||||
|
||||
for rconn, timestamp, data in reader.messages():
|
||||
data = cdr_to_ros1(data, rconn.msgtype)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2020-2022 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Shared interfaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class ConnectionExtRosbag1(NamedTuple):
|
||||
"""Rosbag1 specific connection extensions."""
|
||||
|
||||
callerid: Optional[str]
|
||||
latching: Optional[int]
|
||||
|
||||
|
||||
class ConnectionExtRosbag2(NamedTuple):
|
||||
"""Rosbag2 specific connection extensions."""
|
||||
|
||||
serialization_format: str
|
||||
offered_qos_profiles: str
|
||||
|
||||
|
||||
class Connection(NamedTuple):
|
||||
"""Connection information."""
|
||||
|
||||
id: int
|
||||
topic: str
|
||||
msgtype: str
|
||||
msgdef: str
|
||||
md5sum: str
|
||||
msgcount: int
|
||||
ext: Union[ConnectionExtRosbag1, ConnectionExtRosbag2]
|
||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, NamedTuple
|
||||
|
||||
from lz4.frame import decompress as lz4_decompress
|
||||
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag1
|
||||
from rosbags.typesys.msg import normalize_msgtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,18 +51,6 @@ class RecordType(IntEnum):
|
||||
CONNECTION = 7
|
||||
|
||||
|
||||
class Connection(NamedTuple):
|
||||
"""Connection information."""
|
||||
|
||||
id: int
|
||||
topic: str
|
||||
msgtype: str
|
||||
msgdef: str
|
||||
md5sum: str
|
||||
callerid: Optional[str]
|
||||
latching: Optional[int]
|
||||
|
||||
|
||||
class ChunkInfo(NamedTuple):
|
||||
"""Chunk information."""
|
||||
|
||||
@@ -427,6 +416,13 @@ class Reader:
|
||||
}
|
||||
assert all(self.indexes[x] for x in self.connections)
|
||||
|
||||
for cid, connection in self.connections.items():
|
||||
self.connections[cid] = Connection(
|
||||
*connection[0:5],
|
||||
len(self.indexes[cid]),
|
||||
connection[6],
|
||||
)
|
||||
|
||||
self.topics = {}
|
||||
for topic, group in groupby(
|
||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||
@@ -499,8 +495,11 @@ class Reader:
|
||||
normalize_msgtype(typ),
|
||||
msgdef,
|
||||
md5sum,
|
||||
callerid,
|
||||
latching,
|
||||
0,
|
||||
ConnectionExtRosbag1(
|
||||
callerid,
|
||||
latching,
|
||||
),
|
||||
)
|
||||
|
||||
def read_chunk_info(self) -> ChunkInfo:
|
||||
|
||||
@@ -15,9 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from lz4.frame import compress as lz4_compress
|
||||
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag1
|
||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||
|
||||
from .reader import Connection, RecordType
|
||||
from .reader import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
@@ -249,8 +250,11 @@ class Writer:
|
||||
denormalize_msgtype(msgtype),
|
||||
msgdef,
|
||||
md5sum,
|
||||
callerid,
|
||||
latching,
|
||||
-1,
|
||||
ConnectionExtRosbag1(
|
||||
callerid,
|
||||
latching,
|
||||
),
|
||||
)
|
||||
|
||||
if any(x[1:] == connection[1:] for x in self.connections.values()):
|
||||
@@ -314,10 +318,11 @@ class Writer:
|
||||
header.set_string('type', connection.msgtype)
|
||||
header.set_string('md5sum', connection.md5sum)
|
||||
header.set_string('message_definition', connection.msgdef)
|
||||
if connection.callerid is not None:
|
||||
header.set_string('callerid', connection.callerid)
|
||||
if connection.latching is not None:
|
||||
header.set_string('latching', str(connection.latching))
|
||||
assert isinstance(connection.ext, ConnectionExtRosbag1)
|
||||
if connection.ext.callerid is not None:
|
||||
header.set_string('callerid', connection.ext.callerid)
|
||||
if connection.ext.latching is not None:
|
||||
header.set_string('latching', str(connection.ext.latching))
|
||||
header.write(bio)
|
||||
|
||||
def write_chunk(self, chunk: WriteChunk) -> None:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2020-2022 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag2 connection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Connection:
|
||||
"""Connection metadata."""
|
||||
id: int = field(compare=False) # pylint: disable=invalid-name
|
||||
count: int = field(compare=False)
|
||||
topic: str
|
||||
msgtype: str
|
||||
serialization_format: str
|
||||
offered_qos_profiles: str
|
||||
@@ -14,7 +14,7 @@ import zstandard
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.error import YAMLError
|
||||
|
||||
from .connection import Connection
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
@@ -139,15 +139,20 @@ class Reader:
|
||||
self.connections = {
|
||||
idx + 1: Connection(
|
||||
id=idx + 1,
|
||||
count=x['message_count'],
|
||||
topic=x['topic_metadata']['name'],
|
||||
msgtype=x['topic_metadata']['type'],
|
||||
serialization_format=x['topic_metadata']['serialization_format'],
|
||||
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
|
||||
msgdef='',
|
||||
md5sum='',
|
||||
msgcount=x['message_count'],
|
||||
ext=ConnectionExtRosbag2(
|
||||
serialization_format=x['topic_metadata']['serialization_format'],
|
||||
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
|
||||
),
|
||||
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
|
||||
}
|
||||
noncdr = {
|
||||
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
|
||||
fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
|
||||
if (fmt := x.ext.serialization_format) != 'cdr'
|
||||
}
|
||||
if noncdr:
|
||||
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from .connection import Connection
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
@@ -82,6 +82,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_format = ''
|
||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.counts: dict[int, int] = {}
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.cursor: Optional[sqlite3.Cursor] = None
|
||||
|
||||
@@ -152,16 +153,25 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
connection = Connection(
|
||||
id=len(self.connections.values()) + 1,
|
||||
count=0,
|
||||
topic=topic,
|
||||
msgtype=msgtype,
|
||||
serialization_format=serialization_format,
|
||||
offered_qos_profiles=offered_qos_profiles,
|
||||
msgdef='',
|
||||
md5sum='',
|
||||
msgcount=0,
|
||||
ext=ConnectionExtRosbag2(
|
||||
serialization_format=serialization_format,
|
||||
offered_qos_profiles=offered_qos_profiles,
|
||||
),
|
||||
)
|
||||
if connection in self.connections.values():
|
||||
raise WriterError(f'Connection can only be added once: {connection!r}.')
|
||||
for conn in self.connections.values():
|
||||
if (
|
||||
conn.topic == connection.topic and conn.msgtype == connection.msgtype and
|
||||
conn.ext == connection.ext
|
||||
):
|
||||
raise WriterError(f'Connection can only be added once: {connection!r}.')
|
||||
|
||||
self.connections[connection.id] = connection
|
||||
self.counts[connection.id] = 0
|
||||
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
|
||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||
return connection
|
||||
@@ -191,7 +201,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
|
||||
(connection.id, timestamp, data),
|
||||
)
|
||||
connection.count += 1
|
||||
self.counts[connection.id] += 1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2 after writing.
|
||||
@@ -237,11 +247,11 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
'topic_metadata': {
|
||||
'name': x.topic,
|
||||
'type': x.msgtype,
|
||||
'serialization_format': x.serialization_format,
|
||||
'offered_qos_profiles': x.offered_qos_profiles,
|
||||
'serialization_format': x.ext.serialization_format,
|
||||
'offered_qos_profiles': x.ext.offered_qos_profiles,
|
||||
},
|
||||
'message_count': x.count,
|
||||
} for x in self.connections.values()
|
||||
'message_count': self.counts[x.id],
|
||||
} for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
|
||||
],
|
||||
'compression_format': self.compression_format,
|
||||
'compression_mode': self.compression_mode,
|
||||
|
||||
Reference in New Issue
Block a user