Change to connection oriented reader API

This commit is contained in:
Marko Durkovic 2021-08-01 18:22:36 +02:00 committed by Florian Friesdorf
parent ebf357a0c6
commit f33e65b14a
13 changed files with 290 additions and 172 deletions

View File

@ -41,8 +41,8 @@ Read and deserialize rosbag2 messages:
# create reader instance and open for reading
with Reader('/home/ros/rosbag_2020_03_24') as reader:
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']):
msg = deserialize_cdr(rawdata, msgtype)
for connection, timestamp, rawdata in reader.messages(['/imu_raw/Imu']):
msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id)

View File

@ -13,15 +13,16 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
# create reader instance
with Reader('/home/ros/rosbag_2020_03_24.bag') as reader:
# topic and msgtype information is available on .topics dictionary
for topic, info in reader.topics.items():
print(topic, info)
# topic and msgtype information is available on .connections dictionary
for connection in reader.connections.values():
print(connection.topic, connection.msgtype)
# iterate over messages
for topic, msgtype, timestamp, rawdata in reader.messages():
if topic == '/imu_raw/Imu':
for connection, timestamp, rawdata in reader.messages():
if connection.topic == '/imu_raw/Imu':
print(timestamp)
# messages() accepts topic filters
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']):
# messages() accepts connection filters
connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu']
for connection, timestamp, rawdata in reader.messages(connections=connections):
print(timestamp)

View File

@ -27,16 +27,18 @@ Instances of the :py:class:`Writer <rosbags.rosbag2.Writer>` class can create an
from rosbags.rosbag2 import Writer
from rosbags.serde import serialize_cdr
from rosbags.typesys.types import std_msgs__msg__String as String
# create writer instance and open for writing
with Writer('/home/ros/rosbag_2020_03_24') as writer:
# add new topic
topic = '/imu_raw/Imu'
msgtype = 'sensor_msgs/msg/Imu'
writer.add_topic(topic, msgtype, 'cdr')
# add new connection
topic = '/chatter'
msgtype = String.__msgtype__
connection = writer.add_connection(topic, msgtype, 'cdr', '')
# serialize and write message
writer.write(topic, timestamp, serialize_cdr(message, msgtype))
message = String('hello world')
writer.write(connection, timestamp, serialize_cdr(message, msgtype))
Reading rosbag2
---------------
@ -49,17 +51,18 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are used to r
# create reader instance and open for reading
with Reader('/home/ros/rosbag_2020_03_24') as reader:
# topic and msgtype information is available on .topics dict
for topic, msgtype in reader.topics.items():
print(topic, msgtype)
# topic and msgtype information is available on .connections dict
for connection in reader.connections.values():
print(connection.topic, connection.msgtype)
# iterate over messages
for topic, msgtype, timestamp, rawdata in reader.messages():
if topic == '/imu_raw/Imu':
msg = deserialize_cdr(rawdata, msgtype)
for connection, timestamp, rawdata in reader.messages():
if connection.topic == '/imu_raw/Imu':
msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id)
# messages() accepts topic filters
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']):
msg = deserialize_cdr(rawdata, msgtype)
# messages() accepts connection filters
connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu']
for connection, timestamp, rawdata in reader.messages(connections=connections):
msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id)

View File

