Change to connection oriented reader API

This commit is contained in:
Marko Durkovic 2021-08-01 18:22:36 +02:00 committed by Florian Friesdorf
parent ebf357a0c6
commit f33e65b14a
13 changed files with 290 additions and 172 deletions

View File

@ -41,8 +41,8 @@ Read and deserialize rosbag2 messages:
# create reader instance and open for reading # create reader instance and open for reading
with Reader('/home/ros/rosbag_2020_03_24') as reader: with Reader('/home/ros/rosbag_2020_03_24') as reader:
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): for connection, timestamp, rawdata in reader.messages(['/imu_raw/Imu']):
msg = deserialize_cdr(rawdata, msgtype) msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id) print(msg.header.frame_id)

View File

@ -13,15 +13,16 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
# create reader instance # create reader instance
with Reader('/home/ros/rosbag_2020_03_24.bag') as reader: with Reader('/home/ros/rosbag_2020_03_24.bag') as reader:
# topic and msgtype information is available on .topics dictionary # topic and msgtype information is available on .connections dictionary
for topic, info in reader.topics.items(): for connection in reader.connections.values():
print(topic, info) print(connection.topic, connection.msgtype)
# iterate over messages # iterate over messages
for topic, msgtype, timestamp, rawdata in reader.messages(): for connection, timestamp, rawdata in reader.messages():
if topic == '/imu_raw/Imu': if connection.topic == '/imu_raw/Imu':
print(timestamp) print(timestamp)
# messages() accepts topic filters # messages() accepts connection filters
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu']
for connection, timestamp, rawdata in reader.messages(connections=connections):
print(timestamp) print(timestamp)

View File

@ -27,16 +27,18 @@ Instances of the :py:class:`Writer <rosbags.rosbag2.Writer>` class can create an
from rosbags.rosbag2 import Writer from rosbags.rosbag2 import Writer
from rosbags.serde import serialize_cdr from rosbags.serde import serialize_cdr
from rosbags.typesys.types import std_msgs__msg__String as String
# create writer instance and open for writing # create writer instance and open for writing
with Writer('/home/ros/rosbag_2020_03_24') as writer: with Writer('/home/ros/rosbag_2020_03_24') as writer:
# add new topic # add new connection
topic = '/imu_raw/Imu' topic = '/chatter'
msgtype = 'sensor_msgs/msg/Imu' msgtype = String.__msgtype__
writer.add_topic(topic, msgtype, 'cdr') connection = writer.add_connection(topic, msgtype, 'cdr', '')
# serialize and write message # serialize and write message
writer.write(topic, timestamp, serialize_cdr(message, msgtype)) message = String('hello world')
writer.write(connection, timestamp, serialize_cdr(message, msgtype))
Reading rosbag2 Reading rosbag2
--------------- ---------------
@ -49,17 +51,18 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are used to r
# create reader instance and open for reading # create reader instance and open for reading
with Reader('/home/ros/rosbag_2020_03_24') as reader: with Reader('/home/ros/rosbag_2020_03_24') as reader:
# topic and msgtype information is available on .topics dict # topic and msgtype information is available on .connections dict
for topic, msgtype in reader.topics.items(): for connection in reader.connections.values():
print(topic, msgtype) print(connection.topic, connection.msgtype)
# iterate over messages # iterate over messages
for topic, msgtype, timestamp, rawdata in reader.messages(): for connection, timestamp, rawdata in reader.messages():
if topic == '/imu_raw/Imu': if connection.topic == '/imu_raw/Imu':
msg = deserialize_cdr(rawdata, msgtype) msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id) print(msg.header.frame_id)
# messages() accepts topic filters # messages() accepts connection filters
for topic, msgtype, timestamp, rawdata in reader.messages(['/imu_raw/Imu']): connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu']
msg = deserialize_cdr(rawdata, msgtype) for connection, timestamp, rawdata in reader.messages(connections=connections):
msg = deserialize_cdr(rawdata, connection.msgtype)
print(msg.header.frame_id) print(msg.header.frame_id)

View File

