From f7d69e35d5833eda608058e75f253ff321631cd4 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Wed, 13 Apr 2022 13:12:18 +0200 Subject: [PATCH] Add all-in-one reader --- docs/api/rosbags.highlevel.rst | 6 + docs/api/rosbags.rst | 1 + src/rosbags/highlevel/__init__.py | 10 ++ src/rosbags/highlevel/anyreader.py | 258 +++++++++++++++++++++++++++++ src/rosbags/highlevel/py.typed | 0 tests/test_highlevel.py | 202 ++++++++++++++++++++++ 6 files changed, 477 insertions(+) create mode 100644 docs/api/rosbags.highlevel.rst create mode 100644 src/rosbags/highlevel/__init__.py create mode 100644 src/rosbags/highlevel/anyreader.py create mode 100644 src/rosbags/highlevel/py.typed create mode 100644 tests/test_highlevel.py diff --git a/docs/api/rosbags.highlevel.rst b/docs/api/rosbags.highlevel.rst new file mode 100644 index 00000000..66c7ba3f --- /dev/null +++ b/docs/api/rosbags.highlevel.rst @@ -0,0 +1,6 @@ +rosbags.highlevel +================= + +.. automodule:: rosbags.highlevel + :members: + :show-inheritance: diff --git a/docs/api/rosbags.rst b/docs/api/rosbags.rst index e8761f70..cedf169a 100644 --- a/docs/api/rosbags.rst +++ b/docs/api/rosbags.rst @@ -5,6 +5,7 @@ Rosbags namespace :maxdepth: 4 rosbags.convert + rosbags.highlevel rosbags.rosbag1 rosbags.rosbag2 rosbags.serde diff --git a/src/rosbags/highlevel/__init__.py b/src/rosbags/highlevel/__init__.py new file mode 100644 index 00000000..3ce7382d --- /dev/null +++ b/src/rosbags/highlevel/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020-2022 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Highlevel interfaces for rosbags.""" + +from .anyreader import AnyReader, AnyReaderError + +__all__ = [ + 'AnyReader', + 'AnyReaderError', +] diff --git a/src/rosbags/highlevel/anyreader.py b/src/rosbags/highlevel/anyreader.py new file mode 100644 index 00000000..1a635a17 --- /dev/null +++ b/src/rosbags/highlevel/anyreader.py @@ -0,0 +1,258 @@ +# Copyright 2020-2022 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Tools for reading all rosbag versions with unified api.""" + +from __future__ import annotations + +from contextlib import suppress +from dataclasses import dataclass +from heapq import merge +from itertools import groupby +from typing import TYPE_CHECKING + +from rosbags.interfaces import TopicInfo +from rosbags.rosbag1 import Reader as Reader1 +from rosbags.rosbag1 import ReaderError as ReaderError1 +from rosbags.rosbag2 import Reader as Reader2 +from rosbags.rosbag2 import ReaderError as ReaderError2 +from rosbags.serde import deserialize_cdr, ros1_to_cdr +from rosbags.typesys import get_types_from_msg, register_types, types + +if TYPE_CHECKING: + import sys + from pathlib import Path + from types import TracebackType + from typing import Any, Generator, Iterable, Literal, Optional, Sequence, Type, Union + + from rosbags.interfaces import Connection + from rosbags.typesys.base import Typesdict + from rosbags.typesys.register import Typestore + + if sys.version_info < (3, 10): + from typing_extensions import TypeGuard + else: + from typing import TypeGuard + + +class AnyReaderError(Exception): + """Reader error.""" + + +ReaderErrors = (ReaderError1, ReaderError2) + + +def is_reader1(val: Union[Sequence[Reader1], Sequence[Reader2]]) -> TypeGuard[Sequence[Reader1]]: + """Determine wether all items are Reader1 instances.""" + return all(isinstance(x, Reader1) for x in val) + + +@dataclass +class SimpleTypeStore: + """Simple type store implementation.""" + + FIELDDEFS: Typesdict # pylint: disable=invalid-name + + def __hash__(self) -> int: + """Create hash.""" + return id(self) + + +class AnyReader: + """Unified rosbag1 and rosbag2 reader.""" + + readers: Union[Sequence[Reader1], Sequence[Reader2]] + typestore: Typestore + + def __init__(self, paths: Sequence[Path]): + """Initialize RosbagReader. + + Opens one or multiple rosbag1 recordings or a single rosbag2 recording. + + Args: + paths: Paths to multiple rosbag1 files or single rosbag2 directory. + + Raises: + AnyReaderError: If paths do not exist or multiple rosbag2 files are given. + + """ + if not paths: + raise AnyReaderError('Must call with at least one path.') + + if len(paths) > 1 and any((x / 'metadata.yaml').exists() for x in paths): + raise AnyReaderError('Opening of multiple rosbag2 recordings is not supported.') + + if missing := [x for x in paths if not x.exists()]: + raise AnyReaderError(f'The following paths are missing: {missing!r}') + + self.paths = paths + self.is2 = (paths[0] / 'metadata.yaml').exists() + self.isopen = False + self.connections: list[Connection] = [] + + try: + if self.is2: + self.readers = [Reader2(x) for x in paths] + else: + self.readers = [Reader1(x) for x in paths] + except ReaderErrors as err: + raise AnyReaderError(*err.args) from err + + self.typestore = SimpleTypeStore({}) + + def _deser_ros1(self, rawdata: bytes, typ: str) -> object: + """Deserialize ROS1 message.""" + return deserialize_cdr(ros1_to_cdr(rawdata, typ, self.typestore), typ, self.typestore) + + def _deser_ros2(self, rawdata: bytes, typ: str) -> object: + """Deserialize CDR message.""" + return deserialize_cdr(rawdata, typ, self.typestore) + + def deserialize(self, rawdata: bytes, typ: str) -> object: + """Deserialize message with appropriate helper.""" + return self._deser_ros2(rawdata, typ) if self.is2 else self._deser_ros1(rawdata, typ) + + def open(self) -> None: + """Open rosbags.""" + assert not self.isopen + rollback = [] + try: + for reader in self.readers: + reader.open() + rollback.append(reader) + except ReaderErrors as err: + for reader in rollback: + with suppress(*ReaderErrors): + reader.close() + raise AnyReaderError(*err.args) from err + + if self.is2: + for key, value in types.FIELDDEFS.items(): + self.typestore.FIELDDEFS[key] = value + attr = key.replace('/', '__') + setattr(self.typestore, attr, getattr(types, attr)) + else: + for key in [ + 'builtin_interfaces/msg/Time', + 'builtin_interfaces/msg/Duration', + 'std_msgs/msg/Header', + ]: + self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key] + attr = key.replace('/', '__') + setattr(self.typestore, attr, getattr(types, attr)) + + typs: dict[str, Any] = {} + for reader in self.readers: + assert isinstance(reader, Reader1) + for connection in reader.connections.values(): + typs.update(get_types_from_msg(connection.msgdef, connection.msgtype)) + register_types(typs, self.typestore) + + self.connections = [y for x in self.readers for y in x.connections.values()] + self.isopen = True + + def close(self) -> None: + """Close rosbag.""" + assert self.isopen + for reader in self.readers: + with suppress(*ReaderErrors): + reader.close() + self.isopen = False + + def __enter__(self) -> AnyReader: + """Open rosbags when entering contextmanager.""" + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: + """Close rosbags when exiting contextmanager.""" + self.close() + return False + + @property + def duration(self) -> int: + """Duration in nanoseconds between earliest and latest messages.""" + return self.end_time - self.start_time + + @property + def start_time(self) -> int: + """Timestamp in nanoseconds of the earliest message.""" + return min(x.start_time for x in self.readers) + + @property + def end_time(self) -> int: + """Timestamp in nanoseconds after the latest message.""" + return max(x.end_time for x in self.readers) + + @property + def message_count(self) -> int: + """Total message count.""" + return sum(x.message_count for x in self.readers) + + @property + def topics(self) -> dict[str, TopicInfo]: + """Topics stored in the rosbags.""" + assert self.isopen + + if self.is2: + assert isinstance(self.readers[0], Reader2) + return self.readers[0].topics + + assert is_reader1(self.readers) + + def summarize(names_infos: Iterable[tuple[str, TopicInfo]]) -> TopicInfo: + """Summarize topic infos.""" + infos = [x[1] for x in names_infos] + return TopicInfo( + msgtypes.pop() if len(msgtypes := {x.msgtype for x in infos}) == 1 else None, + msgdefs.pop() if len(msgdefs := {x.msgdef for x in infos}) == 1 else None, + sum(x.msgcount for x in infos), + sum((x.connections for x in infos), []), + ) + + return { + name: summarize(infos) for name, infos in groupby( + sorted( + (x for reader in self.readers for x in reader.topics.items()), + key=lambda x: x[0], + ), + key=lambda x: x[0], + ) + } + + def messages( + self, + connections: Iterable[Any] = (), + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> Generator[tuple[Any, int, bytes], None, None]: + """Read messages from bags. + + Args: + connections: Iterable with connections to filter for. An empty + iterable disables filtering on connections. + start: Yield only messages at or after this timestamp (ns). + stop: Yield only messages before this timestamp (ns). + + Yields: + Tuples of connection, timestamp (ns), and rawdata. + + """ + assert self.isopen + + def get_owner(connection: Connection) -> Union[Reader1, Reader2]: + assert isinstance(connection.owner, (Reader1, Reader2)) + return connection.owner + + if connections: + generators = [ + reader.messages(connections=list(conns), start=start, stop=stop) for reader, conns + in groupby(sorted(connections, key=lambda x: id(get_owner(x))), key=get_owner) + ] + else: + generators = [reader.messages(start=start, stop=stop) for reader in self.readers] + yield from merge(*generators, key=lambda x: x[1]) diff --git a/src/rosbags/highlevel/py.typed b/src/rosbags/highlevel/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py new file mode 100644 index 00000000..17fcaa7c --- /dev/null +++ b/tests/test_highlevel.py @@ -0,0 +1,202 @@ +# Copyright 2020-2022 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Reader tests.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from rosbags.highlevel import AnyReader, AnyReaderError +from rosbags.rosbag1 import Writer as Writer1 +from rosbags.rosbag2 import Writer as Writer2 + +if TYPE_CHECKING: + from pathlib import Path + from typing import Sequence + +HEADER = b'\x00\x01\x00\x00' + + +@pytest.fixture() +def bags1(tmp_path: Path) -> list[Path]: + """Test data fixture.""" + paths = [ + tmp_path / 'ros1_1.bag', + tmp_path / 'ros1_2.bag', + tmp_path / 'ros1_3.bag', + tmp_path / 'bad.bag', + ] + with (Writer1(paths[0])) as writer: + topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8') + topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16') + writer.write(topic1, 1, b'\x01') + writer.write(topic2, 2, b'\x02\x00') + writer.write(topic1, 9, b'\x09') + with (Writer1(paths[1])) as writer: + topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8') + writer.write(topic1, 5, b'\x05') + with (Writer1(paths[2])) as writer: + topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16') + writer.write(topic2, 15, b'\x15\x00') + + paths[3].touch() + + return paths + + +@pytest.fixture() +def bags2(tmp_path: Path) -> list[Path]: + """Test data fixture.""" + paths = [ + tmp_path / 'ros2_1', + tmp_path / 'bad', + ] + with (Writer2(paths[0])) as writer: + topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8') + topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16') + writer.write(topic1, 1, HEADER + b'\x01') + writer.write(topic2, 2, HEADER + b'\x02\x00') + writer.write(topic1, 9, HEADER + b'\x09') + writer.write(topic1, 5, HEADER + b'\x05') + writer.write(topic2, 15, HEADER + b'\x15\x00') + + paths[1].mkdir() + (paths[1] / 'metadata.yaml').write_text(':') + + return paths + + +def test_anyreader1(bags1: Sequence[Path]) -> None: # pylint: disable=redefined-outer-name + """Test AnyReader on rosbag1.""" + # pylint: disable=too-many-statements + with pytest.raises(AnyReaderError, match='at least one'): + AnyReader([]) + + with pytest.raises(AnyReaderError, match='missing'): + AnyReader([bags1[0] / 'badname']) + + reader = AnyReader(bags1) + with pytest.raises(AssertionError): + assert reader.topics + + with pytest.raises(AssertionError): + next(reader.messages()) + + reader = AnyReader(bags1) + with pytest.raises(AnyReaderError, match='seems to be empty'): + reader.open() + assert all(not x.bio for x in reader.readers) + + with AnyReader(bags1[:3]) as reader: + assert reader.duration == 15 + assert reader.start_time == 1 + assert reader.end_time == 16 + assert reader.message_count == 5 + assert list(reader.topics.keys()) == ['/topic1', '/topic2'] + assert len(reader.topics['/topic1'].connections) == 2 + assert reader.topics['/topic1'].msgcount == 3 + assert len(reader.topics['/topic2'].connections) == 2 + assert reader.topics['/topic2'].msgcount == 2 + + gen = reader.messages() + + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (1, b'\x01') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 1 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic2' + assert nxt[1:] == (2, b'\x02\x00') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 2 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (5, b'\x05') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 5 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (9, b'\x09') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 9 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic2' + assert nxt[1:] == (15, b'\x15\x00') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 21 # type: ignore + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(connections=reader.topics['/topic1'].connections) + nxt = next(gen) + assert nxt[0].topic == '/topic1' + nxt = next(gen) + assert nxt[0].topic == '/topic1' + nxt = next(gen) + assert nxt[0].topic == '/topic1' + with pytest.raises(StopIteration): + next(gen) + + +def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name + """Test AnyReader on rosbag2.""" + # pylint: disable=too-many-statements + with pytest.raises(AnyReaderError, match='multiple rosbag2'): + AnyReader(bags2) + + with pytest.raises(AnyReaderError, match='YAML'): + AnyReader([bags2[1]]) + + with AnyReader([bags2[0]]) as reader: + assert reader.duration == 15 + assert reader.start_time == 1 + assert reader.end_time == 16 + assert reader.message_count == 5 + assert list(reader.topics.keys()) == ['/topic1', '/topic2'] + assert len(reader.topics['/topic1'].connections) == 1 + assert reader.topics['/topic1'].msgcount == 3 + assert len(reader.topics['/topic2'].connections) == 1 + assert reader.topics['/topic2'].msgcount == 2 + + gen = reader.messages() + + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (1, HEADER + b'\x01') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 1 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic2' + assert nxt[1:] == (2, HEADER + b'\x02\x00') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 2 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (5, HEADER + b'\x05') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 5 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic1' + assert nxt[1:] == (9, HEADER + b'\x09') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 9 # type: ignore + nxt = next(gen) + assert nxt[0].topic == '/topic2' + assert nxt[1:] == (15, HEADER + b'\x15\x00') + msg = reader.deserialize(nxt[2], nxt[0].msgtype) + assert msg.data == 21 # type: ignore + with pytest.raises(StopIteration): + next(gen) + + gen = reader.messages(connections=reader.topics['/topic1'].connections) + nxt = next(gen) + assert nxt[0].topic == '/topic1' + nxt = next(gen) + assert nxt[0].topic == '/topic1' + nxt = next(gen) + assert nxt[0].topic == '/topic1' + with pytest.raises(StopIteration): + next(gen)