@ -4,10 +4,12 @@
from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection as WConnection
from rosbags.serde import ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types
@ -15,6 +17,8 @@ if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Dict, Optional
from rosbags.rosbag1.reader import Connection as RConnection
LATCH = """
- history: 3
depth: 0
@ -38,6 +42,26 @@ class ConverterError(Exception):
"""Converter Error."""
def convert_connection(rconn: RConnection) -> WConnection:
"""Convert rosbag1 connection to rosbag2 connection.
Args:
rconn: Rosbag1 connection.
Returns:
Rosbag2 connection.
"""
return WConnection(
-1,
0,
rconn.topic,
rconn.msgtype,
'cdr',
LATCH if rconn.latching else '',
)
def convert(src: Path, dst: Optional[Path]) -> None:
"""Convert Rosbag1 to Rosbag2.
@ -57,21 +81,19 @@ def convert(src: Path, dst: Optional[Path]) -> None:
try:
with Reader(src) as reader, Writer(dst) as writer:
typs: Dict[str, Any] = {}
for name, topic in reader.topics.items():
connection = next( # pragma: no branch
x for x in reader.connections.values() if x.topic == name
)
writer.add_topic(
name,
topic.msgtype,
offered_qos_profiles=LATCH if connection.latching else '',
)
typs.update(get_types_from_msg(topic.msgdef, topic.msgtype))
connmap: Dict[int, WConnection] = {}
for rconn in reader.connections.values():
candidate = convert_connection(rconn)
existing = next((x for x in writer.connections.values() if x == candidate), None)
wconn = existing if existing else writer.add_connection(**asdict(candidate))
connmap[rconn.cid] = wconn
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
register_types(typs)
for topic, msgtype, timestamp, data in reader.messages():
data = ros1_to_cdr(data, msgtype)
writer.write(topic, timestamp, data)
for rconn, timestamp, data in reader.messages():
data = ros1_to_cdr(data, rconn.msgtype)
writer.write(connmap[rconn.cid], timestamp, data)
except ReaderError as err:
raise ConverterError(f'Reading source bag: {err}') from err
except WriterError as err:

View File

@ -249,7 +249,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header':
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle.
Args:
@ -588,7 +588,7 @@ class Reader:
topics: Optional[Iterable[str]] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]:
) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
@ -598,7 +598,7 @@ class Reader:
stop: Yield only messages before this timestamp (ns).
Yields:
Tuples of topic name, type, timestamp (ns), and rawdata.
Tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open or data corrupt.
@ -635,13 +635,10 @@ class Reader:
if have != RecordType.MSGDATA:
raise ReaderError('Expected to find message data.')
connection = self.connections[header.get_uint32('conn')]
time = header.get_time('time')
data = read_bytes(chunk, read_uint32(chunk))
assert entry.time == time
yield connection.topic, connection.msgtype, time, data
connection = self.connections[header.get_uint32('conn')]
assert entry.time == header.get_time('time')
yield connection, entry.time, data
def __enter__(self) -> Reader:
"""Open rosbag1 when entering contextmanager."""

View File

@ -0,0 +1,18 @@
# Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 connection."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class Connection:
"""Connection metadata."""
id: int = field(compare=False) # pylint: disable=invalid-name
count: int = field(compare=False)
topic: str
msgtype: str
serialization_format: str
offered_qos_profiles: str

View File