@ -4,10 +4,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from rosbags.rosbag1 import Reader, ReaderError from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag2 import Writer, WriterError from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection as WConnection
from rosbags.serde import ros1_to_cdr from rosbags.serde import ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types from rosbags.typesys import get_types_from_msg, register_types
@ -15,6 +17,8 @@ if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from rosbags.rosbag1.reader import Connection as RConnection
LATCH = """ LATCH = """
- history: 3 - history: 3
depth: 0 depth: 0
@ -38,6 +42,26 @@ class ConverterError(Exception):
"""Converter Error.""" """Converter Error."""
def convert_connection(rconn: RConnection) -> WConnection:
"""Convert rosbag1 connection to rosbag2 connection.
Args:
rconn: Rosbag1 connection.
Returns:
Rosbag2 connection.
"""
return WConnection(
-1,
0,
rconn.topic,
rconn.msgtype,
'cdr',
LATCH if rconn.latching else '',
)
def convert(src: Path, dst: Optional[Path]) -> None: def convert(src: Path, dst: Optional[Path]) -> None:
"""Convert Rosbag1 to Rosbag2. """Convert Rosbag1 to Rosbag2.
@ -57,21 +81,19 @@ def convert(src: Path, dst: Optional[Path]) -> None:
try: try:
with Reader(src) as reader, Writer(dst) as writer: with Reader(src) as reader, Writer(dst) as writer:
typs: Dict[str, Any] = {} typs: Dict[str, Any] = {}
for name, topic in reader.topics.items(): connmap: Dict[int, WConnection] = {}
connection = next( # pragma: no branch
x for x in reader.connections.values() if x.topic == name for rconn in reader.connections.values():
) candidate = convert_connection(rconn)
writer.add_topic( existing = next((x for x in writer.connections.values() if x == candidate), None)
name, wconn = existing if existing else writer.add_connection(**asdict(candidate))
topic.msgtype, connmap[rconn.cid] = wconn
offered_qos_profiles=LATCH if connection.latching else '', typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
)
typs.update(get_types_from_msg(topic.msgdef, topic.msgtype))
register_types(typs) register_types(typs)
for topic, msgtype, timestamp, data in reader.messages(): for rconn, timestamp, data in reader.messages():
data = ros1_to_cdr(data, msgtype) data = ros1_to_cdr(data, rconn.msgtype)
writer.write(topic, timestamp, data) writer.write(connmap[rconn.cid], timestamp, data)
except ReaderError as err: except ReaderError as err:
raise ConverterError(f'Reading source bag: {err}') from err raise ConverterError(f'Reading source bag: {err}') from err
except WriterError as err: except WriterError as err:

View File

