Change to connection oriented reader API
This commit is contained in:
committed by
Florian Friesdorf
parent
ebf357a0c6
commit
f33e65b14a
+39
-12
@@ -76,17 +76,29 @@ def test_convert(tmp_path: Path):
|
||||
patch('rosbags.convert.converter.register_types') as register_types, \
|
||||
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
|
||||
|
||||
connections = [
|
||||
Mock(topic='/topic', msgtype='typ', latching=False),
|
||||
Mock(topic='/topic', msgtype='typ', latching=True),
|
||||
]
|
||||
|
||||
wconnections = [
|
||||
Mock(topic='/topic', msgtype='typ'),
|
||||
Mock(topic='/topic', msgtype='typ'),
|
||||
]
|
||||
|
||||
reader.return_value.__enter__.return_value.connections = {
|
||||
0: Mock(topic='/topic', latching=False),
|
||||
1: Mock(topic='/latched', latching=True),
|
||||
}
|
||||
reader.return_value.__enter__.return_value.topics = {
|
||||
'/topic': Mock(msgtype='typ', msgdef='def'),
|
||||
'/latched': Mock(msgtype='typ', msgdef='def'),
|
||||
1: connections[0],
|
||||
2: connections[1],
|
||||
}
|
||||
|
||||
reader.return_value.__enter__.return_value.messages.return_value = [
|
||||
('/topic', 'typ', 42, b'\x42'),
|
||||
('/latched', 'typ', 43, b'\x43'),
|
||||
(connections[0], 42, b'\x42'),
|
||||
(connections[1], 43, b'\x43'),
|
||||
]
|
||||
|
||||
writer.return_value.__enter__.return_value.add_connection.side_effect = [
|
||||
wconnections[0],
|
||||
wconnections[1],
|
||||
]
|
||||
|
||||
ros1_to_cdr.return_value = b'666'
|
||||
@@ -97,14 +109,29 @@ def test_convert(tmp_path: Path):
|
||||
reader.return_value.__enter__.return_value.messages.assert_called_with()
|
||||
|
||||
writer.assert_called_with(Path('foo'))
|
||||
writer.return_value.__enter__.return_value.add_topic.assert_has_calls(
|
||||
writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
|
||||
[
|
||||
call('/topic', 'typ', offered_qos_profiles=''),
|
||||
call('/latched', 'typ', offered_qos_profiles=LATCH),
|
||||
call(
|
||||
id=-1,
|
||||
count=0,
|
||||
topic='/topic',
|
||||
msgtype='typ',
|
||||
serialization_format='cdr',
|
||||
offered_qos_profiles='',
|
||||
),
|
||||
call(
|
||||
id=-1,
|
||||
count=0,
|
||||
topic='/topic',
|
||||
msgtype='typ',
|
||||
serialization_format='cdr',
|
||||
offered_qos_profiles=LATCH,
|
||||
),
|
||||
],
|
||||
)
|
||||
writer.return_value.__enter__.return_value.write.assert_has_calls(
|
||||
[call('/topic', 42, b'666'), call('/latched', 43, b'666')],
|
||||
[call(wconnections[0], 42, b'666'),
|
||||
call(wconnections[1], 43, b'666')],
|
||||
)
|
||||
|
||||
register_types.assert_called_with({'typ': 'def'})
|
||||
|
||||
+26
-25
@@ -126,25 +126,26 @@ def test_reader(bag: Path):
|
||||
assert reader.message_count == 4
|
||||
if reader.compression_mode:
|
||||
assert reader.compression_format == 'zstd'
|
||||
|
||||
assert [*reader.connections.keys()] == [1, 2, 3]
|
||||
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
|
||||
gen = reader.messages()
|
||||
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/poly'
|
||||
assert msgtype == 'geometry_msgs/msg/Polygon'
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
|
||||
assert timestamp == 666
|
||||
assert rawdata == MSG_POLY[0]
|
||||
|
||||
for idx in range(2):
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/magn'
|
||||
assert msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
assert timestamp == 708
|
||||
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
|
||||
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/joint'
|
||||
assert msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
@@ -153,32 +154,32 @@ def test_reader(bag: Path):
|
||||
def test_message_filters(bag: Path):
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag) as reader:
|
||||
|
||||
gen = reader.messages(['/magn'])
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
|
||||
gen = reader.messages(connections=magn_connections)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=667)
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/joint'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(stop=667)
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/poly'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(['/magn'], stop=667)
|
||||
gen = reader.messages(connections=magn_connections, stop=667)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
+10
-8
@@ -141,7 +141,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
||||
header['index_pos'] = pack('<Q', pos)
|
||||
|
||||
header = ser(header)
|
||||
header += b'\x00' * (4096 - len(header))
|
||||
header += b'\x20' * (4096 - len(header))
|
||||
|
||||
bag.write_bytes(b''.join([
|
||||
magic,
|
||||
@@ -227,8 +227,10 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
|
||||
assert reader.topics['/topic0'].msgcount == 2
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
assert msgs[1][3] == b'MSGCONTENT10'
|
||||
assert msgs[0][0].topic == '/topic0'
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
assert msgs[1][0].topic == '/topic0'
|
||||
assert msgs[1][2] == b'MSGCONTENT10'
|
||||
|
||||
# sorts by time on different topic
|
||||
write_bag(
|
||||
@@ -249,20 +251,20 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
|
||||
assert reader.topics['/topic2'].msgcount == 1
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
assert msgs[1][3] == b'MSGCONTENT10'
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
assert msgs[1][2] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(['/topic0']))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT10'
|
||||
assert msgs[0][2] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(start=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT10'
|
||||
assert msgs[0][2] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(stop=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
|
||||
|
||||
def test_user_errors(tmp_path):
|
||||
|
||||
@@ -29,14 +29,15 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
|
||||
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
|
||||
with wbag:
|
||||
msgtype = 'std_msgs/msg/Float64'
|
||||
wbag.add_topic('/test', msgtype)
|
||||
wbag.write('/test', 42, serialize_cdr(Foo, msgtype))
|
||||
wconnection = wbag.add_connection('/test', msgtype)
|
||||
wbag.write(wconnection, 42, serialize_cdr(Foo, msgtype))
|
||||
|
||||
rbag = Reader(path)
|
||||
with rbag:
|
||||
gen = rbag.messages()
|
||||
_, msgtype, _, raw = next(gen)
|
||||
msg = deserialize_cdr(raw, msgtype)
|
||||
rconnection, _, raw = next(gen)
|
||||
assert rconnection == wconnection
|
||||
msg = deserialize_cdr(raw, rconnection.msgtype)
|
||||
assert msg.data == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
+19
-17
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag2 import Writer, WriterError
|
||||
from rosbags.rosbag2.connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
@@ -18,9 +19,9 @@ def test_writer(tmp_path: Path):
|
||||
"""Test Writer."""
|
||||
path = (tmp_path / 'rosbag2')
|
||||
with Writer(path) as bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'rosbag2.db3').exists()
|
||||
size = (path / 'rosbag2.db3').stat().st_size
|
||||
@@ -29,9 +30,9 @@ def test_writer(tmp_path: Path):
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_none.db3').exists()
|
||||
assert size == (path / 'compress_none.db3').stat().st_size
|
||||
@@ -40,9 +41,9 @@ def test_writer(tmp_path: Path):
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert not (path / 'compress_file.db3').exists()
|
||||
assert (path / 'compress_file.db3.zstd').exists()
|
||||
@@ -51,9 +52,9 @@ def test_writer(tmp_path: Path):
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_message.db3').exists()
|
||||
assert size > (path / 'compress_message.db3').stat().st_size
|
||||
@@ -76,7 +77,7 @@ def test_failure_cases(tmp_path: Path):
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'write')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
@@ -84,11 +85,12 @@ def test_failure_cases(tmp_path: Path):
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
bag.open()
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
with pytest.raises(WriterError, match='only be added once'):
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'notopic')
|
||||
bag.open()
|
||||
with pytest.raises(WriterError, match='unknown topic'):
|
||||
bag.write('/test', 42, b'\x00')
|
||||
connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '')
|
||||
with pytest.raises(WriterError, match='unknown connection'):
|
||||
bag.write(connection, 42, b'\x00')
|
||||
|
||||
Reference in New Issue
Block a user