Unify rosbag1 and rosbag2 connection class

This commit is contained in:
Marko Durkovic 2022-04-13 09:40:22 +02:00
parent dee7e9c2fc
commit 16d1758327
13 changed files with 301 additions and 150 deletions

View File

@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from rosbags.interfaces import ConnectionExtRosbag2
from rosbags.rosbag2 import Reader, Writer from rosbags.rosbag2 import Reader, Writer
from rosbags.serde import deserialize_cdr, serialize_cdr from rosbags.serde import deserialize_cdr, serialize_cdr
@ -23,11 +24,12 @@ def offset_timestamps(src: Path, dst: Path, offset: int) -> None:
with Reader(src) as reader, Writer(dst) as writer: with Reader(src) as reader, Writer(dst) as writer:
conn_map = {} conn_map = {}
for conn in reader.connections.values(): for conn in reader.connections.values():
ext = cast(ConnectionExtRosbag2, conn.ext)
conn_map[conn.id] = writer.add_connection( conn_map[conn.id] = writer.add_connection(
conn.topic, conn.topic,
conn.msgtype, conn.msgtype,
conn.serialization_format, ext.serialization_format,
conn.offered_qos_profiles, ext.offered_qos_profiles,
) )
for conn, timestamp, data in reader.messages(): for conn, timestamp, data in reader.messages():

View File

@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from rosbags.interfaces import ConnectionExtRosbag2
from rosbags.rosbag2 import Reader, Writer from rosbags.rosbag2 import Reader, Writer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -24,11 +25,12 @@ def remove_topic(src: Path, dst: Path, topic: str) -> None:
for conn in reader.connections.values(): for conn in reader.connections.values():
if conn.topic == topic: if conn.topic == topic:
continue continue
ext = cast(ConnectionExtRosbag2, conn.ext)
conn_map[conn.id] = writer.add_connection( conn_map[conn.id] = writer.add_connection(
conn.topic, conn.topic,
conn.msgtype, conn.msgtype,
conn.serialization_format, ext.serialization_format,
conn.offered_qos_profiles, ext.offered_qos_profiles,
) )
rconns = [reader.connections[x] for x in conn_map] rconns = [reader.connections[x] for x in conn_map]

View File