@ -249,7 +249,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod @classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header': def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle. """Read header from file handle.
Args: Args:
@ -588,7 +588,7 @@ class Reader:
topics: Optional[Iterable[str]] = None, topics: Optional[Iterable[str]] = None,
start: Optional[int] = None, start: Optional[int] = None,
stop: Optional[int] = None, stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]: ) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag. """Read messages from bag.
Args: Args:
@ -598,7 +598,7 @@ class Reader:
stop: Yield only messages before this timestamp (ns). stop: Yield only messages before this timestamp (ns).
Yields: Yields:
Tuples of topic name, type, timestamp (ns), and rawdata. Tuples of connection, timestamp (ns), and rawdata.
Raises: Raises:
ReaderError: Bag not open or data corrupt. ReaderError: Bag not open or data corrupt.
@ -635,13 +635,10 @@ class Reader:
if have != RecordType.MSGDATA: if have != RecordType.MSGDATA:
raise ReaderError('Expected to find message data.') raise ReaderError('Expected to find message data.')
connection = self.connections[header.get_uint32('conn')]
time = header.get_time('time')
data = read_bytes(chunk, read_uint32(chunk)) data = read_bytes(chunk, read_uint32(chunk))
connection = self.connections[header.get_uint32('conn')]
assert entry.time == time assert entry.time == header.get_time('time')
yield connection.topic, connection.msgtype, time, data yield connection, entry.time, data
def __enter__(self) -> Reader: def __enter__(self) -> Reader:
"""Open rosbag1 when entering contextmanager.""" """Open rosbag1 when entering contextmanager."""

View File

@ -0,0 +1,18 @@
# Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 connection."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class Connection:
"""Connection metadata."""
id: int = field(compare=False) # pylint: disable=invalid-name
count: int = field(compare=False)
topic: str
msgtype: str
serialization_format: str
offered_qos_profiles: str

View File

@ -13,9 +13,11 @@ from typing import TYPE_CHECKING
import zstandard import zstandard
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
class ReaderError(Exception): class ReaderError(Exception):
@ -96,11 +98,21 @@ class Reader:
if missing: if missing:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
topics = [x['topic_metadata'] for x in self.metadata['topics_with_message_count']] self.connections = {
noncdr = {y for x in topics if (y := x['serialization_format']) != 'cdr'} idx + 1: Connection(
id=idx + 1,
count=x['message_count'],
topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'],
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
}
noncdr = {
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
}
if noncdr: if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.') raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
self.topics = {x['name']: x['type'] for x in topics}
if self.compression_mode and (cfmt := self.compression_format) != 'zstd': if self.compression_mode and (cfmt := self.compression_format) != 'zstd':
raise ReaderError(f'Compression format {cfmt!r} is not supported.') raise ReaderError(f'Compression format {cfmt!r} is not supported.')
@ -149,22 +161,31 @@ class Reader:
mode = self.metadata.get('compression_mode', '').lower() mode = self.metadata.get('compression_mode', '').lower()
return mode if mode != 'none' else None return mode if mode != 'none' else None
@property
def topics(self) -> Dict[str, Connection]:
"""Topic information.
For the moment this a dictionary mapping topic names to connections.
"""
return {x.topic: x for x in self.connections.values()}
def messages( # pylint: disable=too-many-locals def messages( # pylint: disable=too-many-locals
self, self,
topics: Iterable[str] = (), connections: Iterable[Connection] = (),
start: Optional[int] = None, start: Optional[int] = None,
stop: Optional[int] = None, stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]: ) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag. """Read messages from bag.
Args: Args:
topics: Iterable with topic names to filter for. An empty iterable connections: Iterable with connections to filter for. An empty
yields all messages. iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns). start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns). stop: Yield only messages before this timestamp (ns).
Yields: Yields:
Tuples of topic name, type, timestamp (ns), and rawdata. Tuples of connection, timestamp (ns), and rawdata.
Raises: Raises:
ReaderError: Bag not open. ReaderError: Bag not open.
@ -173,7 +194,32 @@ class Reader:
if not self.bio: if not self.bio:
raise ReaderError('Rosbag is not open.') raise ReaderError('Rosbag is not open.')
topics = tuple(topics) query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for filepath in self.paths: for filepath in self.paths:
with decompress(filepath, self.compression_mode == 'file') as path: with decompress(filepath, self.compression_mode == 'file') as path:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True) conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
@ -186,34 +232,16 @@ class Reader:
if cur.fetchone()[0] != 2: if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.') raise ReaderError(f'Cannot open database {path} or database missing tables.')
query = [ cur.execute(querystr, args)
'SELECT topics.name,topics.type,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
if topics:
query.append(f'WHERE topics.name IN ({",".join("?" for _ in topics)})')
args += topics
if start is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp >= ?')
args.append(start)
if stop is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp < ?')
args.append(stop)
query.append('ORDER BY timestamp')
cur.execute(' '.join(query), args)
if self.compression_mode == 'message': if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress decomp = zstandard.ZstdDecompressor().decompress
for row in cur: for row in cur:
topic, msgtype, timestamp, data = row cid, timestamp, data = row
yield topic, msgtype, timestamp, decomp(data) yield self.connections[cid], timestamp, decomp(data)
else: else:
yield from cur for cid, timestamp, data in cur:
yield self.connections[cid], timestamp, data
def __enter__(self) -> Reader: def __enter__(self) -> Reader:
"""Open rosbag2 when entering contextmanager.""" """Open rosbag2 when entering contextmanager."""

View File

@ -1,6 +1,6 @@
# Copyright 2020-2021 Ternaris. # Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Rosbag2 reader.""" """Rosbag2 writer."""
from __future__ import annotations from __future__ import annotations
@ -12,6 +12,8 @@ from typing import TYPE_CHECKING
import zstandard import zstandard
from ruamel.yaml import YAML from ruamel.yaml import YAML
from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Literal, Optional, Type, Union from typing import Any, Dict, Literal, Optional, Type, Union
@ -77,10 +79,9 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_mode = '' self.compression_mode = ''
self.compression_format = '' self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None self.compressor: Optional[zstandard.ZstdCompressor] = None
self.topics: Dict[str, Any] = {} self.connections: Dict[int, Connection] = {}
self.conn = None self.conn = None
self.cursor: Optional[sqlite3.Cursor] = None self.cursor: Optional[sqlite3.Cursor] = None
self.topics = {}
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat): def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
"""Enable compression on bag. """Enable compression on bag.
@ -118,22 +119,27 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.executescript(self.SQLITE_SCHEMA) self.conn.executescript(self.SQLITE_SCHEMA)
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
def add_topic( def add_connection(
self, self,
name: str, topic: str,
typ: str, msgtype: str,
serialization_format: str = 'cdr', serialization_format: str = 'cdr',
offered_qos_profiles: str = '', offered_qos_profiles: str = '',
): **_kw: Any,
"""Add a topic. ) -> Connection:
"""Add a connection.
This function can only be called after opening a bag. This function can only be called after opening a bag.
Args: Args:
name: Topic name. topic: Topic name.
typ: Message type. msgtype: Message type.
serialization_format: Serialization format. serialization_format: Serialization format.
offered_qos_profiles: QOS Profile. offered_qos_profiles: QOS Profile.
_kw: Ignored to allow consuming dicts from connection objects.
Returns:
Connection object.
Raises: Raises:
WriterError: Bag not open or topic previously registered. WriterError: Bag not open or topic previously registered.
@ -141,17 +147,28 @@ class Writer: # pylint: disable=too-many-instance-attributes
""" """
if not self.cursor: if not self.cursor:
raise WriterError('Bag was not opened.') raise WriterError('Bag was not opened.')
if name in self.topics:
raise WriterError(f'Topics can only be added once: {name!r}.')
meta = (len(self.topics) + 1, name, typ, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
self.topics[name] = [*meta, 0]
def write(self, topic: str, timestamp: int, data: bytes): connection = Connection(
id=len(self.connections.values()) + 1,
count=0,
topic=topic,
msgtype=msgtype,
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
)
if connection in self.connections.values():
raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
"""Write message to rosbag2. """Write message to rosbag2.
Args: Args:
topic: Topic message belongs to. connection: Connection to write message to.
timestamp: Message timestamp (ns). timestamp: Message timestamp (ns).
data: Serialized message data. data: Serialized message data.
@ -161,19 +178,18 @@ class Writer: # pylint: disable=too-many-instance-attributes
""" """
if not self.cursor: if not self.cursor:
raise WriterError('Bag was not opened.') raise WriterError('Bag was not opened.')
if topic not in self.topics: if connection not in self.connections.values():
raise WriterError(f'Tried to write to unknown topic {topic!r}.') raise WriterError(f'Tried to write to unknown connection {connection!r}.')
if self.compression_mode == 'message': if self.compression_mode == 'message':
assert self.compressor assert self.compressor
data = self.compressor.compress(data) data = self.compressor.compress(data)
tmeta = self.topics[topic]
self.cursor.execute( self.cursor.execute(
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)', 'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
(tmeta[0], timestamp, data), (connection.id, timestamp, data),
) )
tmeta[-1] += 1 connection.count += 1
def close(self): def close(self):
"""Close rosbag2 after writing. """Close rosbag2 after writing.
@ -214,13 +230,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
'topics_with_message_count': [ 'topics_with_message_count': [
{ {
'topic_metadata': { 'topic_metadata': {
'name': x[1], 'name': x.topic,
'type': x[2], 'type': x.msgtype,
'serialization_format': x[3], 'serialization_format': x.serialization_format,
'offered_qos_profiles': x[4], 'offered_qos_profiles': x.offered_qos_profiles,
}, },
'message_count': x[5], 'message_count': x.count,
} for x in self.topics.values() } for x in self.connections.values()
], ],
'compression_format': self.compression_format, 'compression_format': self.compression_format,
'compression_mode': self.compression_mode, 'compression_mode': self.compression_mode,

View File

@ -76,17 +76,29 @@ def test_convert(tmp_path: Path):
patch('rosbags.convert.converter.register_types') as register_types, \ patch('rosbags.convert.converter.register_types') as register_types, \
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr: patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
connections = [
Mock(topic='/topic', msgtype='typ', latching=False),
Mock(topic='/topic', msgtype='typ', latching=True),
]
wconnections = [
Mock(topic='/topic', msgtype='typ'),
Mock(topic='/topic', msgtype='typ'),
]
reader.return_value.__enter__.return_value.connections = { reader.return_value.__enter__.return_value.connections = {
0: Mock(topic='/topic', latching=False), 1: connections[0],
1: Mock(topic='/latched', latching=True), 2: connections[1],
}
reader.return_value.__enter__.return_value.topics = {
'/topic': Mock(msgtype='typ', msgdef='def'),
'/latched': Mock(msgtype='typ', msgdef='def'),
} }
reader.return_value.__enter__.return_value.messages.return_value = [ reader.return_value.__enter__.return_value.messages.return_value = [
('/topic', 'typ', 42, b'\x42'), (connections[0], 42, b'\x42'),
('/latched', 'typ', 43, b'\x43'), (connections[1], 43, b'\x43'),
]
writer.return_value.__enter__.return_value.add_connection.side_effect = [
wconnections[0],
wconnections[1],
] ]
ros1_to_cdr.return_value = b'666' ros1_to_cdr.return_value = b'666'
@ -97,14 +109,29 @@ def test_convert(tmp_path: Path):
reader.return_value.__enter__.return_value.messages.assert_called_with() reader.return_value.__enter__.return_value.messages.assert_called_with()
writer.assert_called_with(Path('foo')) writer.assert_called_with(Path('foo'))
writer.return_value.__enter__.return_value.add_topic.assert_has_calls( writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
[ [
call('/topic', 'typ', offered_qos_profiles=''), call(
call('/latched', 'typ', offered_qos_profiles=LATCH), id=-1,
count=0,
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles='',
),
call(
id=-1,
count=0,
topic='/topic',
msgtype='typ',
serialization_format='cdr',
offered_qos_profiles=LATCH,
),
], ],
) )
writer.return_value.__enter__.return_value.write.assert_has_calls( writer.return_value.__enter__.return_value.write.assert_has_calls(
[call('/topic', 42, b'666'), call('/latched', 43, b'666')], [call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666')],
) )
register_types.assert_called_with({'typ': 'def'}) register_types.assert_called_with({'typ': 'def'})

View File

@ -126,25 +126,26 @@ def test_reader(bag: Path):
assert reader.message_count == 4 assert reader.message_count == 4
if reader.compression_mode: if reader.compression_mode:
assert reader.compression_format == 'zstd' assert reader.compression_format == 'zstd'
assert [*reader.connections.keys()] == [1, 2, 3]
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
gen = reader.messages() gen = reader.messages()
topic, msgtype, timestamp, rawdata = next(gen) connection, timestamp, rawdata = next(gen)
assert topic == '/poly' assert connection.topic == '/poly'
assert msgtype == 'geometry_msgs/msg/Polygon' assert connection.msgtype == 'geometry_msgs/msg/Polygon'
assert timestamp == 666 assert timestamp == 666
assert rawdata == MSG_POLY[0] assert rawdata == MSG_POLY[0]
for idx in range(2): for idx in range(2):
topic, msgtype, timestamp, rawdata = next(gen) connection, timestamp, rawdata = next(gen)
assert topic == '/magn' assert connection.topic == '/magn'
assert msgtype == 'sensor_msgs/msg/MagneticField' assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
assert timestamp == 708 assert timestamp == 708
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0] assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
topic, msgtype, timestamp, rawdata = next(gen) connection, timestamp, rawdata = next(gen)
assert topic == '/joint' assert connection.topic == '/joint'
assert msgtype == 'trajectory_msgs/msg/JointTrajectory' assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)
@ -153,32 +154,32 @@ def test_reader(bag: Path):
def test_message_filters(bag: Path): def test_message_filters(bag: Path):
"""Test reader filters messages.""" """Test reader filters messages."""
with Reader(bag) as reader: with Reader(bag) as reader:
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
gen = reader.messages(['/magn']) gen = reader.messages(connections=magn_connections)
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/magn' assert connection.topic == '/magn'
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/magn' assert connection.topic == '/magn'
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)
gen = reader.messages(start=667) gen = reader.messages(start=667)
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/magn' assert connection.topic == '/magn'
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/magn' assert connection.topic == '/magn'
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/joint' assert connection.topic == '/joint'
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)
gen = reader.messages(stop=667) gen = reader.messages(stop=667)
topic, _, _, _ = next(gen) connection, _, _ = next(gen)
assert topic == '/poly' assert connection.topic == '/poly'
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)
gen = reader.messages(['/magn'], stop=667) gen = reader.messages(connections=magn_connections, stop=667)
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)

