Add rosbag2 mcap storage reader
This commit is contained in:
parent
9830c38fc7
commit
dff38bdb60
@ -4,7 +4,7 @@ The :py:mod:`rosbags.rosbag2` package provides a conformant implementation of ro
|
||||
|
||||
Supported Versions
|
||||
------------------
|
||||
All versions up to the current (ROS2 Foxy) version 4 are supported.
|
||||
All versions up to the current (ROS2 Humble) version 6 are supported.
|
||||
|
||||
Supported Features
|
||||
------------------
|
||||
@ -18,6 +18,7 @@ Rosbag2 is a flexible format that supports plugging different serialization meth
|
||||
|
||||
:Storages:
|
||||
- sqlite3
|
||||
- mcap
|
||||
|
||||
Writing rosbag2
|
||||
---------------
|
||||
|
||||
@ -12,6 +12,7 @@ keywords =
|
||||
conversion
|
||||
deserialization
|
||||
idl
|
||||
mcap
|
||||
message
|
||||
msg
|
||||
reader
|
||||
|
||||
@ -17,6 +17,7 @@ from rosbags.rosbag2 import Reader as Reader2
|
||||
from rosbags.rosbag2 import ReaderError as ReaderError2
|
||||
from rosbags.serde import deserialize_cdr, deserialize_ros1
|
||||
from rosbags.typesys import get_types_from_msg, register_types, types
|
||||
from rosbags.typesys.idl import get_types_from_idl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
@ -125,24 +126,34 @@ class AnyReader:
|
||||
reader.close()
|
||||
raise AnyReaderError(*err.args) from err
|
||||
|
||||
for key in [
|
||||
'builtin_interfaces/msg/Time',
|
||||
'builtin_interfaces/msg/Duration',
|
||||
'std_msgs/msg/Header',
|
||||
]:
|
||||
self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key]
|
||||
attr = key.replace('/', '__')
|
||||
setattr(self.typestore, attr, getattr(types, attr))
|
||||
typs: dict[str, Any] = {}
|
||||
if self.is2:
|
||||
for key, value in types.FIELDDEFS.items():
|
||||
self.typestore.FIELDDEFS[key] = value
|
||||
attr = key.replace('/', '__')
|
||||
setattr(self.typestore, attr, getattr(types, attr))
|
||||
reader = self.readers[0]
|
||||
assert isinstance(reader, Reader2)
|
||||
if reader.metadata['storage_identifier'] == 'mcap':
|
||||
for connection in reader.connections:
|
||||
if connection.md5sum:
|
||||
if connection.md5sum == 'idl':
|
||||
typ = get_types_from_idl(connection.msgdef)
|
||||
else:
|
||||
typ = get_types_from_msg(connection.msgdef, connection.msgtype)
|
||||
typs.update(typ)
|
||||
register_types(typs, self.typestore)
|
||||
else:
|
||||
for key, value in types.FIELDDEFS.items():
|
||||
self.typestore.FIELDDEFS[key] = value
|
||||
attr = key.replace('/', '__')
|
||||
setattr(self.typestore, attr, getattr(types, attr))
|
||||
else:
|
||||
for key in [
|
||||
'builtin_interfaces/msg/Time',
|
||||
'builtin_interfaces/msg/Duration',
|
||||
'std_msgs/msg/Header',
|
||||
]:
|
||||
self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key]
|
||||
attr = key.replace('/', '__')
|
||||
setattr(self.typestore, attr, getattr(types, attr))
|
||||
|
||||
typs: dict[str, Any] = {}
|
||||
for reader in self.readers:
|
||||
assert isinstance(reader, Reader1)
|
||||
for connection in reader.connections:
|
||||
typs.update(get_types_from_msg(connection.msgdef, connection.msgtype))
|
||||
register_types(typs, self.typestore)
|
||||
|
||||
@ -15,6 +15,7 @@ from ruamel.yaml.error import YAMLError
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
|
||||
|
||||
from .errors import ReaderError
|
||||
from .storage_mcap import ReaderMcap
|
||||
from .storage_sqlite3 import ReaderSqlite3
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -29,19 +30,19 @@ class StorageProtocol(Protocol):
|
||||
|
||||
def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]):
|
||||
"""Initialize."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open file."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close file."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||
"""Get message definitions."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def messages(
|
||||
self,
|
||||
@ -50,7 +51,7 @@ class StorageProtocol(Protocol):
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Get messages from file."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class Reader:
|
||||
@ -73,6 +74,7 @@ class Reader:
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = {
|
||||
'mcap': ReaderMcap,
|
||||
'sqlite3': ReaderSqlite3,
|
||||
}
|
||||
|
||||
|
||||
571
src/rosbags/rosbag2/storage_mcap.py
Normal file
571
src/rosbags/rosbag2/storage_mcap.py
Normal file
@ -0,0 +1,571 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Mcap storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
from io import BytesIO
|
||||
from struct import iter_unpack, unpack_from
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import zstandard
|
||||
from lz4.frame import decompress as lz4_decompress
|
||||
|
||||
from .errors import ReaderError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Callable, Generator, Iterable, Optional
|
||||
|
||||
from rosbags.interfaces import Connection
|
||||
|
||||
|
||||
class Schema(NamedTuple):
|
||||
"""Schema."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
encoding: str
|
||||
data: str
|
||||
|
||||
|
||||
class Channel(NamedTuple):
|
||||
"""Channel."""
|
||||
|
||||
id: int
|
||||
schema: str
|
||||
topic: str
|
||||
message_encoding: str
|
||||
metadata: bytes # dict[str, str]
|
||||
|
||||
|
||||
class Chunk(NamedTuple):
|
||||
"""Chunk."""
|
||||
|
||||
start_time: int
|
||||
end_time: int
|
||||
size: int
|
||||
crc: int
|
||||
compression: str
|
||||
records: bytes
|
||||
|
||||
|
||||
class ChunkInfo(NamedTuple):
|
||||
"""Chunk."""
|
||||
|
||||
message_start_time: int
|
||||
message_end_time: int
|
||||
chunk_start_offset: int
|
||||
chunk_length: int
|
||||
message_index_offsets: dict[int, int]
|
||||
message_index_length: int
|
||||
compression: str
|
||||
compressed_size: int
|
||||
uncompressed_size: int
|
||||
channel_count: dict[int, int]
|
||||
|
||||
|
||||
class Statistics(NamedTuple):
|
||||
"""Statistics."""
|
||||
|
||||
message_count: int
|
||||
schema_count: int
|
||||
channel_count: int
|
||||
attachement_count: int
|
||||
metadata_count: int
|
||||
chunk_count: int
|
||||
start_time: int
|
||||
end_time: int
|
||||
channel_message_counts: bytes
|
||||
|
||||
|
||||
class Msg(NamedTuple):
|
||||
"""Message wrapper."""
|
||||
|
||||
timestamp: int
|
||||
offset: int
|
||||
connection: Optional[Connection]
|
||||
data: Optional[bytes]
|
||||
|
||||
|
||||
def read_sized(bio: BinaryIO) -> bytes:
|
||||
"""Read one record."""
|
||||
return bio.read(unpack_from('<Q', bio.read(8))[0])
|
||||
|
||||
|
||||
def skip_sized(bio: BinaryIO) -> None:
|
||||
"""Read one record."""
|
||||
bio.seek(unpack_from('<Q', bio.read(8))[0], 1)
|
||||
|
||||
|
||||
def read_bytes(bio: BinaryIO) -> bytes:
|
||||
"""Read string."""
|
||||
return bio.read(unpack_from('<I', bio.read(4))[0])
|
||||
|
||||
|
||||
def read_string(bio: BinaryIO) -> str:
|
||||
"""Read string."""
|
||||
return bio.read(unpack_from('<I', bio.read(4))[0]).decode()
|
||||
|
||||
|
||||
DECOMPRESSORS: dict[str, Callable[[bytes, int], bytes]] = {
|
||||
'': lambda x, _: x,
|
||||
'lz4': lambda x, _: lz4_decompress(x), # type: ignore
|
||||
'zstd': zstandard.ZstdDecompressor().decompress,
|
||||
}
|
||||
|
||||
|
||||
def msgsrc(
|
||||
chunk: ChunkInfo,
|
||||
channel_map: dict[int, Connection],
|
||||
start: int,
|
||||
stop: int,
|
||||
bio: BinaryIO,
|
||||
) -> Generator[Msg, None, None]:
|
||||
"""Yield messages from chunk in time order."""
|
||||
yield Msg(chunk.message_start_time, 0, None, None)
|
||||
|
||||
bio.seek(chunk.chunk_start_offset + 9 + 40 + len(chunk.compression))
|
||||
compressed_data = bio.read(chunk.compressed_size)
|
||||
subio = BytesIO(DECOMPRESSORS[chunk.compression](compressed_data, chunk.uncompressed_size))
|
||||
|
||||
messages = []
|
||||
while (offset := subio.tell()) < chunk.uncompressed_size:
|
||||
op_ = ord(subio.read(1))
|
||||
if op_ == 0x05:
|
||||
recio = BytesIO(read_sized(subio))
|
||||
channel_id, _, log_time, _ = unpack_from(
|
||||
'<HIQQ',
|
||||
recio.read(22),
|
||||
)
|
||||
if start <= log_time < stop and channel_id in channel_map:
|
||||
messages.append(
|
||||
Msg(
|
||||
log_time,
|
||||
chunk.chunk_start_offset + offset,
|
||||
channel_map[channel_id],
|
||||
recio.read(),
|
||||
),
|
||||
)
|
||||
else:
|
||||
skip_sized(subio)
|
||||
|
||||
yield from sorted(messages, key=lambda x: x.timestamp)
|
||||
|
||||
|
||||
class MCAPFile:
|
||||
"""Mcap format reader."""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
def __init__(self, path: Path):
|
||||
"""Initialize."""
|
||||
self.path = path
|
||||
self.bio: Optional[BinaryIO] = None
|
||||
self.data_start = 0
|
||||
self.data_end = 0
|
||||
self.schemas: dict[int, Schema] = {}
|
||||
self.channels: dict[int, Channel] = {}
|
||||
self.chunks: list[ChunkInfo] = []
|
||||
self.statistics: Optional[Statistics] = None
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open MCAP."""
|
||||
try:
|
||||
self.bio = self.path.open('rb')
|
||||
except OSError as err:
|
||||
raise ReaderError(f'Could not open file {str(self.path)!r}: {err.strerror}.') from err
|
||||
|
||||
magic = self.bio.read(8)
|
||||
if not magic:
|
||||
raise ReaderError(f'File {str(self.path)!r} seems to be empty.')
|
||||
|
||||
if magic != b'\x89MCAP0\r\n':
|
||||
raise ReaderError('File magic is invalid.')
|
||||
|
||||
op_ = ord(self.bio.read(1))
|
||||
if op_ != 0x01:
|
||||
raise ReaderError('Unexpected record.')
|
||||
|
||||
recio = BytesIO(read_sized(self.bio))
|
||||
profile = read_string(recio)
|
||||
if profile != 'ros2':
|
||||
raise ReaderError('Profile is not ros2.')
|
||||
self.data_start = self.bio.tell()
|
||||
|
||||
self.bio.seek(-37, 2)
|
||||
footer_start = self.bio.tell()
|
||||
data = self.bio.read()
|
||||
magic = data[-8:]
|
||||
if magic != b'\x89MCAP0\r\n':
|
||||
raise ReaderError('File end magic is invalid.')
|
||||
|
||||
assert len(data) == 37
|
||||
assert data[0:9] == b'\x02\x14\x00\x00\x00\x00\x00\x00\x00', data[0:9]
|
||||
|
||||
summary_start, = unpack_from('<Q', data, 9)
|
||||
if summary_start:
|
||||
self.data_end = summary_start
|
||||
self.read_index()
|
||||
else:
|
||||
self.data_end = footer_start
|
||||
|
||||
def read_index(self) -> None:
|
||||
"""Read index from file."""
|
||||
bio = self.bio
|
||||
assert bio
|
||||
|
||||
schemas = self.schemas
|
||||
channels = self.channels
|
||||
chunks = self.chunks
|
||||
|
||||
bio.seek(self.data_end)
|
||||
while True:
|
||||
op_ = ord(bio.read(1))
|
||||
|
||||
if op_ in (0x02, 0x0e):
|
||||
break
|
||||
|
||||
if op_ == 0x03:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schemas[key] = Schema(
|
||||
key,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
)
|
||||
|
||||
elif op_ == 0x04:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schema_name = schemas[unpack_from('<H', bio.read(2))[0]].name
|
||||
channels[key] = Channel(
|
||||
key,
|
||||
schema_name,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_bytes(bio),
|
||||
)
|
||||
|
||||
elif op_ == 0x08:
|
||||
bio.seek(8, 1)
|
||||
chunk = ChunkInfo( # type: ignore
|
||||
*unpack_from('<QQQQ', bio.read(32), 0),
|
||||
{
|
||||
x[0]: x[1] for x in
|
||||
iter_unpack('<HQ', bio.read(unpack_from('<I', bio.read(4))[0]))
|
||||
},
|
||||
*unpack_from('<Q', bio.read(8), 0),
|
||||
read_string(bio),
|
||||
*unpack_from('<QQ', bio.read(16), 0),
|
||||
{},
|
||||
)
|
||||
offset_channel = sorted((v, k) for k, v in chunk.message_index_offsets.items())
|
||||
offsets = [
|
||||
*[x[0] for x in offset_channel],
|
||||
chunk.chunk_start_offset + chunk.chunk_length + chunk.message_index_length,
|
||||
]
|
||||
chunk.channel_count.update(
|
||||
{
|
||||
x[1]: count // 16
|
||||
for x, y, z in zip(offset_channel, offsets[1:], offsets)
|
||||
if (count := y - z - 15)
|
||||
},
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
elif op_ == 0x0a:
|
||||
skip_sized(bio)
|
||||
|
||||
elif op_ == 0x0b:
|
||||
bio.seek(8, 1)
|
||||
self.statistics = Statistics(
|
||||
*unpack_from(
|
||||
'<QHIIIIQQ',
|
||||
bio.read(42),
|
||||
0,
|
||||
),
|
||||
read_bytes(bio), # type: ignore
|
||||
)
|
||||
|
||||
elif op_ == 0x0d:
|
||||
skip_sized(bio)
|
||||
|
||||
else:
|
||||
skip_sized(bio)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close MCAP."""
|
||||
assert self.bio
|
||||
self.bio.close()
|
||||
self.bio = None
|
||||
|
||||
def meta_scan(self) -> None:
|
||||
"""Generate metadata by scanning through file."""
|
||||
assert self.bio
|
||||
bio = self.bio
|
||||
bio_size = self.data_end
|
||||
bio.seek(self.data_start)
|
||||
|
||||
schemas = self.schemas
|
||||
channels = self.channels
|
||||
|
||||
while bio.tell() < bio_size:
|
||||
op_ = ord(bio.read(1))
|
||||
|
||||
if op_ == 0x03:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schemas[key] = Schema(
|
||||
key,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
)
|
||||
elif op_ == 0x04:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schema_name = schemas[unpack_from('<H', bio.read(2))[0]].name
|
||||
channels[key] = Channel(
|
||||
key,
|
||||
schema_name,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_bytes(bio),
|
||||
)
|
||||
elif op_ == 0x06:
|
||||
bio.seek(8, 1)
|
||||
_, _, uncompressed_size, _ = unpack_from('<QQQI', bio.read(28))
|
||||
compression = read_string(bio)
|
||||
compressed_size, = unpack_from('<Q', bio.read(8))
|
||||
bio = BytesIO(
|
||||
DECOMPRESSORS[compression](bio.read(compressed_size), uncompressed_size),
|
||||
)
|
||||
bio_size = uncompressed_size
|
||||
else:
|
||||
skip_sized(bio)
|
||||
|
||||
if bio.tell() == bio_size and bio != self.bio:
|
||||
bio = self.bio
|
||||
bio_size = self.data_end
|
||||
|
||||
def get_schema_definitions(self) -> dict[str, tuple[str, str]]:
|
||||
"""Get schema definition."""
|
||||
if not self.schemas:
|
||||
self.meta_scan()
|
||||
return {schema.name: (schema.encoding[4:], schema.data) for schema in self.schemas.values()}
|
||||
|
||||
def messages_scan(
|
||||
self,
|
||||
connections: Iterable[Connection],
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages by scanning whole bag."""
|
||||
# pylint: disable=too-many-locals
|
||||
assert self.bio
|
||||
bio = self.bio
|
||||
bio_size = self.data_end
|
||||
bio.seek(self.data_start)
|
||||
|
||||
schemas = self.schemas.copy()
|
||||
channels = self.channels.copy()
|
||||
|
||||
if channels:
|
||||
read_meta = False
|
||||
channel_map = {
|
||||
cid: conn for conn in connections if (
|
||||
cid := next(
|
||||
(
|
||||
cid for cid, x in self.channels.items()
|
||||
if x.schema == conn.msgtype and x.topic == conn.topic
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
}
|
||||
else:
|
||||
read_meta = True
|
||||
channel_map = {}
|
||||
|
||||
if start is None:
|
||||
start = 0
|
||||
if stop is None:
|
||||
stop = 2**63 - 1
|
||||
|
||||
while bio.tell() < bio_size:
|
||||
op_ = ord(bio.read(1))
|
||||
|
||||
if op_ == 0x03 and read_meta:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schemas[key] = Schema(
|
||||
key,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
)
|
||||
elif op_ == 0x04 and read_meta:
|
||||
bio.seek(8, 1)
|
||||
key, = unpack_from('<H', bio.read(2))
|
||||
schema_name = schemas[unpack_from('<H', bio.read(2))[0]].name
|
||||
channels[key] = Channel(
|
||||
key,
|
||||
schema_name,
|
||||
read_string(bio),
|
||||
read_string(bio),
|
||||
read_bytes(bio),
|
||||
)
|
||||
conn = next(
|
||||
(
|
||||
x for x in connections
|
||||
if x.topic == channels[key].topic and x.msgtype == schema_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if conn:
|
||||
channel_map[key] = conn
|
||||
elif op_ == 0x05:
|
||||
size, channel_id, _, timestamp, _ = unpack_from('<QHIQQ', bio.read(30))
|
||||
data = bio.read(size - 22)
|
||||
if start <= timestamp < stop and channel_id in channel_map:
|
||||
yield channel_map[channel_id], timestamp, data
|
||||
elif op_ == 0x06:
|
||||
size, = unpack_from('<Q', bio.read(8))
|
||||
start_time, end_time, uncompressed_size, _ = unpack_from('<QQQI', bio.read(28))
|
||||
if read_meta or (start < end_time and start_time < stop):
|
||||
compression = read_string(bio)
|
||||
compressed_size, = unpack_from('<Q', bio.read(8))
|
||||
bio = BytesIO(
|
||||
DECOMPRESSORS[compression](bio.read(compressed_size), uncompressed_size),
|
||||
)
|
||||
bio_size = uncompressed_size
|
||||
else:
|
||||
bio.seek(size - 28, 1)
|
||||
else:
|
||||
skip_sized(bio)
|
||||
|
||||
if bio.tell() == bio_size and bio != self.bio:
|
||||
bio = self.bio
|
||||
bio_size = self.data_end
|
||||
|
||||
def messages(
|
||||
self,
|
||||
connections: Iterable[Connection],
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
connections: Iterable with connections to filter for.
|
||||
start: Yield only messages at or after this timestamp (ns).
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
"""
|
||||
assert self.bio
|
||||
|
||||
if not self.chunks:
|
||||
yield from self.messages_scan(connections, start, stop)
|
||||
return
|
||||
|
||||
channel_map = {
|
||||
cid: conn for conn in connections if (
|
||||
cid := next(
|
||||
(
|
||||
cid for cid, x in self.channels.items()
|
||||
if x.schema == conn.msgtype and x.topic == conn.topic
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
chunks = [
|
||||
msgsrc(
|
||||
x,
|
||||
channel_map,
|
||||
start or x.message_start_time,
|
||||
stop or x.message_end_time + 1,
|
||||
self.bio,
|
||||
)
|
||||
for x in self.chunks
|
||||
if x.message_start_time != 0 and (start is None or start < x.message_end_time) and
|
||||
(stop is None or x.message_start_time < stop) and
|
||||
(any(x.channel_count.get(cid, 0) for cid in channel_map))
|
||||
]
|
||||
|
||||
for timestamp, offset, connection, data in heapq.merge(*chunks):
|
||||
if not offset:
|
||||
continue
|
||||
assert connection
|
||||
assert data
|
||||
yield connection, timestamp, data
|
||||
|
||||
|
||||
class ReaderMcap:
|
||||
"""Mcap storage reader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Iterable[Path],
|
||||
connections: Iterable[Connection],
|
||||
):
|
||||
"""Set up storage reader.
|
||||
|
||||
Args:
|
||||
paths: Paths of storage files.
|
||||
connections: List of connections.
|
||||
|
||||
"""
|
||||
self.paths = paths
|
||||
self.readers: list[MCAPFile] = []
|
||||
self.connections = connections
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
self.readers = [MCAPFile(x) for x in self.paths]
|
||||
for reader in self.readers:
|
||||
reader.open()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
assert self.readers
|
||||
for reader in self.readers:
|
||||
reader.close()
|
||||
self.readers = []
|
||||
|
||||
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||
"""Get message definitions."""
|
||||
res = {}
|
||||
for reader in self.readers:
|
||||
res.update(reader.get_schema_definitions())
|
||||
return res
|
||||
|
||||
def messages(
|
||||
self,
|
||||
connections: Iterable[Connection] = (),
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
connections: Iterable with connections to filter for. An empty
|
||||
iterable disables filtering on connections.
|
||||
start: Yield only messages at or after this timestamp (ns).
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
"""
|
||||
connections = list(connections) or list(self.connections)
|
||||
|
||||
for reader in self.readers:
|
||||
yield from reader.messages(connections, start, stop)
|
||||
@ -5,10 +5,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.highlevel import AnyReader, AnyReaderError
|
||||
from rosbags.interfaces import Connection
|
||||
from rosbags.rosbag1 import Writer as Writer1
|
||||
from rosbags.rosbag2 import Writer as Writer2
|
||||
|
||||
@ -200,3 +202,61 @@ def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-out
|
||||
assert nxt[0].topic == '/topic1'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_anyreader2_autoregister(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
|
||||
"""Test AnyReader on rosbag2."""
|
||||
|
||||
class MockReader:
|
||||
"""Mock reader."""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
||||
def __init__(self, paths: list[Path]):
|
||||
"""Initialize mock."""
|
||||
_ = paths
|
||||
self.metadata = {'storage_identifier': 'mcap'}
|
||||
self.connections = [
|
||||
Connection(
|
||||
1,
|
||||
'/foo',
|
||||
'test_msg/msg/Foo',
|
||||
'string foo',
|
||||
'msg',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
Connection(
|
||||
2,
|
||||
'/bar',
|
||||
'test_msg/msg/Bar',
|
||||
'module test_msgs { module msg { struct Bar {string bar;}; }; };',
|
||||
'idl',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
Connection(
|
||||
3,
|
||||
'/baz',
|
||||
'test_msg/msg/Baz',
|
||||
'',
|
||||
'',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Unused."""
|
||||
|
||||
with patch('rosbags.highlevel.anyreader.Reader2', MockReader), \
|
||||
patch('rosbags.highlevel.anyreader.register_types') as mock_register_types:
|
||||
AnyReader([bags2[0]]).open()
|
||||
mock_register_types.assert_called_once()
|
||||
assert mock_register_types.call_args[0][0] == {
|
||||
'test_msg/msg/Foo': ([], [('foo', (1, 'string'))]),
|
||||
'test_msgs/msg/Bar': ([], [('bar', (1, 'string'))]),
|
||||
}
|
||||
|
||||
@ -7,6 +7,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
@ -19,6 +22,8 @@ from rosbags.rosbag2 import Reader, ReaderError, Writer
|
||||
from .test_serde import MSG_JOINT, MSG_MAGN, MSG_MAGN_BIG, MSG_POLY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import BinaryIO, Iterable
|
||||
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
METADATA = """
|
||||
@ -320,3 +325,427 @@ def test_failure_cases(tmp_path: Path) -> None:
|
||||
with pytest.raises(ReaderError, match='not open database'), \
|
||||
Reader(tmp_path) as reader:
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def write_record(bio: BinaryIO, opcode: int, records: Iterable[bytes]) -> None:
|
||||
"""Write record."""
|
||||
data = b''.join(records)
|
||||
bio.write(bytes([opcode]) + struct.pack('<Q', len(data)) + data)
|
||||
|
||||
|
||||
def make_string(text: str) -> bytes:
|
||||
"""Serialize string."""
|
||||
data = text.encode()
|
||||
return struct.pack('<I', len(data)) + data
|
||||
|
||||
|
||||
MCAP_HEADER = b'\x89MCAP0\r\n'
|
||||
|
||||
SCHEMAS = [
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
make_string('geometry_msgs/msg/Polygon'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
make_string('sensor_msgs/msg/MagneticField'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
make_string('trajectory_msgs/msg/JointTrajectory'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CHANNELS = [
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
struct.pack('<H', 1),
|
||||
make_string('/poly'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<H', 2),
|
||||
make_string('/magn'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<H', 3),
|
||||
make_string('/joint'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=['unindexed', 'partially_indexed', 'indexed', 'chunked_unindexed', 'chunked_indexed'],
|
||||
)
|
||||
def bag_mcap(request: SubRequest, tmp_path: Path) -> Path:
|
||||
"""Manually contruct mcap bag."""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-statements
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='.mcap',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'mcap'),
|
||||
)
|
||||
|
||||
path = tmp_path / 'db.db3.mcap'
|
||||
bio: BinaryIO
|
||||
messages: list[tuple[int, int, int]] = []
|
||||
chunks = []
|
||||
with path.open('wb') as bio:
|
||||
realbio = bio
|
||||
bio.write(MCAP_HEADER)
|
||||
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
bio = BytesIO()
|
||||
messages = []
|
||||
|
||||
write_record(bio, *SCHEMAS[0])
|
||||
write_record(bio, *CHANNELS[0])
|
||||
messages.append((1, 666, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
MSG_POLY[0],
|
||||
),
|
||||
)
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
assert isinstance(bio, BytesIO)
|
||||
chunk_start = realbio.tell()
|
||||
compression = make_string('')
|
||||
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
write_record(
|
||||
realbio,
|
||||
0x06,
|
||||
(
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
uncompressed_size,
|
||||
struct.pack('<I', 0),
|
||||
compression,
|
||||
compressed_size,
|
||||
bio.getbuffer(),
|
||||
),
|
||||
)
|
||||
message_index_offsets = []
|
||||
message_index_start = realbio.tell()
|
||||
for channel_id, group in groupby(messages, key=lambda x: x[0]):
|
||||
message_index_offsets.append((channel_id, realbio.tell()))
|
||||
tpls = [y for x in group for y in x[1:]]
|
||||
write_record(
|
||||
realbio,
|
||||
0x07,
|
||||
(
|
||||
struct.pack('<H', channel_id),
|
||||
struct.pack('<I', 8 * len(tpls)),
|
||||
struct.pack('<' + 'Q' * len(tpls), *tpls),
|
||||
),
|
||||
)
|
||||
chunk = [
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', chunk_start),
|
||||
struct.pack('<Q', message_index_start - chunk_start),
|
||||
struct.pack('<I', 10 * len(message_index_offsets)),
|
||||
*(struct.pack('<HQ', *x) for x in message_index_offsets),
|
||||
struct.pack('<Q',
|
||||
realbio.tell() - message_index_start),
|
||||
compression,
|
||||
compressed_size,
|
||||
uncompressed_size,
|
||||
]
|
||||
chunks.append(chunk)
|
||||
bio = BytesIO()
|
||||
messages = []
|
||||
|
||||
write_record(bio, *SCHEMAS[1])
|
||||
write_record(bio, *CHANNELS[1])
|
||||
messages.append((2, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_MAGN[0],
|
||||
),
|
||||
)
|
||||
messages.append((2, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<I', 2),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_MAGN_BIG[0],
|
||||
),
|
||||
)
|
||||
|
||||
write_record(bio, *SCHEMAS[2])
|
||||
write_record(bio, *CHANNELS[2])
|
||||
messages.append((3, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_JOINT[0],
|
||||
),
|
||||
)
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
assert isinstance(bio, BytesIO)
|
||||
chunk_start = realbio.tell()
|
||||
compression = make_string('')
|
||||
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
write_record(
|
||||
realbio,
|
||||
0x06,
|
||||
(
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
uncompressed_size,
|
||||
struct.pack('<I', 0),
|
||||
compression,
|
||||
compressed_size,
|
||||
bio.getbuffer(),
|
||||
),
|
||||
)
|
||||
message_index_offsets = []
|
||||
message_index_start = realbio.tell()
|
||||
for channel_id, group in groupby(messages, key=lambda x: x[0]):
|
||||
message_index_offsets.append((channel_id, realbio.tell()))
|
||||
tpls = [y for x in group for y in x[1:]]
|
||||
write_record(
|
||||
realbio,
|
||||
0x07,
|
||||
(
|
||||
struct.pack('<H', channel_id),
|
||||
struct.pack('<I', 8 * len(tpls)),
|
||||
struct.pack('<' + 'Q' * len(tpls), *tpls),
|
||||
),
|
||||
)
|
||||
chunk = [
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', chunk_start),
|
||||
struct.pack('<Q', message_index_start - chunk_start),
|
||||
struct.pack('<I', 10 * len(message_index_offsets)),
|
||||
*(struct.pack('<HQ', *x) for x in message_index_offsets),
|
||||
struct.pack('<Q',
|
||||
realbio.tell() - message_index_start),
|
||||
compression,
|
||||
compressed_size,
|
||||
uncompressed_size,
|
||||
]
|
||||
chunks.append(chunk)
|
||||
bio = realbio
|
||||
messages = []
|
||||
|
||||
if request.param in ['indexed', 'partially_indexed', 'chunked_indexed']:
|
||||
summary_start = bio.tell()
|
||||
for schema in SCHEMAS:
|
||||
write_record(bio, *schema)
|
||||
if request.param != 'partially_indexed':
|
||||
for channel in CHANNELS:
|
||||
write_record(bio, *channel)
|
||||
if request.param == 'chunked_indexed':
|
||||
for chunk in chunks:
|
||||
write_record(bio, 0x08, chunk)
|
||||
|
||||
summary_offset_start = 0
|
||||
write_record(bio, 0x0a, (b'ignored',))
|
||||
write_record(
|
||||
bio,
|
||||
0x0b,
|
||||
(
|
||||
struct.pack('<Q', 4),
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<I', 3),
|
||||
struct.pack('<I', 0),
|
||||
struct.pack('<I', 0),
|
||||
struct.pack('<I', 0 if request.param == 'indexed' else 1),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<I', 0),
|
||||
),
|
||||
)
|
||||
write_record(bio, 0x0d, (b'ignored',))
|
||||
write_record(bio, 0xff, (b'ignored',))
|
||||
else:
|
||||
summary_start = 0
|
||||
summary_offset_start = 0
|
||||
|
||||
write_record(
|
||||
bio,
|
||||
0x02,
|
||||
(
|
||||
struct.pack('<Q', summary_start),
|
||||
struct.pack('<Q', summary_offset_start),
|
||||
struct.pack('<I', 0),
|
||||
),
|
||||
)
|
||||
bio.write(MCAP_HEADER)
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_reader_mcap(bag_mcap: Path) -> None:
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
with Reader(bag_mcap) as reader:
|
||||
assert reader.duration == 43
|
||||
assert reader.start_time == 666
|
||||
assert reader.end_time == 709
|
||||
assert reader.message_count == 4
|
||||
if reader.compression_mode:
|
||||
assert reader.compression_format == 'zstd'
|
||||
assert [x.id for x in reader.connections] == [1, 2, 3]
|
||||
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
|
||||
gen = reader.messages()
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
|
||||
assert timestamp == 666
|
||||
assert rawdata == MSG_POLY[0]
|
||||
|
||||
for idx in range(2):
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
assert timestamp == 708
|
||||
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_message_filters_mcap(bag_mcap: Path) -> None:
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag_mcap) as reader:
|
||||
magn_connections = [x for x in reader.connections if x.topic == '/magn']
|
||||
gen = reader.messages(connections=magn_connections)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(stop=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(connections=magn_connections, stop=667)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=666, stop=666)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_bag_mcap_files(tmp_path: Path) -> None:
|
||||
"""Test bad mcap files."""
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='.mcap',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'mcap'),
|
||||
)
|
||||
|
||||
path = tmp_path / 'db.db3.mcap'
|
||||
path.touch()
|
||||
reader = Reader(tmp_path)
|
||||
path.unlink()
|
||||
with pytest.raises(ReaderError, match='Could not open'):
|
||||
reader.open()
|
||||
|
||||
path.touch()
|
||||
with pytest.raises(ReaderError, match='seems to be empty'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
path.write_bytes(b'xxxxxxxx')
|
||||
with pytest.raises(ReaderError, match='magic is invalid'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
path.write_bytes(b'\x89MCAP0\r\n\xFF')
|
||||
with pytest.raises(ReaderError, match='Unexpected record'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
with path.open('wb') as bio:
|
||||
bio.write(b'\x89MCAP0\r\n')
|
||||
write_record(bio, 0x01, (make_string('ros1'), make_string('test_mcap')))
|
||||
with pytest.raises(ReaderError, match='Profile is not'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
with path.open('wb') as bio:
|
||||
bio.write(b'\x89MCAP0\r\n')
|
||||
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
|
||||
with pytest.raises(ReaderError, match='File end magic is invalid'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user