From dff38bdb6023909c64105b02eab8bf1161c45591 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Thu, 2 Mar 2023 12:55:12 +0100 Subject: [PATCH] Add rosbag2 mcap storage reader --- docs/topics/rosbag2.rst | 3 +- setup.cfg | 1 + src/rosbags/highlevel/anyreader.py | 41 +- src/rosbags/rosbag2/reader.py | 12 +- src/rosbags/rosbag2/storage_mcap.py | 571 ++++++++++++++++++++++++++++ tests/test_highlevel.py | 60 +++ tests/test_reader.py | 429 +++++++++++++++++++++ 7 files changed, 1096 insertions(+), 21 deletions(-) create mode 100644 src/rosbags/rosbag2/storage_mcap.py diff --git a/docs/topics/rosbag2.rst b/docs/topics/rosbag2.rst index dbee4d06..21fbecf4 100644 --- a/docs/topics/rosbag2.rst +++ b/docs/topics/rosbag2.rst @@ -4,7 +4,7 @@ The :py:mod:`rosbags.rosbag2` package provides a conformant implementation of ro Supported Versions ------------------ -All versions up to the current (ROS2 Foxy) version 4 are supported. +All versions up to the current (ROS2 Humble) version 6 are supported. Supported Features ------------------ @@ -18,6 +18,7 @@ Rosbag2 is a flexible format that supports plugging different serialization meth :Storages: - sqlite3 + - mcap Writing rosbag2 --------------- diff --git a/setup.cfg b/setup.cfg index 0b4ded51..41515ad5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,6 +12,7 @@ keywords = conversion deserialization idl + mcap message msg reader diff --git a/src/rosbags/highlevel/anyreader.py b/src/rosbags/highlevel/anyreader.py index d9562af0..1363fec7 100644 --- a/src/rosbags/highlevel/anyreader.py +++ b/src/rosbags/highlevel/anyreader.py @@ -17,6 +17,7 @@ from rosbags.rosbag2 import Reader as Reader2 from rosbags.rosbag2 import ReaderError as ReaderError2 from rosbags.serde import deserialize_cdr, deserialize_ros1 from rosbags.typesys import get_types_from_msg, register_types, types +from rosbags.typesys.idl import get_types_from_idl if TYPE_CHECKING: import sys @@ -125,24 +126,34 @@ class AnyReader: reader.close() raise AnyReaderError(*err.args) from err + for key in [ + 'builtin_interfaces/msg/Time', + 'builtin_interfaces/msg/Duration', + 'std_msgs/msg/Header', + ]: + self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key] + attr = key.replace('/', '__') + setattr(self.typestore, attr, getattr(types, attr)) + typs: dict[str, Any] = {} if self.is2: - for key, value in types.FIELDDEFS.items(): - self.typestore.FIELDDEFS[key] = value - attr = key.replace('/', '__') - setattr(self.typestore, attr, getattr(types, attr)) + reader = self.readers[0] + assert isinstance(reader, Reader2) + if reader.metadata['storage_identifier'] == 'mcap': + for connection in reader.connections: + if connection.md5sum: + if connection.md5sum == 'idl': + typ = get_types_from_idl(connection.msgdef) + else: + typ = get_types_from_msg(connection.msgdef, connection.msgtype) + typs.update(typ) + register_types(typs, self.typestore) + else: + for key, value in types.FIELDDEFS.items(): + self.typestore.FIELDDEFS[key] = value + attr = key.replace('/', '__') + setattr(self.typestore, attr, getattr(types, attr)) else: - for key in [ - 'builtin_interfaces/msg/Time', - 'builtin_interfaces/msg/Duration', - 'std_msgs/msg/Header', - ]: - self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key] - attr = key.replace('/', '__') - setattr(self.typestore, attr, getattr(types, attr)) - - typs: dict[str, Any] = {} for reader in self.readers: - assert isinstance(reader, Reader1) for connection in reader.connections: typs.update(get_types_from_msg(connection.msgdef, connection.msgtype)) register_types(typs, self.typestore) diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index 02733a4e..aeb3ceff 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -15,6 +15,7 @@ from ruamel.yaml.error import YAMLError from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo from .errors import ReaderError +from .storage_mcap import ReaderMcap from .storage_sqlite3 import ReaderSqlite3 if TYPE_CHECKING: @@ -29,19 +30,19 @@ class StorageProtocol(Protocol): def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]): """Initialize.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover def open(self) -> None: """Open file.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover def close(self) -> None: """Close file.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover def get_definitions(self) -> dict[str, tuple[str, str]]: """Get message definitions.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover def messages( self, @@ -50,7 +51,7 @@ class StorageProtocol(Protocol): stop: Optional[int] = None, ) -> Generator[tuple[Connection, int, bytes], None, None]: """Get messages from file.""" - raise NotImplementedError + raise NotImplementedError # pragma: no cover class Reader: @@ -73,6 +74,7 @@ class Reader: # pylint: disable=too-many-instance-attributes STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = { + 'mcap': ReaderMcap, 'sqlite3': ReaderSqlite3, } diff --git a/src/rosbags/rosbag2/storage_mcap.py b/src/rosbags/rosbag2/storage_mcap.py new file mode 100644 index 00000000..324ef686 --- /dev/null +++ b/src/rosbags/rosbag2/storage_mcap.py @@ -0,0 +1,571 @@ +# Copyright 2020-2023 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Mcap storage.""" + +from __future__ import annotations + +import heapq +from io import BytesIO +from struct import iter_unpack, unpack_from +from typing import TYPE_CHECKING, NamedTuple + +import zstandard +from lz4.frame import decompress as lz4_decompress + +from .errors import ReaderError + +if TYPE_CHECKING: + from pathlib import Path + from typing import BinaryIO, Callable, Generator, Iterable, Optional + + from rosbags.interfaces import Connection + + +class Schema(NamedTuple): + """Schema.""" + + id: int + name: str + encoding: str + data: str + + +class Channel(NamedTuple): + """Channel.""" + + id: int + schema: str + topic: str + message_encoding: str + metadata: bytes # dict[str, str] + + +class Chunk(NamedTuple): + """Chunk.""" + + start_time: int + end_time: int + size: int + crc: int + compression: str + records: bytes + + +class ChunkInfo(NamedTuple): + """Chunk.""" + + message_start_time: int + message_end_time: int + chunk_start_offset: int + chunk_length: int + message_index_offsets: dict[int, int] + message_index_length: int + compression: str + compressed_size: int + uncompressed_size: int + channel_count: dict[int, int] + + +class Statistics(NamedTuple): + """Statistics.""" + + message_count: int + schema_count: int + channel_count: int + attachement_count: int + metadata_count: int + chunk_count: int + start_time: int + end_time: int + channel_message_counts: bytes + + +class Msg(NamedTuple): + """Message wrapper.""" + + timestamp: int + offset: int + connection: Optional[Connection] + data: Optional[bytes] + + +def read_sized(bio: BinaryIO) -> bytes: + """Read one record.""" + return bio.read(unpack_from(' None: + """Read one record.""" + bio.seek(unpack_from(' bytes: + """Read string.""" + return bio.read(unpack_from(' str: + """Read string.""" + return bio.read(unpack_from(' Generator[Msg, None, None]: + """Yield messages from chunk in time order.""" + yield Msg(chunk.message_start_time, 0, None, None) + + bio.seek(chunk.chunk_start_offset + 9 + 40 + len(chunk.compression)) + compressed_data = bio.read(chunk.compressed_size) + subio = BytesIO(DECOMPRESSORS[chunk.compression](compressed_data, chunk.uncompressed_size)) + + messages = [] + while (offset := subio.tell()) < chunk.uncompressed_size: + op_ = ord(subio.read(1)) + if op_ == 0x05: + recio = BytesIO(read_sized(subio)) + channel_id, _, log_time, _ = unpack_from( + ' None: + """Open MCAP.""" + try: + self.bio = self.path.open('rb') + except OSError as err: + raise ReaderError(f'Could not open file {str(self.path)!r}: {err.strerror}.') from err + + magic = self.bio.read(8) + if not magic: + raise ReaderError(f'File {str(self.path)!r} seems to be empty.') + + if magic != b'\x89MCAP0\r\n': + raise ReaderError('File magic is invalid.') + + op_ = ord(self.bio.read(1)) + if op_ != 0x01: + raise ReaderError('Unexpected record.') + + recio = BytesIO(read_sized(self.bio)) + profile = read_string(recio) + if profile != 'ros2': + raise ReaderError('Profile is not ros2.') + self.data_start = self.bio.tell() + + self.bio.seek(-37, 2) + footer_start = self.bio.tell() + data = self.bio.read() + magic = data[-8:] + if magic != b'\x89MCAP0\r\n': + raise ReaderError('File end magic is invalid.') + + assert len(data) == 37 + assert data[0:9] == b'\x02\x14\x00\x00\x00\x00\x00\x00\x00', data[0:9] + + summary_start, = unpack_from(' None: + """Read index from file.""" + bio = self.bio + assert bio + + schemas = self.schemas + channels = self.channels + chunks = self.chunks + + bio.seek(self.data_end) + while True: + op_ = ord(bio.read(1)) + + if op_ in (0x02, 0x0e): + break + + if op_ == 0x03: + bio.seek(8, 1) + key, = unpack_from(' None: + """Close MCAP.""" + assert self.bio + self.bio.close() + self.bio = None + + def meta_scan(self) -> None: + """Generate metadata by scanning through file.""" + assert self.bio + bio = self.bio + bio_size = self.data_end + bio.seek(self.data_start) + + schemas = self.schemas + channels = self.channels + + while bio.tell() < bio_size: + op_ = ord(bio.read(1)) + + if op_ == 0x03: + bio.seek(8, 1) + key, = unpack_from(' dict[str, tuple[str, str]]: + """Get schema definition.""" + if not self.schemas: + self.meta_scan() + return {schema.name: (schema.encoding[4:], schema.data) for schema in self.schemas.values()} + + def messages_scan( + self, + connections: Iterable[Connection], + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> Generator[tuple[Connection, int, bytes], None, None]: + """Read messages by scanning whole bag.""" + # pylint: disable=too-many-locals + assert self.bio + bio = self.bio + bio_size = self.data_end + bio.seek(self.data_start) + + schemas = self.schemas.copy() + channels = self.channels.copy() + + if channels: + read_meta = False + channel_map = { + cid: conn for conn in connections if ( + cid := next( + ( + cid for cid, x in self.channels.items() + if x.schema == conn.msgtype and x.topic == conn.topic + ), + None, + ) + ) + } + else: + read_meta = True + channel_map = {} + + if start is None: + start = 0 + if stop is None: + stop = 2**63 - 1 + + while bio.tell() < bio_size: + op_ = ord(bio.read(1)) + + if op_ == 0x03 and read_meta: + bio.seek(8, 1) + key, = unpack_from(' Generator[tuple[Connection, int, bytes], None, None]: + """Read messages from bag. + + Args: + connections: Iterable with connections to filter for. + start: Yield only messages at or after this timestamp (ns). + stop: Yield only messages before this timestamp (ns). + + Yields: + tuples of connection, timestamp (ns), and rawdata. + + """ + assert self.bio + + if not self.chunks: + yield from self.messages_scan(connections, start, stop) + return + + channel_map = { + cid: conn for conn in connections if ( + cid := next( + ( + cid for cid, x in self.channels.items() + if x.schema == conn.msgtype and x.topic == conn.topic + ), + None, + ) + ) + } + + chunks = [ + msgsrc( + x, + channel_map, + start or x.message_start_time, + stop or x.message_end_time + 1, + self.bio, + ) + for x in self.chunks + if x.message_start_time != 0 and (start is None or start < x.message_end_time) and + (stop is None or x.message_start_time < stop) and + (any(x.channel_count.get(cid, 0) for cid in channel_map)) + ] + + for timestamp, offset, connection, data in heapq.merge(*chunks): + if not offset: + continue + assert connection + assert data + yield connection, timestamp, data + + +class ReaderMcap: + """Mcap storage reader.""" + + def __init__( + self, + paths: Iterable[Path], + connections: Iterable[Connection], + ): + """Set up storage reader. + + Args: + paths: Paths of storage files. + connections: List of connections. + + """ + self.paths = paths + self.readers: list[MCAPFile] = [] + self.connections = connections + + def open(self) -> None: + """Open rosbag2.""" + self.readers = [MCAPFile(x) for x in self.paths] + for reader in self.readers: + reader.open() + + def close(self) -> None: + """Close rosbag2.""" + assert self.readers + for reader in self.readers: + reader.close() + self.readers = [] + + def get_definitions(self) -> dict[str, tuple[str, str]]: + """Get message definitions.""" + res = {} + for reader in self.readers: + res.update(reader.get_schema_definitions()) + return res + + def messages( + self, + connections: Iterable[Connection] = (), + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> Generator[tuple[Connection, int, bytes], None, None]: + """Read messages from bag. + + Args: + connections: Iterable with connections to filter for. An empty + iterable disables filtering on connections. + start: Yield only messages at or after this timestamp (ns). + stop: Yield only messages before this timestamp (ns). + + Yields: + tuples of connection, timestamp (ns), and rawdata. + + """ + connections = list(connections) or list(self.connections) + + for reader in self.readers: + yield from reader.messages(connections, start, stop) diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py index defe1b1b..9a4d5cb2 100644 --- a/tests/test_highlevel.py +++ b/tests/test_highlevel.py @@ -5,10 +5,12 @@ from __future__ import annotations from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from rosbags.highlevel import AnyReader, AnyReaderError +from rosbags.interfaces import Connection from rosbags.rosbag1 import Writer as Writer1 from rosbags.rosbag2 import Writer as Writer2 @@ -200,3 +202,61 @@ def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-out assert nxt[0].topic == '/topic1' with pytest.raises(StopIteration): next(gen) + + +def test_anyreader2_autoregister(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name + """Test AnyReader on rosbag2.""" + + class MockReader: + """Mock reader.""" + + # pylint: disable=too-few-public-methods + + def __init__(self, paths: list[Path]): + """Initialize mock.""" + _ = paths + self.metadata = {'storage_identifier': 'mcap'} + self.connections = [ + Connection( + 1, + '/foo', + 'test_msg/msg/Foo', + 'string foo', + 'msg', + 0, + None, # type: ignore + self, + ), + Connection( + 2, + '/bar', + 'test_msg/msg/Bar', + 'module test_msgs { module msg { struct Bar {string bar;}; }; };', + 'idl', + 0, + None, # type: ignore + self, + ), + Connection( + 3, + '/baz', + 'test_msg/msg/Baz', + '', + '', + 0, + None, # type: ignore + self, + ), + ] + + def open(self) -> None: + """Unused.""" + + with patch('rosbags.highlevel.anyreader.Reader2', MockReader), \ + patch('rosbags.highlevel.anyreader.register_types') as mock_register_types: + AnyReader([bags2[0]]).open() + mock_register_types.assert_called_once() + assert mock_register_types.call_args[0][0] == { + 'test_msg/msg/Foo': ([], [('foo', (1, 'string'))]), + 'test_msgs/msg/Bar': ([], [('bar', (1, 'string'))]), + } diff --git a/tests/test_reader.py b/tests/test_reader.py index 40583078..cb2a9b30 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -7,6 +7,9 @@ from __future__ import annotations import sqlite3 +import struct +from io import BytesIO +from itertools import groupby from pathlib import Path from typing import TYPE_CHECKING from unittest import mock @@ -19,6 +22,8 @@ from rosbags.rosbag2 import Reader, ReaderError, Writer from .test_serde import MSG_JOINT, MSG_MAGN, MSG_MAGN_BIG, MSG_POLY if TYPE_CHECKING: + from typing import BinaryIO, Iterable + from _pytest.fixtures import SubRequest METADATA = """ @@ -320,3 +325,427 @@ def test_failure_cases(tmp_path: Path) -> None: with pytest.raises(ReaderError, match='not open database'), \ Reader(tmp_path) as reader: next(reader.messages()) + + +def write_record(bio: BinaryIO, opcode: int, records: Iterable[bytes]) -> None: + """Write record.""" + data = b''.join(records) + bio.write(bytes([opcode]) + struct.pack(' bytes: + """Serialize string.""" + data = text.encode() + return struct.pack(' Path: + """Manually contruct mcap bag.""" + # pylint: disable=too-many-locals + # pylint: disable=too-many-statements + (tmp_path / 'metadata.yaml').write_text( + METADATA.format( + extension='.mcap', + compression_format='""', + compression_mode='""', + ).replace('sqlite3', 'mcap'), + ) + + path = tmp_path / 'db.db3.mcap' + bio: BinaryIO + messages: list[tuple[int, int, int]] = [] + chunks = [] + with path.open('wb') as bio: + realbio = bio + bio.write(MCAP_HEADER) + write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap'))) + + if request.param.startswith('chunked'): + bio = BytesIO() + messages = [] + + write_record(bio, *SCHEMAS[0]) + write_record(bio, *CHANNELS[0]) + messages.append((1, 666, bio.tell())) + write_record( + bio, + 0x05, + ( + struct.pack(' None: + """Test reader and deserializer on simple bag.""" + with Reader(bag_mcap) as reader: + assert reader.duration == 43 + assert reader.start_time == 666 + assert reader.end_time == 709 + assert reader.message_count == 4 + if reader.compression_mode: + assert reader.compression_format == 'zstd' + assert [x.id for x in reader.connections] == [1, 2, 3] + assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint'] + gen = reader.messages() + + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/poly' + assert connection.msgtype == 'geometry_msgs/msg/Polygon' + assert timestamp == 666 + assert rawdata == MSG_POLY[0] + + for idx in range(2): + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/magn' + assert connection.msgtype == 'sensor_msgs/msg/MagneticField' + assert timestamp == 708 + assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0] + + connection, timestamp, rawdata = next(gen) + assert connection.topic == '/joint' + assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory' + + with pytest.raises(StopIteration): + next(gen) + + +def test_message_filters_mcap(bag_mcap: Path) -> None: + """Test reader filters messages.""" + with Reader(bag_mcap) as reader: + magn_connections = [x for x in reader.connections if x.topic == '/magn'] + gen = reader.messages(connections=magn_connections) + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/magn' + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(start=667) + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/magn' + connection, _, _ = next(gen) + assert connection.topic == '/joint' + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(stop=667) + connection, _, _ = next(gen) + assert connection.topic == '/poly' + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(connections=magn_connections, stop=667) + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(start=666, stop=666) + with pytest.raises(StopIteration): + next(gen) + + +def test_bag_mcap_files(tmp_path: Path) -> None: + """Test bad mcap files.""" + (tmp_path / 'metadata.yaml').write_text( + METADATA.format( + extension='.mcap', + compression_format='""', + compression_mode='""', + ).replace('sqlite3', 'mcap'), + ) + + path = tmp_path / 'db.db3.mcap' + path.touch() + reader = Reader(tmp_path) + path.unlink() + with pytest.raises(ReaderError, match='Could not open'): + reader.open() + + path.touch() + with pytest.raises(ReaderError, match='seems to be empty'): + Reader(tmp_path).open() + + path.write_bytes(b'xxxxxxxx') + with pytest.raises(ReaderError, match='magic is invalid'): + Reader(tmp_path).open() + + path.write_bytes(b'\x89MCAP0\r\n\xFF') + with pytest.raises(ReaderError, match='Unexpected record'): + Reader(tmp_path).open() + + with path.open('wb') as bio: + bio.write(b'\x89MCAP0\r\n') + write_record(bio, 0x01, (make_string('ros1'), make_string('test_mcap'))) + with pytest.raises(ReaderError, match='Profile is not'): + Reader(tmp_path).open() + + with path.open('wb') as bio: + bio.write(b'\x89MCAP0\r\n') + write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap'))) + with pytest.raises(ReaderError, match='File end magic is invalid'): + Reader(tmp_path).open()