@ -13,9 +13,11 @@ from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML, YAMLError
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
class ReaderError(Exception):
@ -96,11 +98,21 @@ class Reader:
if missing:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
topics = [x['topic_metadata'] for x in self.metadata['topics_with_message_count']]
noncdr = {y for x in topics if (y := x['serialization_format']) != 'cdr'}
self.connections = {
idx + 1: Connection(
id=idx + 1,
count=x['message_count'],
topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'],
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
}
noncdr = {
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
}
if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
self.topics = {x['name']: x['type'] for x in topics}
if self.compression_mode and (cfmt := self.compression_format) != 'zstd':
raise ReaderError(f'Compression format {cfmt!r} is not supported.')
@ -149,22 +161,31 @@ class Reader:
mode = self.metadata.get('compression_mode', '').lower()
return mode if mode != 'none' else None
@property
def topics(self) -> Dict[str, Connection]:
"""Topic information.
For the moment this a dictionary mapping topic names to connections.
"""
return {x.topic: x for x in self.connections.values()}
def messages( # pylint: disable=too-many-locals
self,
topics: Iterable[str] = (),
connections: Iterable[Connection] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]:
) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
topics: Iterable with topic names to filter for. An empty iterable
yields all messages.
connections: Iterable with connections to filter for. An empty
iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns).
Yields:
Tuples of topic name, type, timestamp (ns), and rawdata.
Tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open.
@ -173,7 +194,32 @@ class Reader:
if not self.bio:
raise ReaderError('Rosbag is not open.')
topics = tuple(topics)
query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for filepath in self.paths:
with decompress(filepath, self.compression_mode == 'file') as path:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
@ -186,34 +232,16 @@ class Reader:
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
query = [
'SELECT topics.name,topics.type,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
if topics:
query.append(f'WHERE topics.name IN ({",".join("?" for _ in topics)})')
args += topics
if start is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp >= ?')
args.append(start)
if stop is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp < ?')
args.append(stop)
query.append('ORDER BY timestamp')
cur.execute(' '.join(query), args)
cur.execute(querystr, args)
if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress
for row in cur:
topic, msgtype, timestamp, data = row
yield topic, msgtype, timestamp, decomp(data)
cid, timestamp, data = row
yield self.connections[cid], timestamp, decomp(data)
else:
yield from cur
for cid, timestamp, data in cur:
yield self.connections[cid], timestamp, data
def __enter__(self) -> Reader:
"""Open rosbag2 when entering contextmanager."""

View File

@ -1,6 +1,6 @@
# Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 reader."""
"""Rosbag2 writer."""
from __future__ import annotations
@ -12,6 +12,8 @@ from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Dict, Literal, Optional, Type, Union
@ -77,10 +79,9 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_mode = ''
self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None
self.topics: Dict[str, Any] = {}
self.connections: Dict[int, Connection] = {}
self.conn = None
self.cursor: Optional[sqlite3.Cursor] = None
self.topics = {}
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
"""Enable compression on bag.
@ -118,22 +119,27 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.executescript(self.SQLITE_SCHEMA)
self.cursor = self.conn.cursor()
def add_topic(
def add_connection(
self,
name: str,
typ: str,
topic: str,
msgtype: str,
serialization_format: str = 'cdr',
offered_qos_profiles: str = '',
):
"""Add a topic.
**_kw: Any,
) -> Connection:
"""Add a connection.
This function can only be called after opening a bag.
Args:
name: Topic name.
typ: Message type.
topic: Topic name.
msgtype: Message type.
serialization_format: Serialization format.
offered_qos_profiles: QOS Profile.
_kw: Ignored to allow consuming dicts from connection objects.
Returns:
Connection object.
Raises:
WriterError: Bag not open or topic previously registered.
@ -141,17 +147,28 @@ class Writer: # pylint: disable=too-many-instance-attributes
"""
if not self.cursor:
raise WriterError('Bag was not opened.')
if name in self.topics:
raise WriterError(f'Topics can only be added once: {name!r}.')
meta = (len(self.topics) + 1, name, typ, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
self.topics[name] = [*meta, 0]
def write(self, topic: str, timestamp: int, data: bytes):
connection = Connection(
id=len(self.connections.values()) + 1,
count=0,
topic=topic,
msgtype=msgtype,
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
)
if connection in self.connections.values():
raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
"""Write message to rosbag2.
Args:
topic: Topic message belongs to.
connection: Connection to write message to.
timestamp: Message timestamp (ns).
data: Serialized message data.
@ -161,19 +178,18 @@ class Writer: # pylint: disable=too-many-instance-attributes
"""
if not self.cursor:
raise WriterError('Bag was not opened.')
if topic not in self.topics:
raise WriterError(f'Tried to write to unknown topic {topic!r}.')
if connection not in self.connections.values():
raise WriterError(f'Tried to write to unknown connection {connection!r}.')
if self.compression_mode == 'message':
assert self.compressor
data = self.compressor.compress(data)
tmeta = self.topics[topic]
self.cursor.execute(
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
(tmeta[0], timestamp, data),
(connection.id, timestamp, data),
)
tmeta[-1] += 1
connection.count += 1
def close(self):
"""Close rosbag2 after writing.
@ -214,13 +230,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
'topics_with_message_count': [
{
'topic_metadata': {
'name': x[1],
'type': x[2],
'serialization_format': x[3],
'offered_qos_profiles': x[4],
'name': x.topic,
'type': x.msgtype,
'serialization_format': x.serialization_format,
'offered_qos_profiles': x.offered_qos_profiles,
},
'message_count': x[5],
} for x in self.topics.values()
'message_count': x.count,
} for x in self.connections.values()
],
'compression_format': self.compression_format,
'compression_mode': self.compression_mode,

View File

@ -76,17 +76,29 @@ def test_convert(tmp_path: Path):
patch('rosbags.convert.converter.register_types') as register_types, \
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
connections = [
Mock(topic='/topic', msgtype='typ', latching=False),
Mock(topic='/topic', msgtype='typ', latching=True),
]
wconnections = [
Mock(topic='/topic', msgtype='typ'),
Mock(topic='/topic', msgtype='typ'),
]
reader.return_value.__enter__.return_value.connections = {
0: Mock(topic='/topic', latching=False),
1: Mock(topic='/latched', latching=True),
}
reader.return_value.__enter__.return_value.topics = {
'/topic': Mock(msgtype='typ', msgdef='def'),
'/latched': Mock(msgtype='typ', msgdef='def'),
1: connections[0],
2: connections[1],
}
reader.return_value.__enter__.return_value.messages.return_value = [
('/topic', 'typ', 42, b'\x42'),
('/latched', 'typ', 43, b'\x43'),
(connections[0], 42, b'\x42'),
(connections[1], 43, b'\x43'),
]
writer.return_value.__enter__.return_value.add_connection.side_effect = [
wconnections[0],
wconnections[1],
]
ros1_to_cdr.return_value = b'666'
@ -97,14 +109,29 @@ def test_convert(tmp_path: Path):
reader.return_value.__enter__.return_value.messages.assert_called_with()
writer.assert_called_with(Path('foo'))
writer.return_value.__enter__.return_value.add_topic.assert_has_calls(
writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
[
call('/topic', 'typ', offered_qos_profiles=''),
call('/latched', 'typ', offered_qos_profiles=LATCH),
call(
id=-1,
count=0,
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles='',
),
call(
id=-1,
count=0,
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles=LATCH,
),
],
)
writer.return_value.__enter__.return_value.write.assert_has_calls(
[call('/topic', 42, b'666'), call('/latched', 43, b'666')],
[call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666')],
)
register_types.assert_called_with({'typ': 'def'})

View File

@ -126,25 +126,26 @@ def test_reader(bag: Path):
assert reader.message_count == 4
if reader.compression_mode:
assert reader.compression_format == 'zstd'
assert [*reader.connections.keys()] == [1, 2, 3]
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
gen = reader.messages()
topic, msgtype, timestamp, rawdata = next(gen)
assert topic == '/poly'
assert msgtype == 'geometry_msgs/msg/Polygon'
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/poly'
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
assert timestamp == 666
assert rawdata == MSG_POLY[0]
for idx in range(2):
topic, msgtype, timestamp, rawdata = next(gen)
assert topic == '/magn'
assert msgtype == 'sensor_msgs/msg/MagneticField'
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/magn'
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
assert timestamp == 708
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
topic, msgtype, timestamp, rawdata = next(gen)
assert topic == '/joint'
assert msgtype == 'trajectory_msgs/msg/JointTrajectory'
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/joint'
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
with pytest.raises(StopIteration):
next(gen)
@ -153,32 +154,32 @@ def test_reader(bag: Path):
def test_message_filters(bag: Path):
"""Test reader filters messages."""
with Reader(bag) as reader:
gen = reader.messages(['/magn'])
topic, _, _, _ = next(gen)
assert topic == '/magn'
topic, _, _, _ = next(gen)
assert topic == '/magn'
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
gen = reader.messages(connections=magn_connections)
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(start=667)
topic, _, _, _ = next(gen)
assert topic == '/magn'
topic, _, _, _ = next(gen)
assert topic == '/magn'
topic, _, _, _ = next(gen)
assert topic == '/joint'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/joint'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(stop=667)
topic, _, _, _ = next(gen)
assert topic == '/poly'
connection, _, _ = next(gen)
assert connection.topic == '/poly'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(['/magn'], stop=667)
gen = reader.messages(connections=magn_connections, stop=667)
with pytest.raises(StopIteration):
next(gen)

View File

@ -141,7 +141,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
header['index_pos'] = pack('<Q', pos)
header = ser(header)
header += b'\x00' * (4096 - len(header))
header += b'\x20' * (4096 - len(header))
bag.write_bytes(b''.join([
magic,
@ -227,8 +227,10 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert reader.topics['/topic0'].msgcount == 2
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][3] == b'MSGCONTENT5'
assert msgs[1][3] == b'MSGCONTENT10'
assert msgs[0][0].topic == '/topic0'
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][0].topic == '/topic0'
assert msgs[1][2] == b'MSGCONTENT10'
# sorts by time on different topic
write_bag(
@ -249,20 +251,20 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert reader.topics['/topic2'].msgcount == 1
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][3] == b'MSGCONTENT5'
assert msgs[1][3] == b'MSGCONTENT10'
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][2] == b'MSGCONTENT10'
msgs = list(reader.messages(['/topic0']))
assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT10'
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(start=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT10'
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(stop=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT5'
assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path):

View File

@ -29,14 +29,15 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
with wbag:
msgtype = 'std_msgs/msg/Float64'
wbag.add_topic('/test', msgtype)
wbag.write('/test', 42, serialize_cdr(Foo, msgtype))
wconnection = wbag.add_connection('/test', msgtype)
wbag.write(wconnection, 42, serialize_cdr(Foo, msgtype))
rbag = Reader(path)
with rbag:
gen = rbag.messages()
_, msgtype, _, raw = next(gen)
msg = deserialize_cdr(raw, msgtype)
rconnection, _, raw = next(gen)
assert rconnection == wconnection
msg = deserialize_cdr(raw, rconnection.msgtype)
assert msg.data == Foo.data
with pytest.raises(StopIteration):
next(gen)

View File

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
import pytest
from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection
if TYPE_CHECKING:
from pathlib import Path
@ -18,9 +19,9 @@ def test_writer(tmp_path: Path):
"""Test Writer."""
path = (tmp_path / 'rosbag2')
with Writer(path) as bag:
bag.add_topic('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096)
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'rosbag2.db3').exists()
size = (path / 'rosbag2.db3').stat().st_size
@ -29,9 +30,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path)
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096)
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'compress_none.db3').exists()
assert size == (path / 'compress_none.db3').stat().st_size
@ -40,9 +41,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path)
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096)
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert not (path / 'compress_file.db3').exists()
assert (path / 'compress_file.db3.zstd').exists()
@ -51,9 +52,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path)
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096)
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'compress_message.db3').exists()
assert size > (path / 'compress_message.db3').stat().st_size
@ -76,7 +77,7 @@ def test_failure_cases(tmp_path: Path):
bag = Writer(tmp_path / 'topic')
with pytest.raises(WriterError, match='was not opened'):
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'write')
with pytest.raises(WriterError, match='was not opened'):
@ -84,11 +85,12 @@ def test_failure_cases(tmp_path: Path):
bag = Writer(tmp_path / 'topic')
bag.open()
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
with pytest.raises(WriterError, match='only be added once'):
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'notopic')
bag.open()
with pytest.raises(WriterError, match='unknown topic'):
bag.write('/test', 42, b'\x00')
connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '')
with pytest.raises(WriterError, match='unknown connection'):
bag.write(connection, 42, b'\x00')