View File

@ -141,7 +141,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
header['index_pos'] = pack('<Q', pos) header['index_pos'] = pack('<Q', pos)
header = ser(header) header = ser(header)
header += b'\x00' * (4096 - len(header)) header += b'\x20' * (4096 - len(header))
bag.write_bytes(b''.join([ bag.write_bytes(b''.join([
magic, magic,
@ -227,8 +227,10 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert reader.topics['/topic0'].msgcount == 2 assert reader.topics['/topic0'].msgcount == 2
msgs = list(reader.messages()) msgs = list(reader.messages())
assert len(msgs) == 2 assert len(msgs) == 2
assert msgs[0][3] == b'MSGCONTENT5' assert msgs[0][0].topic == '/topic0'
assert msgs[1][3] == b'MSGCONTENT10' assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][0].topic == '/topic0'
assert msgs[1][2] == b'MSGCONTENT10'
# sorts by time on different topic # sorts by time on different topic
write_bag( write_bag(
@ -249,20 +251,20 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert reader.topics['/topic2'].msgcount == 1 assert reader.topics['/topic2'].msgcount == 1
msgs = list(reader.messages()) msgs = list(reader.messages())
assert len(msgs) == 2 assert len(msgs) == 2
assert msgs[0][3] == b'MSGCONTENT5' assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][3] == b'MSGCONTENT10' assert msgs[1][2] == b'MSGCONTENT10'
msgs = list(reader.messages(['/topic0'])) msgs = list(reader.messages(['/topic0']))
assert len(msgs) == 1 assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT10' assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(start=7 * 10**9)) msgs = list(reader.messages(start=7 * 10**9))
assert len(msgs) == 1 assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT10' assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(stop=7 * 10**9)) msgs = list(reader.messages(stop=7 * 10**9))
assert len(msgs) == 1 assert len(msgs) == 1
assert msgs[0][3] == b'MSGCONTENT5' assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path): def test_user_errors(tmp_path):

View File

@ -29,14 +29,15 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD) wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
with wbag: with wbag:
msgtype = 'std_msgs/msg/Float64' msgtype = 'std_msgs/msg/Float64'
wbag.add_topic('/test', msgtype) wconnection = wbag.add_connection('/test', msgtype)
wbag.write('/test', 42, serialize_cdr(Foo, msgtype)) wbag.write(wconnection, 42, serialize_cdr(Foo, msgtype))
rbag = Reader(path) rbag = Reader(path)
with rbag: with rbag:
gen = rbag.messages() gen = rbag.messages()
_, msgtype, _, raw = next(gen) rconnection, _, raw = next(gen)
msg = deserialize_cdr(raw, msgtype) assert rconnection == wconnection
msg = deserialize_cdr(raw, rconnection.msgtype)
assert msg.data == Foo.data assert msg.data == Foo.data
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)

View File

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
import pytest import pytest
from rosbags.rosbag2 import Writer, WriterError from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
@ -18,9 +19,9 @@ def test_writer(tmp_path: Path):
"""Test Writer.""" """Test Writer."""
path = (tmp_path / 'rosbag2') path = (tmp_path / 'rosbag2')
with Writer(path) as bag: with Writer(path) as bag:
bag.add_topic('/test', 'std_msgs/msg/Int8') connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00') bag.write(connection, 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096) bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists() assert (path / 'metadata.yaml').exists()
assert (path / 'rosbag2.db3').exists() assert (path / 'rosbag2.db3').exists()
size = (path / 'rosbag2.db3').stat().st_size size = (path / 'rosbag2.db3').stat().st_size
@ -29,9 +30,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path) bag = Writer(path)
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD) bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
with bag: with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8') connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00') bag.write(connection, 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096) bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists() assert (path / 'metadata.yaml').exists()
assert (path / 'compress_none.db3').exists() assert (path / 'compress_none.db3').exists()
assert size == (path / 'compress_none.db3').stat().st_size assert size == (path / 'compress_none.db3').stat().st_size
@ -40,9 +41,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path) bag = Writer(path)
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD) bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
with bag: with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8') connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00') bag.write(connection, 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096) bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists() assert (path / 'metadata.yaml').exists()
assert not (path / 'compress_file.db3').exists() assert not (path / 'compress_file.db3').exists()
assert (path / 'compress_file.db3.zstd').exists() assert (path / 'compress_file.db3.zstd').exists()
@ -51,9 +52,9 @@ def test_writer(tmp_path: Path):
bag = Writer(path) bag = Writer(path)
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD) bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
with bag: with bag:
bag.add_topic('/test', 'std_msgs/msg/Int8') connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write('/test', 42, b'\x00') bag.write(connection, 42, b'\x00')
bag.write('/test', 666, b'\x01' * 4096) bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists() assert (path / 'metadata.yaml').exists()
assert (path / 'compress_message.db3').exists() assert (path / 'compress_message.db3').exists()
assert size > (path / 'compress_message.db3').stat().st_size assert size > (path / 'compress_message.db3').stat().st_size
@ -76,7 +77,7 @@ def test_failure_cases(tmp_path: Path):
bag = Writer(tmp_path / 'topic') bag = Writer(tmp_path / 'topic')
with pytest.raises(WriterError, match='was not opened'): with pytest.raises(WriterError, match='was not opened'):
bag.add_topic('/tf', 'tf_msgs/msg/tf2') bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'write') bag = Writer(tmp_path / 'write')
with pytest.raises(WriterError, match='was not opened'): with pytest.raises(WriterError, match='was not opened'):
@ -84,11 +85,12 @@ def test_failure_cases(tmp_path: Path):
bag = Writer(tmp_path / 'topic') bag = Writer(tmp_path / 'topic')
bag.open() bag.open()
bag.add_topic('/tf', 'tf_msgs/msg/tf2') bag.add_connection('/tf', 'tf_msgs/msg/tf2')
with pytest.raises(WriterError, match='only be added once'): with pytest.raises(WriterError, match='only be added once'):
bag.add_topic('/tf', 'tf_msgs/msg/tf2') bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'notopic') bag = Writer(tmp_path / 'notopic')
bag.open() bag.open()
with pytest.raises(WriterError, match='unknown topic'): connection = Connection(1, 0, '/tf', 'tf_msgs/msg/tf2', 'cdr', '')
bag.write('/test', 42, b'\x00') with pytest.raises(WriterError, match='unknown connection'):
bag.write(connection, 42, b'\x00')