@ -4,19 +4,17 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
from rosbags.rosbag1 import Reader as Reader1 from rosbags.rosbag1 import Reader as Reader1
from rosbags.rosbag1 import ReaderError as ReaderError1 from rosbags.rosbag1 import ReaderError as ReaderError1
from rosbags.rosbag1 import Writer as Writer1 from rosbags.rosbag1 import Writer as Writer1
from rosbags.rosbag1 import WriterError as WriterError1 from rosbags.rosbag1 import WriterError as WriterError1
from rosbags.rosbag1.reader import Connection as Connection1
from rosbags.rosbag2 import Reader as Reader2 from rosbags.rosbag2 import Reader as Reader2
from rosbags.rosbag2 import ReaderError as ReaderError2 from rosbags.rosbag2 import ReaderError as ReaderError2
from rosbags.rosbag2 import Writer as Writer2 from rosbags.rosbag2 import Writer as Writer2
from rosbags.rosbag2 import WriterError as WriterError2 from rosbags.rosbag2 import WriterError as WriterError2
from rosbags.rosbag2.connection import Connection as Connection2
from rosbags.serde import cdr_to_ros1, ros1_to_cdr from rosbags.serde import cdr_to_ros1, ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types from rosbags.typesys import get_types_from_msg, register_types
from rosbags.typesys.msg import generate_msgdef from rosbags.typesys.msg import generate_msgdef
@ -48,7 +46,7 @@ class ConverterError(Exception):
"""Converter Error.""" """Converter Error."""
def upgrade_connection(rconn: Connection1) -> Connection2: def upgrade_connection(rconn: Connection) -> Connection:
"""Convert rosbag1 connection to rosbag2 connection. """Convert rosbag1 connection to rosbag2 connection.
Args: Args:
@ -58,17 +56,22 @@ def upgrade_connection(rconn: Connection1) -> Connection2:
Rosbag2 connection. Rosbag2 connection.
""" """
return Connection2( assert isinstance(rconn.ext, ConnectionExtRosbag1)
-1, return Connection(
0, rconn.id,
rconn.topic, rconn.topic,
rconn.msgtype, rconn.msgtype,
'cdr', '',
LATCH if rconn.latching else '', '',
0,
ConnectionExtRosbag2(
'cdr',
LATCH if rconn.ext.latching else '',
),
) )
def downgrade_connection(rconn: Connection2) -> Connection1: def downgrade_connection(rconn: Connection) -> Connection:
"""Convert rosbag2 connection to rosbag1 connection. """Convert rosbag2 connection to rosbag1 connection.
Args: Args:
@ -78,15 +81,19 @@ def downgrade_connection(rconn: Connection2) -> Connection1:
Rosbag1 connection. Rosbag1 connection.
""" """
assert isinstance(rconn.ext, ConnectionExtRosbag2)
msgdef, md5sum = generate_msgdef(rconn.msgtype) msgdef, md5sum = generate_msgdef(rconn.msgtype)
return Connection1( return Connection(
-1, rconn.id,
rconn.topic, rconn.topic,
rconn.msgtype, rconn.msgtype,
msgdef, msgdef,
md5sum, md5sum,
None, -1,
int('durability: 1' in rconn.offered_qos_profiles), ConnectionExtRosbag1(
None,
int('durability: 1' in rconn.ext.offered_qos_profiles),
),
) )
@ -100,13 +107,26 @@ def convert_1to2(src: Path, dst: Path) -> None:
""" """
with Reader1(src) as reader, Writer2(dst) as writer: with Reader1(src) as reader, Writer2(dst) as writer:
typs: dict[str, Any] = {} typs: dict[str, Any] = {}
connmap: dict[int, Connection2] = {} connmap: dict[int, Connection] = {}
for rconn in reader.connections.values(): for rconn in reader.connections.values():
candidate = upgrade_connection(rconn) candidate = upgrade_connection(rconn)
existing = next((x for x in writer.connections.values() if x == candidate), None) assert isinstance(candidate.ext, ConnectionExtRosbag2)
wconn = existing if existing else writer.add_connection(**asdict(candidate)) for conn in writer.connections.values():
connmap[rconn.id] = wconn assert isinstance(conn.ext, ConnectionExtRosbag2)
if (
conn.topic == candidate.topic and conn.msgtype == candidate.msgtype and
conn.ext == candidate.ext
):
break
else:
conn = writer.add_connection(
candidate.topic,
candidate.msgtype,
candidate.ext.serialization_format,
candidate.ext.offered_qos_profiles,
)
connmap[rconn.id] = conn
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype)) typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
register_types(typs) register_types(typs)
@ -124,22 +144,27 @@ def convert_2to1(src: Path, dst: Path) -> None:
""" """
with Reader2(src) as reader, Writer1(dst) as writer: with Reader2(src) as reader, Writer1(dst) as writer:
connmap: dict[int, Connection1] = {} connmap: dict[int, Connection] = {}
for rconn in reader.connections.values(): for rconn in reader.connections.values():
candidate = downgrade_connection(rconn) candidate = downgrade_connection(rconn)
# yapf: disable assert isinstance(candidate.ext, ConnectionExtRosbag1)
existing = next( for conn in writer.connections.values():
( assert isinstance(conn.ext, ConnectionExtRosbag1)
x if (
for x in writer.connections.values() conn.topic == candidate.topic and conn.md5sum == candidate.md5sum and
if x.topic == candidate.topic conn.ext.latching == candidate.ext.latching
if x.md5sum == candidate.md5sum ):
if x.latching == candidate.latching break
), else:
None, conn = writer.add_connection(
) candidate.topic,
# yapf: enable candidate.msgtype,
connmap[rconn.id] = existing if existing else writer.add_connection(*candidate[1:]) candidate.msgdef,
candidate.md5sum,
candidate.ext.callerid,
candidate.ext.latching,
)
connmap[rconn.id] = conn
for rconn, timestamp, data in reader.messages(): for rconn, timestamp, data in reader.messages():
data = cdr_to_ros1(data, rconn.msgtype) data = cdr_to_ros1(data, rconn.msgtype)

View File

