Unify rosbag1 and rosbag2 connection class
This commit is contained in:
parent
dee7e9c2fc
commit
16d1758327
@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
from rosbags.interfaces import ConnectionExtRosbag2
|
||||||
from rosbags.rosbag2 import Reader, Writer
|
from rosbags.rosbag2 import Reader, Writer
|
||||||
from rosbags.serde import deserialize_cdr, serialize_cdr
|
from rosbags.serde import deserialize_cdr, serialize_cdr
|
||||||
|
|
||||||
@ -23,11 +24,12 @@ def offset_timestamps(src: Path, dst: Path, offset: int) -> None:
|
|||||||
with Reader(src) as reader, Writer(dst) as writer:
|
with Reader(src) as reader, Writer(dst) as writer:
|
||||||
conn_map = {}
|
conn_map = {}
|
||||||
for conn in reader.connections.values():
|
for conn in reader.connections.values():
|
||||||
|
ext = cast(ConnectionExtRosbag2, conn.ext)
|
||||||
conn_map[conn.id] = writer.add_connection(
|
conn_map[conn.id] = writer.add_connection(
|
||||||
conn.topic,
|
conn.topic,
|
||||||
conn.msgtype,
|
conn.msgtype,
|
||||||
conn.serialization_format,
|
ext.serialization_format,
|
||||||
conn.offered_qos_profiles,
|
ext.offered_qos_profiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
for conn, timestamp, data in reader.messages():
|
for conn, timestamp, data in reader.messages():
|
||||||
|
|||||||
@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
from rosbags.interfaces import ConnectionExtRosbag2
|
||||||
from rosbags.rosbag2 import Reader, Writer
|
from rosbags.rosbag2 import Reader, Writer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -24,11 +25,12 @@ def remove_topic(src: Path, dst: Path, topic: str) -> None:
|
|||||||
for conn in reader.connections.values():
|
for conn in reader.connections.values():
|
||||||
if conn.topic == topic:
|
if conn.topic == topic:
|
||||||
continue
|
continue
|
||||||
|
ext = cast(ConnectionExtRosbag2, conn.ext)
|
||||||
conn_map[conn.id] = writer.add_connection(
|
conn_map[conn.id] = writer.add_connection(
|
||||||
conn.topic,
|
conn.topic,
|
||||||
conn.msgtype,
|
conn.msgtype,
|
||||||
conn.serialization_format,
|
ext.serialization_format,
|
||||||
conn.offered_qos_profiles,
|
ext.offered_qos_profiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
rconns = [reader.connections[x] for x in conn_map]
|
rconns = [reader.connections[x] for x in conn_map]
|
||||||
|
|||||||
@ -4,19 +4,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
|
||||||
from rosbags.rosbag1 import Reader as Reader1
|
from rosbags.rosbag1 import Reader as Reader1
|
||||||
from rosbags.rosbag1 import ReaderError as ReaderError1
|
from rosbags.rosbag1 import ReaderError as ReaderError1
|
||||||
from rosbags.rosbag1 import Writer as Writer1
|
from rosbags.rosbag1 import Writer as Writer1
|
||||||
from rosbags.rosbag1 import WriterError as WriterError1
|
from rosbags.rosbag1 import WriterError as WriterError1
|
||||||
from rosbags.rosbag1.reader import Connection as Connection1
|
|
||||||
from rosbags.rosbag2 import Reader as Reader2
|
from rosbags.rosbag2 import Reader as Reader2
|
||||||
from rosbags.rosbag2 import ReaderError as ReaderError2
|
from rosbags.rosbag2 import ReaderError as ReaderError2
|
||||||
from rosbags.rosbag2 import Writer as Writer2
|
from rosbags.rosbag2 import Writer as Writer2
|
||||||
from rosbags.rosbag2 import WriterError as WriterError2
|
from rosbags.rosbag2 import WriterError as WriterError2
|
||||||
from rosbags.rosbag2.connection import Connection as Connection2
|
|
||||||
from rosbags.serde import cdr_to_ros1, ros1_to_cdr
|
from rosbags.serde import cdr_to_ros1, ros1_to_cdr
|
||||||
from rosbags.typesys import get_types_from_msg, register_types
|
from rosbags.typesys import get_types_from_msg, register_types
|
||||||
from rosbags.typesys.msg import generate_msgdef
|
from rosbags.typesys.msg import generate_msgdef
|
||||||
@ -48,7 +46,7 @@ class ConverterError(Exception):
|
|||||||
"""Converter Error."""
|
"""Converter Error."""
|
||||||
|
|
||||||
|
|
||||||
def upgrade_connection(rconn: Connection1) -> Connection2:
|
def upgrade_connection(rconn: Connection) -> Connection:
|
||||||
"""Convert rosbag1 connection to rosbag2 connection.
|
"""Convert rosbag1 connection to rosbag2 connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -58,17 +56,22 @@ def upgrade_connection(rconn: Connection1) -> Connection2:
|
|||||||
Rosbag2 connection.
|
Rosbag2 connection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return Connection2(
|
assert isinstance(rconn.ext, ConnectionExtRosbag1)
|
||||||
-1,
|
return Connection(
|
||||||
0,
|
rconn.id,
|
||||||
rconn.topic,
|
rconn.topic,
|
||||||
rconn.msgtype,
|
rconn.msgtype,
|
||||||
|
'',
|
||||||
|
'',
|
||||||
|
0,
|
||||||
|
ConnectionExtRosbag2(
|
||||||
'cdr',
|
'cdr',
|
||||||
LATCH if rconn.latching else '',
|
LATCH if rconn.ext.latching else '',
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade_connection(rconn: Connection2) -> Connection1:
|
def downgrade_connection(rconn: Connection) -> Connection:
|
||||||
"""Convert rosbag2 connection to rosbag1 connection.
|
"""Convert rosbag2 connection to rosbag1 connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -78,15 +81,19 @@ def downgrade_connection(rconn: Connection2) -> Connection1:
|
|||||||
Rosbag1 connection.
|
Rosbag1 connection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(rconn.ext, ConnectionExtRosbag2)
|
||||||
msgdef, md5sum = generate_msgdef(rconn.msgtype)
|
msgdef, md5sum = generate_msgdef(rconn.msgtype)
|
||||||
return Connection1(
|
return Connection(
|
||||||
-1,
|
rconn.id,
|
||||||
rconn.topic,
|
rconn.topic,
|
||||||
rconn.msgtype,
|
rconn.msgtype,
|
||||||
msgdef,
|
msgdef,
|
||||||
md5sum,
|
md5sum,
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag1(
|
||||||
None,
|
None,
|
||||||
int('durability: 1' in rconn.offered_qos_profiles),
|
int('durability: 1' in rconn.ext.offered_qos_profiles),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -100,13 +107,26 @@ def convert_1to2(src: Path, dst: Path) -> None:
|
|||||||
"""
|
"""
|
||||||
with Reader1(src) as reader, Writer2(dst) as writer:
|
with Reader1(src) as reader, Writer2(dst) as writer:
|
||||||
typs: dict[str, Any] = {}
|
typs: dict[str, Any] = {}
|
||||||
connmap: dict[int, Connection2] = {}
|
connmap: dict[int, Connection] = {}
|
||||||
|
|
||||||
for rconn in reader.connections.values():
|
for rconn in reader.connections.values():
|
||||||
candidate = upgrade_connection(rconn)
|
candidate = upgrade_connection(rconn)
|
||||||
existing = next((x for x in writer.connections.values() if x == candidate), None)
|
assert isinstance(candidate.ext, ConnectionExtRosbag2)
|
||||||
wconn = existing if existing else writer.add_connection(**asdict(candidate))
|
for conn in writer.connections.values():
|
||||||
connmap[rconn.id] = wconn
|
assert isinstance(conn.ext, ConnectionExtRosbag2)
|
||||||
|
if (
|
||||||
|
conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and
|
||||||
|
conn.ext == candidate.ext
|
||||||
|
):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
conn = writer.add_connection(
|
||||||
|
candidate.topic,
|
||||||
|
candidate.msgtype,
|
||||||
|
candidate.ext.serialization_format,
|
||||||
|
candidate.ext.offered_qos_profiles,
|
||||||
|
)
|
||||||
|
connmap[rconn.id] = conn
|
||||||
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
|
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
|
||||||
register_types(typs)
|
register_types(typs)
|
||||||
|
|
||||||
@ -124,22 +144,27 @@ def convert_2to1(src: Path, dst: Path) -> None:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
with Reader2(src) as reader, Writer1(dst) as writer:
|
with Reader2(src) as reader, Writer1(dst) as writer:
|
||||||
connmap: dict[int, Connection1] = {}
|
connmap: dict[int, Connection] = {}
|
||||||
for rconn in reader.connections.values():
|
for rconn in reader.connections.values():
|
||||||
candidate = downgrade_connection(rconn)
|
candidate = downgrade_connection(rconn)
|
||||||
# yapf: disable
|
assert isinstance(candidate.ext, ConnectionExtRosbag1)
|
||||||
existing = next(
|
for conn in writer.connections.values():
|
||||||
(
|
assert isinstance(conn.ext, ConnectionExtRosbag1)
|
||||||
x
|
if (
|
||||||
for x in writer.connections.values()
|
conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and
|
||||||
if x.topic == candidate.topic
|
conn.ext.latching == candidate.ext.latching
|
||||||
if x.md5sum == candidate.md5sum
|
):
|
||||||
if x.latching == candidate.latching
|
break
|
||||||
),
|
else:
|
||||||
None,
|
conn = writer.add_connection(
|
||||||
|
candidate.topic,
|
||||||
|
candidate.msgtype,
|
||||||
|
candidate.msgdef,
|
||||||
|
candidate.md5sum,
|
||||||
|
candidate.ext.callerid,
|
||||||
|
candidate.ext.latching,
|
||||||
)
|
)
|
||||||
# yapf: enable
|
connmap[rconn.id] = conn
|
||||||
connmap[rconn.id] = existing if existing else writer.add_connection(*candidate[1:])
|
|
||||||
|
|
||||||
for rconn, timestamp, data in reader.messages():
|
for rconn, timestamp, data in reader.messages():
|
||||||
data = cdr_to_ros1(data, rconn.msgtype)
|
data = cdr_to_ros1(data, rconn.msgtype)
|
||||||
|
|||||||
36
src/rosbags/interfaces/__init__.py
Normal file
36
src/rosbags/interfaces/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright 2020-2022 Ternaris.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Shared interfaces."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionExtRosbag1(NamedTuple):
|
||||||
|
"""Rosbag1 specific connection extensions."""
|
||||||
|
|
||||||
|
callerid: Optional[str]
|
||||||
|
latching: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionExtRosbag2(NamedTuple):
|
||||||
|
"""Rosbag2 specific connection extensions."""
|
||||||
|
|
||||||
|
serialization_format: str
|
||||||
|
offered_qos_profiles: str
|
||||||
|
|
||||||
|
|
||||||
|
class Connection(NamedTuple):
|
||||||
|
"""Connection information."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
topic: str
|
||||||
|
msgtype: str
|
||||||
|
msgdef: str
|
||||||
|
md5sum: str
|
||||||
|
msgcount: int
|
||||||
|
ext: Union[ConnectionExtRosbag1, ConnectionExtRosbag2]
|
||||||
0
src/rosbags/interfaces/py.typed
Normal file
0
src/rosbags/interfaces/py.typed
Normal file
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, NamedTuple
|
|||||||
|
|
||||||
from lz4.frame import decompress as lz4_decompress
|
from lz4.frame import decompress as lz4_decompress
|
||||||
|
|
||||||
|
from rosbags.interfaces import Connection, ConnectionExtRosbag1
|
||||||
from rosbags.typesys.msg import normalize_msgtype
|
from rosbags.typesys.msg import normalize_msgtype
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -50,18 +51,6 @@ class RecordType(IntEnum):
|
|||||||
CONNECTION = 7
|
CONNECTION = 7
|
||||||
|
|
||||||
|
|
||||||
class Connection(NamedTuple):
|
|
||||||
"""Connection information."""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
topic: str
|
|
||||||
msgtype: str
|
|
||||||
msgdef: str
|
|
||||||
md5sum: str
|
|
||||||
callerid: Optional[str]
|
|
||||||
latching: Optional[int]
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkInfo(NamedTuple):
|
class ChunkInfo(NamedTuple):
|
||||||
"""Chunk information."""
|
"""Chunk information."""
|
||||||
|
|
||||||
@ -427,6 +416,13 @@ class Reader:
|
|||||||
}
|
}
|
||||||
assert all(self.indexes[x] for x in self.connections)
|
assert all(self.indexes[x] for x in self.connections)
|
||||||
|
|
||||||
|
for cid, connection in self.connections.items():
|
||||||
|
self.connections[cid] = Connection(
|
||||||
|
*connection[0:5],
|
||||||
|
len(self.indexes[cid]),
|
||||||
|
connection[6],
|
||||||
|
)
|
||||||
|
|
||||||
self.topics = {}
|
self.topics = {}
|
||||||
for topic, group in groupby(
|
for topic, group in groupby(
|
||||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||||
@ -499,8 +495,11 @@ class Reader:
|
|||||||
normalize_msgtype(typ),
|
normalize_msgtype(typ),
|
||||||
msgdef,
|
msgdef,
|
||||||
md5sum,
|
md5sum,
|
||||||
|
0,
|
||||||
|
ConnectionExtRosbag1(
|
||||||
callerid,
|
callerid,
|
||||||
latching,
|
latching,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def read_chunk_info(self) -> ChunkInfo:
|
def read_chunk_info(self) -> ChunkInfo:
|
||||||
|
|||||||
@ -15,9 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict
|
|||||||
|
|
||||||
from lz4.frame import compress as lz4_compress
|
from lz4.frame import compress as lz4_compress
|
||||||
|
|
||||||
|
from rosbags.interfaces import Connection, ConnectionExtRosbag1
|
||||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||||
|
|
||||||
from .reader import Connection, RecordType
|
from .reader import RecordType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
@ -249,8 +250,11 @@ class Writer:
|
|||||||
denormalize_msgtype(msgtype),
|
denormalize_msgtype(msgtype),
|
||||||
msgdef,
|
msgdef,
|
||||||
md5sum,
|
md5sum,
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag1(
|
||||||
callerid,
|
callerid,
|
||||||
latching,
|
latching,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if any(x[1:] == connection[1:] for x in self.connections.values()):
|
if any(x[1:] == connection[1:] for x in self.connections.values()):
|
||||||
@ -314,10 +318,11 @@ class Writer:
|
|||||||
header.set_string('type', connection.msgtype)
|
header.set_string('type', connection.msgtype)
|
||||||
header.set_string('md5sum', connection.md5sum)
|
header.set_string('md5sum', connection.md5sum)
|
||||||
header.set_string('message_definition', connection.msgdef)
|
header.set_string('message_definition', connection.msgdef)
|
||||||
if connection.callerid is not None:
|
assert isinstance(connection.ext, ConnectionExtRosbag1)
|
||||||
header.set_string('callerid', connection.callerid)
|
if connection.ext.callerid is not None:
|
||||||
if connection.latching is not None:
|
header.set_string('callerid', connection.ext.callerid)
|
||||||
header.set_string('latching', str(connection.latching))
|
if connection.ext.latching is not None:
|
||||||
|
header.set_string('latching', str(connection.ext.latching))
|
||||||
header.write(bio)
|
header.write(bio)
|
||||||
|
|
||||||
def write_chunk(self, chunk: WriteChunk) -> None:
|
def write_chunk(self, chunk: WriteChunk) -> None:
|
||||||
|
|||||||
@ -1,18 +0,0 @@
|
|||||||
# Copyright 2020-2022 Ternaris.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""Rosbag2 connection."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Connection:
|
|
||||||
"""Connection metadata."""
|
|
||||||
id: int = field(compare=False) # pylint: disable=invalid-name
|
|
||||||
count: int = field(compare=False)
|
|
||||||
topic: str
|
|
||||||
msgtype: str
|
|
||||||
serialization_format: str
|
|
||||||
offered_qos_profiles: str
|
|
||||||
@ -14,7 +14,7 @@ import zstandard
|
|||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
from ruamel.yaml.error import YAMLError
|
from ruamel.yaml.error import YAMLError
|
||||||
|
|
||||||
from .connection import Connection
|
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
@ -139,15 +139,20 @@ class Reader:
|
|||||||
self.connections = {
|
self.connections = {
|
||||||
idx + 1: Connection(
|
idx + 1: Connection(
|
||||||
id=idx + 1,
|
id=idx + 1,
|
||||||
count=x['message_count'],
|
|
||||||
topic=x['topic_metadata']['name'],
|
topic=x['topic_metadata']['name'],
|
||||||
msgtype=x['topic_metadata']['type'],
|
msgtype=x['topic_metadata']['type'],
|
||||||
|
msgdef='',
|
||||||
|
md5sum='',
|
||||||
|
msgcount=x['message_count'],
|
||||||
|
ext=ConnectionExtRosbag2(
|
||||||
serialization_format=x['topic_metadata']['serialization_format'],
|
serialization_format=x['topic_metadata']['serialization_format'],
|
||||||
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
|
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
|
||||||
|
),
|
||||||
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
|
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
|
||||||
}
|
}
|
||||||
noncdr = {
|
noncdr = {
|
||||||
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
|
fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
|
||||||
|
if (fmt := x.ext.serialization_format) != 'cdr'
|
||||||
}
|
}
|
||||||
if noncdr:
|
if noncdr:
|
||||||
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
|
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
|
|||||||
import zstandard
|
import zstandard
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
from .connection import Connection
|
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
@ -82,6 +82,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.compression_format = ''
|
self.compression_format = ''
|
||||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||||
self.connections: dict[int, Connection] = {}
|
self.connections: dict[int, Connection] = {}
|
||||||
|
self.counts: dict[int, int] = {}
|
||||||
self.conn: Optional[sqlite3.Connection] = None
|
self.conn: Optional[sqlite3.Connection] = None
|
||||||
self.cursor: Optional[sqlite3.Cursor] = None
|
self.cursor: Optional[sqlite3.Cursor] = None
|
||||||
|
|
||||||
@ -152,16 +153,25 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
connection = Connection(
|
connection = Connection(
|
||||||
id=len(self.connections.values()) + 1,
|
id=len(self.connections.values()) + 1,
|
||||||
count=0,
|
|
||||||
topic=topic,
|
topic=topic,
|
||||||
msgtype=msgtype,
|
msgtype=msgtype,
|
||||||
|
msgdef='',
|
||||||
|
md5sum='',
|
||||||
|
msgcount=0,
|
||||||
|
ext=ConnectionExtRosbag2(
|
||||||
serialization_format=serialization_format,
|
serialization_format=serialization_format,
|
||||||
offered_qos_profiles=offered_qos_profiles,
|
offered_qos_profiles=offered_qos_profiles,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if connection in self.connections.values():
|
for conn in self.connections.values():
|
||||||
|
if (
|
||||||
|
conn.topic == connection.topic and conn.msgtype == connection.msgtype and
|
||||||
|
conn.ext == connection.ext
|
||||||
|
):
|
||||||
raise WriterError(f'Connection can only be added once: {connection!r}.')
|
raise WriterError(f'Connection can only be added once: {connection!r}.')
|
||||||
|
|
||||||
self.connections[connection.id] = connection
|
self.connections[connection.id] = connection
|
||||||
|
self.counts[connection.id] = 0
|
||||||
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
|
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
|
||||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||||
return connection
|
return connection
|
||||||
@ -191,7 +201,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
|
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
|
||||||
(connection.id, timestamp, data),
|
(connection.id, timestamp, data),
|
||||||
)
|
)
|
||||||
connection.count += 1
|
self.counts[connection.id] += 1
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close rosbag2 after writing.
|
"""Close rosbag2 after writing.
|
||||||
@ -237,11 +247,11 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
'topic_metadata': {
|
'topic_metadata': {
|
||||||
'name': x.topic,
|
'name': x.topic,
|
||||||
'type': x.msgtype,
|
'type': x.msgtype,
|
||||||
'serialization_format': x.serialization_format,
|
'serialization_format': x.ext.serialization_format,
|
||||||
'offered_qos_profiles': x.offered_qos_profiles,
|
'offered_qos_profiles': x.ext.offered_qos_profiles,
|
||||||
},
|
},
|
||||||
'message_count': x.count,
|
'message_count': self.counts[x.id],
|
||||||
} for x in self.connections.values()
|
} for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
|
||||||
],
|
],
|
||||||
'compression_format': self.compression_format,
|
'compression_format': self.compression_format,
|
||||||
'compression_mode': self.compression_mode,
|
'compression_mode': self.compression_mode,
|
||||||
|
|||||||
@ -2,18 +2,25 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Rosbag1to2 converter tests."""
|
"""Rosbag1to2 converter tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, call, patch
|
from typing import TYPE_CHECKING
|
||||||
|
from unittest.mock import call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from rosbags.convert import ConverterError, convert
|
from rosbags.convert import ConverterError, convert
|
||||||
from rosbags.convert.__main__ import main
|
from rosbags.convert.__main__ import main
|
||||||
from rosbags.convert.converter import LATCH
|
from rosbags.convert.converter import LATCH
|
||||||
|
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
|
||||||
from rosbags.rosbag1 import ReaderError
|
from rosbags.rosbag1 import ReaderError
|
||||||
from rosbags.rosbag2 import WriterError
|
from rosbags.rosbag2 import WriterError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def test_cliwrapper(tmp_path: Path) -> None:
|
def test_cliwrapper(tmp_path: Path) -> None:
|
||||||
"""Test cli wrapper."""
|
"""Test cli wrapper."""
|
||||||
@ -114,71 +121,83 @@ def test_convert_1to2(tmp_path: Path) -> None:
|
|||||||
patch('rosbags.convert.converter.register_types') as register_types, \
|
patch('rosbags.convert.converter.register_types') as register_types, \
|
||||||
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
|
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
|
||||||
|
|
||||||
|
readerinst = reader.return_value.__enter__.return_value
|
||||||
|
writerinst = writer.return_value.__enter__.return_value
|
||||||
|
|
||||||
connections = [
|
connections = [
|
||||||
Mock(topic='/topic', msgtype='typ', latching=False),
|
Connection(1, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)),
|
||||||
Mock(topic='/topic', msgtype='typ', latching=True),
|
Connection(2, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, True)),
|
||||||
|
Connection(3, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)),
|
||||||
|
Connection(4, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1('caller', False)),
|
||||||
]
|
]
|
||||||
|
|
||||||
wconnections = [
|
wconnections = [
|
||||||
Mock(topic='/topic', msgtype='typ'),
|
Connection(1, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')),
|
||||||
Mock(topic='/topic', msgtype='typ'),
|
Connection(2, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', LATCH)),
|
||||||
|
Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')),
|
||||||
]
|
]
|
||||||
|
|
||||||
reader.return_value.__enter__.return_value.connections = {
|
readerinst.connections = {
|
||||||
1: connections[0],
|
1: connections[0],
|
||||||
2: connections[1],
|
2: connections[1],
|
||||||
|
3: connections[2],
|
||||||
|
4: connections[3],
|
||||||
}
|
}
|
||||||
|
|
||||||
reader.return_value.__enter__.return_value.messages.return_value = [
|
readerinst.messages.return_value = [
|
||||||
(connections[0], 42, b'\x42'),
|
(connections[0], 42, b'\x42'),
|
||||||
(connections[1], 43, b'\x43'),
|
(connections[1], 43, b'\x43'),
|
||||||
|
(connections[2], 44, b'\x44'),
|
||||||
|
(connections[3], 45, b'\x45'),
|
||||||
]
|
]
|
||||||
|
|
||||||
writer.return_value.__enter__.return_value.add_connection.side_effect = [
|
writerinst.connections = {}
|
||||||
wconnections[0],
|
|
||||||
wconnections[1],
|
def add_connection(*_: Any) -> Connection: # noqa: ANN401
|
||||||
]
|
"""Mock for Writer.add_connection."""
|
||||||
|
writerinst.connections = {
|
||||||
|
conn.id: conn
|
||||||
|
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
|
||||||
|
}
|
||||||
|
return wconnections[len(writerinst.connections) - 1]
|
||||||
|
|
||||||
|
writerinst.add_connection.side_effect = add_connection
|
||||||
|
|
||||||
ros1_to_cdr.return_value = b'666'
|
ros1_to_cdr.return_value = b'666'
|
||||||
|
|
||||||
convert(Path('foo.bag'), None)
|
convert(Path('foo.bag'), None)
|
||||||
|
|
||||||
reader.assert_called_with(Path('foo.bag'))
|
reader.assert_called_with(Path('foo.bag'))
|
||||||
reader.return_value.__enter__.return_value.messages.assert_called_with()
|
readerinst.messages.assert_called_with()
|
||||||
|
|
||||||
writer.assert_called_with(Path('foo'))
|
writer.assert_called_with(Path('foo'))
|
||||||
writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
|
writerinst.add_connection.assert_has_calls(
|
||||||
[
|
[
|
||||||
call(
|
call('/topic', 'typ', 'cdr', ''),
|
||||||
id=-1,
|
call('/topic', 'typ', 'cdr', LATCH),
|
||||||
count=0,
|
call('/other', 'typ', 'cdr', ''),
|
||||||
topic='/topic',
|
|
||||||
msgtype='typ',
|
|
||||||
serialization_format='cdr',
|
|
||||||
offered_qos_profiles='',
|
|
||||||
),
|
|
||||||
call(
|
|
||||||
id=-1,
|
|
||||||
count=0,
|
|
||||||
topic='/topic',
|
|
||||||
msgtype='typ',
|
|
||||||
serialization_format='cdr',
|
|
||||||
offered_qos_profiles=LATCH,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
writer.return_value.__enter__.return_value.write.assert_has_calls(
|
writerinst.write.assert_has_calls(
|
||||||
[call(wconnections[0], 42, b'666'),
|
[
|
||||||
call(wconnections[1], 43, b'666')],
|
call(wconnections[0], 42, b'666'),
|
||||||
|
call(wconnections[1], 43, b'666'),
|
||||||
|
call(wconnections[2], 44, b'666'),
|
||||||
|
call(wconnections[2], 45, b'666'),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
register_types.assert_called_with({'typ': 'def'})
|
register_types.assert_called_with({'typ': 'def'})
|
||||||
ros1_to_cdr.assert_has_calls([call(b'\x42', 'typ'), call(b'\x43', 'typ')])
|
ros1_to_cdr.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(b'\x42', 'typ'),
|
||||||
|
call(b'\x43', 'typ'),
|
||||||
|
call(b'\x44', 'typ'),
|
||||||
|
call(b'\x45', 'typ'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
writer.return_value.__enter__.return_value.add_connection.side_effect = [
|
writerinst.connections.clear()
|
||||||
wconnections[0],
|
|
||||||
wconnections[1],
|
|
||||||
]
|
|
||||||
ros1_to_cdr.side_effect = KeyError('exc')
|
ros1_to_cdr.side_effect = KeyError('exc')
|
||||||
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
||||||
convert(Path('foo.bag'), None)
|
convert(Path('foo.bag'), None)
|
||||||
@ -204,30 +223,79 @@ def test_convert_2to1(tmp_path: Path) -> None:
|
|||||||
patch('rosbags.convert.converter.Writer1') as writer, \
|
patch('rosbags.convert.converter.Writer1') as writer, \
|
||||||
patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1:
|
patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1:
|
||||||
|
|
||||||
|
readerinst = reader.return_value.__enter__.return_value
|
||||||
|
writerinst = writer.return_value.__enter__.return_value
|
||||||
|
|
||||||
connections = [
|
connections = [
|
||||||
Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=''),
|
Connection(1, '/topic', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')),
|
||||||
Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=LATCH),
|
Connection(
|
||||||
|
2,
|
||||||
|
'/topic',
|
||||||
|
'std_msgs/msg/Bool',
|
||||||
|
'',
|
||||||
|
'',
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag2('', LATCH),
|
||||||
|
),
|
||||||
|
Connection(3, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')),
|
||||||
|
Connection(4, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '0')),
|
||||||
]
|
]
|
||||||
|
|
||||||
wconnections = [
|
wconnections = [
|
||||||
Mock(topic='/topic', msgtype='typ'),
|
Connection(
|
||||||
Mock(topic='/topic', msgtype='typ'),
|
1,
|
||||||
|
'/topic',
|
||||||
|
'std_msgs/msg/Bool',
|
||||||
|
'',
|
||||||
|
'8b94c1b53db61fb6aed406028ad6332a',
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag1(None, False),
|
||||||
|
),
|
||||||
|
Connection(
|
||||||
|
2,
|
||||||
|
'/topic',
|
||||||
|
'std_msgs/msg/Bool',
|
||||||
|
'',
|
||||||
|
'8b94c1b53db61fb6aed406028ad6332a',
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag1(None, True),
|
||||||
|
),
|
||||||
|
Connection(
|
||||||
|
3,
|
||||||
|
'/other',
|
||||||
|
'std_msgs/msg/Bool',
|
||||||
|
'',
|
||||||
|
'8b94c1b53db61fb6aed406028ad6332a',
|
||||||
|
-1,
|
||||||
|
ConnectionExtRosbag1(None, False),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
reader.return_value.__enter__.return_value.connections = {
|
readerinst.connections = {
|
||||||
1: connections[0],
|
1: connections[0],
|
||||||
2: connections[1],
|
2: connections[1],
|
||||||
|
3: connections[2],
|
||||||
|
4: connections[3],
|
||||||
}
|
}
|
||||||
|
|
||||||
reader.return_value.__enter__.return_value.messages.return_value = [
|
readerinst.messages.return_value = [
|
||||||
(connections[0], 42, b'\x42'),
|
(connections[0], 42, b'\x42'),
|
||||||
(connections[1], 43, b'\x43'),
|
(connections[1], 43, b'\x43'),
|
||||||
|
(connections[2], 44, b'\x44'),
|
||||||
|
(connections[3], 45, b'\x45'),
|
||||||
]
|
]
|
||||||
|
|
||||||
writer.return_value.__enter__.return_value.add_connection.side_effect = [
|
writerinst.connections = {}
|
||||||
wconnections[0],
|
|
||||||
wconnections[1],
|
def add_connection(*_: Any) -> Connection: # noqa: ANN401
|
||||||
]
|
"""Mock for Writer.add_connection."""
|
||||||
|
writerinst.connections = {
|
||||||
|
conn.id: conn
|
||||||
|
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
|
||||||
|
}
|
||||||
|
return wconnections[len(writerinst.connections) - 1]
|
||||||
|
|
||||||
|
writerinst.add_connection.side_effect = add_connection
|
||||||
|
|
||||||
cdr_to_ros1.return_value = b'666'
|
cdr_to_ros1.return_value = b'666'
|
||||||
|
|
||||||
@ -255,24 +323,35 @@ def test_convert_2to1(tmp_path: Path) -> None:
|
|||||||
None,
|
None,
|
||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
|
call(
|
||||||
|
'/other',
|
||||||
|
'std_msgs/msg/Bool',
|
||||||
|
'bool data\n',
|
||||||
|
'8b94c1b53db61fb6aed406028ad6332a',
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
writer.return_value.__enter__.return_value.write.assert_has_calls(
|
writer.return_value.__enter__.return_value.write.assert_has_calls(
|
||||||
[call(wconnections[0], 42, b'666'),
|
[
|
||||||
call(wconnections[1], 43, b'666')],
|
call(wconnections[0], 42, b'666'),
|
||||||
|
call(wconnections[1], 43, b'666'),
|
||||||
|
call(wconnections[2], 44, b'666'),
|
||||||
|
call(wconnections[2], 45, b'666'),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cdr_to_ros1.assert_has_calls(
|
cdr_to_ros1.assert_has_calls(
|
||||||
[
|
[
|
||||||
call(b'\x42', 'std_msgs/msg/Bool'),
|
call(b'\x42', 'std_msgs/msg/Bool'),
|
||||||
call(b'\x43', 'std_msgs/msg/Bool'),
|
call(b'\x43', 'std_msgs/msg/Bool'),
|
||||||
|
call(b'\x44', 'std_msgs/msg/Bool'),
|
||||||
|
call(b'\x45', 'std_msgs/msg/Bool'),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
writer.return_value.__enter__.return_value.add_connection.side_effect = [
|
writerinst.connections.clear()
|
||||||
wconnections[0],
|
|
||||||
wconnections[1],
|
|
||||||
]
|
|
||||||
cdr_to_ros1.side_effect = KeyError('exc')
|
cdr_to_ros1.side_effect = KeyError('exc')
|
||||||
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
||||||
convert(Path('foo'), None)
|
convert(Path('foo'), None)
|
||||||
|
|||||||
@ -36,7 +36,9 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
|
|||||||
with rbag:
|
with rbag:
|
||||||
gen = rbag.messages()
|
gen = rbag.messages()
|
||||||
rconnection, _, raw = next(gen)
|
rconnection, _, raw = next(gen)
|
||||||
assert rconnection == wconnection
|
assert rconnection.topic == wconnection.topic
|
||||||
|
assert rconnection.msgtype == wconnection.msgtype
|
||||||
|
assert rconnection.ext == wconnection.ext
|
||||||
msg = deserialize_cdr(raw, rconnection.msgtype)
|
msg = deserialize_cdr(raw, rconnection.msgtype)
|
||||||
assert getattr(msg, 'data', None) == Foo.data
|
assert getattr(msg, 'data', None) == Foo.data
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(StopIteration):
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||||
from rosbags.rosbag2 import Writer, WriterError
|
from rosbags.rosbag2 import Writer, WriterError
|
||||||
from rosbags.rosbag2.connection import Connection
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -81,7 +81,11 @@ def test_failure_cases(tmp_path: Path) -> None:
|
|||||||
|
|
||||||
bag = Writer(tmp_path / 'write')
|
bag = Writer(tmp_path / 'write')
|
||||||
with pytest.raises(WriterError, match='was not opened'):
|
with pytest.raises(WriterError, match='was not opened'):
|
||||||
bag.write(Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', ''), 0, b'')
|
bag.write(
|
||||||
|
Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', '')),
|
||||||
|
0,
|
||||||
|
b'',
|
||||||
|
)
|
||||||
|
|
||||||
bag = Writer(tmp_path / 'topic')
|
bag = Writer(tmp_path / 'topic')
|
||||||
bag.open()
|
bag.open()
|
||||||
@ -91,6 +95,6 @@ def test_failure_cases(tmp_path: Path) -> None:
|
|||||||
|
|
||||||
bag = Writer(tmp_path / 'notopic')
|
bag = Writer(tmp_path / 'notopic')
|
||||||
bag.open()
|
bag.open()
|
||||||
connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '')
|
connection = Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', ''))
|
||||||
with pytest.raises(WriterError, match='unknown connection'):
|
with pytest.raises(WriterError, match='unknown connection'):
|
||||||
bag.write(connection, 42, b'\x00')
|
bag.write(connection, 42, b'\x00')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user