@ -0,0 +1,36 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Shared interfaces."""
from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Optional, Union
class ConnectionExtRosbag1(NamedTuple):
"""Rosbag1 specific connection extensions."""
callerid: Optional[str]
latching: Optional[int]
class ConnectionExtRosbag2(NamedTuple):
"""Rosbag2 specific connection extensions."""
serialization_format: str
offered_qos_profiles: str
class Connection(NamedTuple):
"""Connection information."""
id: int
topic: str
msgtype: str
msgdef: str
md5sum: str
msgcount: int
ext: Union[ConnectionExtRosbag1, ConnectionExtRosbag2]

View File

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, NamedTuple
from lz4.frame import decompress as lz4_decompress from lz4.frame import decompress as lz4_decompress
from rosbags.interfaces import Connection, ConnectionExtRosbag1
from rosbags.typesys.msg import normalize_msgtype from rosbags.typesys.msg import normalize_msgtype
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,18 +51,6 @@ class RecordType(IntEnum):
CONNECTION = 7 CONNECTION = 7
class Connection(NamedTuple):
"""Connection information."""
id: int
topic: str
msgtype: str
msgdef: str
md5sum: str
callerid: Optional[str]
latching: Optional[int]
class ChunkInfo(NamedTuple): class ChunkInfo(NamedTuple):
"""Chunk information.""" """Chunk information."""
@ -427,6 +416,13 @@ class Reader:
} }
assert all(self.indexes[x] for x in self.connections) assert all(self.indexes[x] for x in self.connections)
for cid, connection in self.connections.items():
self.connections[cid] = Connection(
*connection[0:5],
len(self.indexes[cid]),
connection[6],
)
self.topics = {} self.topics = {}
for topic, group in groupby( for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic), sorted(self.connections.values(), key=lambda x: x.topic),
@ -499,8 +495,11 @@ class Reader:
normalize_msgtype(typ), normalize_msgtype(typ),
msgdef, msgdef,
md5sum, md5sum,
callerid, 0,
latching, ConnectionExtRosbag1(
callerid,
latching,
),
) )
def read_chunk_info(self) -> ChunkInfo: def read_chunk_info(self) -> ChunkInfo:

View File

@ -15,9 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict
from lz4.frame import compress as lz4_compress from lz4.frame import compress as lz4_compress
from rosbags.interfaces import Connection, ConnectionExtRosbag1
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
from .reader import Connection, RecordType from .reader import RecordType
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
@ -249,8 +250,11 @@ class Writer:
denormalize_msgtype(msgtype), denormalize_msgtype(msgtype),
msgdef, msgdef,
md5sum, md5sum,
callerid, -1,
latching, ConnectionExtRosbag1(
callerid,
latching,
),
) )
if any(x[1:] == connection[1:] for x in self.connections.values()): if any(x[1:] == connection[1:] for x in self.connections.values()):
@ -314,10 +318,11 @@ class Writer:
header.set_string('type', connection.msgtype) header.set_string('type', connection.msgtype)
header.set_string('md5sum', connection.md5sum) header.set_string('md5sum', connection.md5sum)
header.set_string('message_definition', connection.msgdef) header.set_string('message_definition', connection.msgdef)
if connection.callerid is not None: assert isinstance(connection.ext, ConnectionExtRosbag1)
header.set_string('callerid', connection.callerid) if connection.ext.callerid is not None:
if connection.latching is not None: header.set_string('callerid', connection.ext.callerid)
header.set_string('latching', str(connection.latching)) if connection.ext.latching is not None:
header.set_string('latching', str(connection.ext.latching))
header.write(bio) header.write(bio)
def write_chunk(self, chunk: WriteChunk) -> None: def write_chunk(self, chunk: WriteChunk) -> None:

View File

@ -1,18 +0,0 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 connection."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class Connection:
"""Connection metadata."""
id: int = field(compare=False) # pylint: disable=invalid-name
count: int = field(compare=False)
topic: str
msgtype: str
serialization_format: str
offered_qos_profiles: str

View File

@ -14,7 +14,7 @@ import zstandard
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError from ruamel.yaml.error import YAMLError
from .connection import Connection from rosbags.interfaces import Connection, ConnectionExtRosbag2
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
@ -139,15 +139,20 @@ class Reader:
self.connections = { self.connections = {
idx + 1: Connection( idx + 1: Connection(
id=idx + 1, id=idx + 1,
count=x['message_count'],
topic=x['topic_metadata']['name'], topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'], msgtype=x['topic_metadata']['type'],
serialization_format=x['topic_metadata']['serialization_format'], msgdef='',
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''), md5sum='',
msgcount=x['message_count'],
ext=ConnectionExtRosbag2(
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
),
) for idx, x in enumerate(self.metadata['topics_with_message_count']) ) for idx, x in enumerate(self.metadata['topics_with_message_count'])
} }
noncdr = { noncdr = {
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr' fmt for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
if (fmt := x.ext.serialization_format) != 'cdr'
} }
if noncdr: if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.') raise ReaderError(f'Serialization format {noncdr!r} is not supported.')

View File

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import zstandard import zstandard
from ruamel.yaml import YAML from ruamel.yaml import YAML
from .connection import Connection from rosbags.interfaces import Connection, ConnectionExtRosbag2
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
@ -82,6 +82,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = '' self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {} self.connections: dict[int, Connection] = {}
self.counts: dict[int, int] = {}
self.conn: Optional[sqlite3.Connection] = None self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None self.cursor: Optional[sqlite3.Cursor] = None
@ -152,16 +153,25 @@ class Writer: # pylint: disable=too-many-instance-attributes
connection = Connection( connection = Connection(
id=len(self.connections.values()) + 1, id=len(self.connections.values()) + 1,
count=0,
topic=topic, topic=topic,
msgtype=msgtype, msgtype=msgtype,
serialization_format=serialization_format, msgdef='',
offered_qos_profiles=offered_qos_profiles, md5sum='',
msgcount=0,
ext=ConnectionExtRosbag2(
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
),
) )
if connection in self.connections.values(): for conn in self.connections.values():
raise WriterError(f'Connection can only be added once: {connection!r}.') if (
conn.topic == connection.topic and conn.msgtype == connection.msgtype and
conn.ext == connection.ext
):
raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection self.connections[connection.id] = connection
self.counts[connection.id] = 0
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles) meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection return connection
@ -191,7 +201,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)', 'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
(connection.id, timestamp, data), (connection.id, timestamp, data),
) )
connection.count += 1 self.counts[connection.id] += 1
def close(self) -> None: def close(self) -> None:
"""Close rosbag2 after writing. """Close rosbag2 after writing.
@ -237,11 +247,11 @@ class Writer: # pylint: disable=too-many-instance-attributes
'topic_metadata': { 'topic_metadata': {
'name': x.topic, 'name': x.topic,
'type': x.msgtype, 'type': x.msgtype,
'serialization_format': x.serialization_format, 'serialization_format': x.ext.serialization_format,
'offered_qos_profiles': x.offered_qos_profiles, 'offered_qos_profiles': x.ext.offered_qos_profiles,
}, },
'message_count': x.count, 'message_count': self.counts[x.id],
} for x in self.connections.values() } for x in self.connections.values() if isinstance(x.ext, ConnectionExtRosbag2)
], ],
'compression_format': self.compression_format, 'compression_format': self.compression_format,
'compression_mode': self.compression_mode, 'compression_mode': self.compression_mode,

View File

@ -2,18 +2,25 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Rosbag1to2 converter tests.""" """Rosbag1to2 converter tests."""
from __future__ import annotations
import sys import sys
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, call, patch from typing import TYPE_CHECKING
from unittest.mock import call, patch
import pytest import pytest
from rosbags.convert import ConverterError, convert from rosbags.convert import ConverterError, convert
from rosbags.convert.__main__ import main from rosbags.convert.__main__ import main
from rosbags.convert.converter import LATCH from rosbags.convert.converter import LATCH
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
from rosbags.rosbag1 import ReaderError from rosbags.rosbag1 import ReaderError
from rosbags.rosbag2 import WriterError from rosbags.rosbag2 import WriterError
if TYPE_CHECKING:
from typing import Any
def test_cliwrapper(tmp_path: Path) -> None: def test_cliwrapper(tmp_path: Path) -> None:
"""Test cli wrapper.""" """Test cli wrapper."""
@ -114,71 +121,83 @@ def test_convert_1to2(tmp_path: Path) -> None:
patch('rosbags.convert.converter.register_types') as register_types, \ patch('rosbags.convert.converter.register_types') as register_types, \
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr: patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
readerinst = reader.return_value.__enter__.return_value
writerinst = writer.return_value.__enter__.return_value
connections = [ connections = [
Mock(topic='/topic', msgtype='typ', latching=False), Connection(1, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)),
Mock(topic='/topic', msgtype='typ', latching=True), Connection(2, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, True)),
Connection(3, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False)),
Connection(4, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1('caller', False)),
] ]
wconnections = [ wconnections = [
Mock(topic='/topic', msgtype='typ'), Connection(1, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')),
Mock(topic='/topic', msgtype='typ'), Connection(2, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', LATCH)),
Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', '')),
] ]
reader.return_value.__enter__.return_value.connections = { readerinst.connections = {
1: connections[0], 1: connections[0],
2: connections[1], 2: connections[1],
3: connections[2],
4: connections[3],
} }
reader.return_value.__enter__.return_value.messages.return_value = [ readerinst.messages.return_value = [
(connections[0], 42, b'\x42'), (connections[0], 42, b'\x42'),
(connections[1], 43, b'\x43'), (connections[1], 43, b'\x43'),
(connections[2], 44, b'\x44'),
(connections[3], 45, b'\x45'),
] ]
writer.return_value.__enter__.return_value.add_connection.side_effect = [ writerinst.connections = {}
wconnections[0],
wconnections[1], def add_connection(*_: Any) -> Connection: # noqa: ANN401
] """Mock for Writer.add_connection."""
writerinst.connections = {
conn.id: conn
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
}
return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection
ros1_to_cdr.return_value = b'666' ros1_to_cdr.return_value = b'666'
convert(Path('foo.bag'), None) convert(Path('foo.bag'), None)
reader.assert_called_with(Path('foo.bag')) reader.assert_called_with(Path('foo.bag'))
reader.return_value.__enter__.return_value.messages.assert_called_with() readerinst.messages.assert_called_with()
writer.assert_called_with(Path('foo')) writer.assert_called_with(Path('foo'))
writer.return_value.__enter__.return_value.add_connection.assert_has_calls( writerinst.add_connection.assert_has_calls(
[ [
call( call('/topic', 'typ', 'cdr', ''),
id=-1, call('/topic', 'typ', 'cdr', LATCH),
count=0, call('/other', 'typ', 'cdr', ''),
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles='',
),
call(
id=-1,
count=0,
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles=LATCH,
),
], ],
) )
writer.return_value.__enter__.return_value.write.assert_has_calls( writerinst.write.assert_has_calls(
[call(wconnections[0], 42, b'666'), [
call(wconnections[1], 43, b'666')], call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666'),
call(wconnections[2], 44, b'666'),
call(wconnections[2], 45, b'666'),
],
) )
register_types.assert_called_with({'typ': 'def'}) register_types.assert_called_with({'typ': 'def'})
ros1_to_cdr.assert_has_calls([call(b'\x42', 'typ'), call(b'\x43', 'typ')]) ros1_to_cdr.assert_has_calls(
[
call(b'\x42', 'typ'),
call(b'\x43', 'typ'),
call(b'\x44', 'typ'),
call(b'\x45', 'typ'),
],
)
writer.return_value.__enter__.return_value.add_connection.side_effect = [ writerinst.connections.clear()
wconnections[0],
wconnections[1],
]
ros1_to_cdr.side_effect = KeyError('exc') ros1_to_cdr.side_effect = KeyError('exc')
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'): with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
convert(Path('foo.bag'), None) convert(Path('foo.bag'), None)
@ -204,30 +223,79 @@ def test_convert_2to1(tmp_path: Path) -> None:
patch('rosbags.convert.converter.Writer1') as writer, \ patch('rosbags.convert.converter.Writer1') as writer, \
patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1: patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1:
readerinst = reader.return_value.__enter__.return_value
writerinst = writer.return_value.__enter__.return_value
connections = [ connections = [
Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=''), Connection(1, '/topic', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')),
Mock(topic='/topic', msgtype='std_msgs/msg/Bool', offered_qos_profiles=LATCH), Connection(
2,
'/topic',
'std_msgs/msg/Bool',
'',
'',
-1,
ConnectionExtRosbag2('', LATCH),
),
Connection(3, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '')),
Connection(4, '/other', 'std_msgs/msg/Bool', '', '', -1, ConnectionExtRosbag2('', '0')),
] ]
wconnections = [ wconnections = [
Mock(topic='/topic', msgtype='typ'), Connection(
Mock(topic='/topic', msgtype='typ'), 1,
'/topic',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, False),
),
Connection(
2,
'/topic',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, True),
),
Connection(
3,
'/other',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, False),
),
] ]
reader.return_value.__enter__.return_value.connections = { readerinst.connections = {
1: connections[0], 1: connections[0],
2: connections[1], 2: connections[1],
3: connections[2],
4: connections[3],
} }
reader.return_value.__enter__.return_value.messages.return_value = [ readerinst.messages.return_value = [
(connections[0], 42, b'\x42'), (connections[0], 42, b'\x42'),
(connections[1], 43, b'\x43'), (connections[1], 43, b'\x43'),
(connections[2], 44, b'\x44'),
(connections[3], 45, b'\x45'),
] ]
writer.return_value.__enter__.return_value.add_connection.side_effect = [ writerinst.connections = {}
wconnections[0],
wconnections[1], def add_connection(*_: Any) -> Connection: # noqa: ANN401
] """Mock for Writer.add_connection."""
writerinst.connections = {
conn.id: conn
for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
}
return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection
cdr_to_ros1.return_value = b'666' cdr_to_ros1.return_value = b'666'
@ -255,24 +323,35 @@ def test_convert_2to1(tmp_path: Path) -> None:
None, None,
1, 1,
), ),
call(
'/other',
'std_msgs/msg/Bool',
'bool data\n',
'8b94c1b53db61fb6aed406028ad6332a',
None,
0,
),
], ],
) )
writer.return_value.__enter__.return_value.write.assert_has_calls( writer.return_value.__enter__.return_value.write.assert_has_calls(
[call(wconnections[0], 42, b'666'), [
call(wconnections[1], 43, b'666')], call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666'),
call(wconnections[2], 44, b'666'),
call(wconnections[2], 45, b'666'),
],
) )
cdr_to_ros1.assert_has_calls( cdr_to_ros1.assert_has_calls(
[ [
call(b'\x42', 'std_msgs/msg/Bool'), call(b'\x42', 'std_msgs/msg/Bool'),
call(b'\x43', 'std_msgs/msg/Bool'), call(b'\x43', 'std_msgs/msg/Bool'),
call(b'\x44', 'std_msgs/msg/Bool'),
call(b'\x45', 'std_msgs/msg/Bool'),
], ],
) )
writer.return_value.__enter__.return_value.add_connection.side_effect = [ writerinst.connections.clear()
wconnections[0],
wconnections[1],
]
cdr_to_ros1.side_effect = KeyError('exc') cdr_to_ros1.side_effect = KeyError('exc')
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'): with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
convert(Path('foo'), None) convert(Path('foo'), None)

View File

@ -36,7 +36,9 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
with rbag: with rbag:
gen = rbag.messages() gen = rbag.messages()
rconnection, _, raw = next(gen) rconnection, _, raw = next(gen)
assert rconnection == wconnection assert rconnection.topic == wconnection.topic
assert rconnection.msgtype == wconnection.msgtype
assert rconnection.ext == wconnection.ext
msg = deserialize_cdr(raw, rconnection.msgtype) msg = deserialize_cdr(raw, rconnection.msgtype)
assert getattr(msg, 'data', None) == Foo.data assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration): with pytest.raises(StopIteration):

View File

@ -8,8 +8,8 @@ from typing import TYPE_CHECKING
import pytest import pytest
from rosbags.interfaces import Connection, ConnectionExtRosbag2
from rosbags.rosbag2 import Writer, WriterError from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
@ -81,7 +81,11 @@ def test_failure_cases(tmp_path: Path) -> None:
bag = Writer(tmp_path / 'write') bag = Writer(tmp_path / 'write')
with pytest.raises(WriterError, match='was not opened'): with pytest.raises(WriterError, match='was not opened'):
bag.write(Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', ''), 0, b'') bag.write(
Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', '')),
0,
b'',
)
bag = Writer(tmp_path / 'topic') bag = Writer(tmp_path / 'topic')
bag.open() bag.open()
@ -91,6 +95,6 @@ def test_failure_cases(tmp_path: Path) -> None:
bag = Writer(tmp_path / 'notopic') bag = Writer(tmp_path / 'notopic')
bag.open() bag.open()
connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '') connection = Connection(1, '/tf', 'tf_msgs/msg/tf2', '', '', 0, ConnectionExtRosbag2('cdr', ''))
with pytest.raises(WriterError, match='unknown connection'): with pytest.raises(WriterError, match='unknown connection'):
bag.write(connection, 42, b'\x00') bag.write(connection, 42, b'\x00')