Add all-in-one reader

This commit is contained in:
Marko Durkovic 2022-04-13 13:12:18 +02:00
parent 657032ce9f
commit f7d69e35d5
6 changed files with 477 additions and 0 deletions

View File

@ -0,0 +1,6 @@
rosbags.highlevel
=================
.. automodule:: rosbags.highlevel
:members:
:show-inheritance:

View File

@ -5,6 +5,7 @@ Rosbags namespace
:maxdepth: 4
rosbags.convert
rosbags.highlevel
rosbags.rosbag1
rosbags.rosbag2
rosbags.serde

View File

@ -0,0 +1,10 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Highlevel interfaces for rosbags."""
from .anyreader import AnyReader, AnyReaderError
__all__ = [
'AnyReader',
'AnyReaderError',
]

View File

@ -0,0 +1,258 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Tools for reading all rosbag versions with unified api."""
from __future__ import annotations
from contextlib import suppress
from dataclasses import dataclass
from heapq import merge
from itertools import groupby
from typing import TYPE_CHECKING
from rosbags.interfaces import TopicInfo
from rosbags.rosbag1 import Reader as Reader1
from rosbags.rosbag1 import ReaderError as ReaderError1
from rosbags.rosbag2 import Reader as Reader2
from rosbags.rosbag2 import ReaderError as ReaderError2
from rosbags.serde import deserialize_cdr, ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types, types
if TYPE_CHECKING:
import sys
from pathlib import Path
from types import TracebackType
from typing import Any, Generator, Iterable, Literal, Optional, Sequence, Type, Union
from rosbags.interfaces import Connection
from rosbags.typesys.base import Typesdict
from rosbags.typesys.register import Typestore
if sys.version_info < (3, 10):
from typing_extensions import TypeGuard
else:
from typing import TypeGuard
class AnyReaderError(Exception):
"""Reader error."""
ReaderErrors = (ReaderError1, ReaderError2)
def is_reader1(val: Union[Sequence[Reader1], Sequence[Reader2]]) -> TypeGuard[Sequence[Reader1]]:
"""Determine wether all items are Reader1 instances."""
return all(isinstance(x, Reader1) for x in val)
@dataclass
class SimpleTypeStore:
"""Simple type store implementation."""
FIELDDEFS: Typesdict # pylint: disable=invalid-name
def __hash__(self) -> int:
"""Create hash."""
return id(self)
class AnyReader:
"""Unified rosbag1 and rosbag2 reader."""
readers: Union[Sequence[Reader1], Sequence[Reader2]]
typestore: Typestore
def __init__(self, paths: Sequence[Path]):
"""Initialize RosbagReader.
Opens one or multiple rosbag1 recordings or a single rosbag2 recording.
Args:
paths: Paths to multiple rosbag1 files or single rosbag2 directory.
Raises:
AnyReaderError: If paths do not exist or multiple rosbag2 files are given.
"""
if not paths:
raise AnyReaderError('Must call with at least one path.')
if len(paths) > 1 and any((x / 'metadata.yaml').exists() for x in paths):
raise AnyReaderError('Opening of multiple rosbag2 recordings is not supported.')
if missing := [x for x in paths if not x.exists()]:
raise AnyReaderError(f'The following paths are missing: {missing!r}')
self.paths = paths
self.is2 = (paths[0] / 'metadata.yaml').exists()
self.isopen = False
self.connections: list[Connection] = []
try:
if self.is2:
self.readers = [Reader2(x) for x in paths]
else:
self.readers = [Reader1(x) for x in paths]
except ReaderErrors as err:
raise AnyReaderError(*err.args) from err
self.typestore = SimpleTypeStore({})
def _deser_ros1(self, rawdata: bytes, typ: str) -> object:
"""Deserialize ROS1 message."""
return deserialize_cdr(ros1_to_cdr(rawdata, typ, self.typestore), typ, self.typestore)
def _deser_ros2(self, rawdata: bytes, typ: str) -> object:
"""Deserialize CDR message."""
return deserialize_cdr(rawdata, typ, self.typestore)
def deserialize(self, rawdata: bytes, typ: str) -> object:
"""Deserialize message with appropriate helper."""
return self._deser_ros2(rawdata, typ) if self.is2 else self._deser_ros1(rawdata, typ)
def open(self) -> None:
"""Open rosbags."""
assert not self.isopen
rollback = []
try:
for reader in self.readers:
reader.open()
rollback.append(reader)
except ReaderErrors as err:
for reader in rollback:
with suppress(*ReaderErrors):
reader.close()
raise AnyReaderError(*err.args) from err
if self.is2:
for key, value in types.FIELDDEFS.items():
self.typestore.FIELDDEFS[key] = value
attr = key.replace('/', '__')
setattr(self.typestore, attr, getattr(types, attr))
else:
for key in [
'builtin_interfaces/msg/Time',
'builtin_interfaces/msg/Duration',
'std_msgs/msg/Header',
]:
self.typestore.FIELDDEFS[key] = types.FIELDDEFS[key]
attr = key.replace('/', '__')
setattr(self.typestore, attr, getattr(types, attr))
typs: dict[str, Any] = {}
for reader in self.readers:
assert isinstance(reader, Reader1)
for connection in reader.connections.values():
typs.update(get_types_from_msg(connection.msgdef, connection.msgtype))
register_types(typs, self.typestore)
self.connections = [y for x in self.readers for y in x.connections.values()]
self.isopen = True
def close(self) -> None:
"""Close rosbag."""
assert self.isopen
for reader in self.readers:
with suppress(*ReaderErrors):
reader.close()
self.isopen = False
def __enter__(self) -> AnyReader:
"""Open rosbags when entering contextmanager."""
self.open()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""Close rosbags when exiting contextmanager."""
self.close()
return False
@property
def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages."""
return self.end_time - self.start_time
@property
def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message."""
return min(x.start_time for x in self.readers)
@property
def end_time(self) -> int:
"""Timestamp in nanoseconds after the latest message."""
return max(x.end_time for x in self.readers)
@property
def message_count(self) -> int:
"""Total message count."""
return sum(x.message_count for x in self.readers)
@property
def topics(self) -> dict[str, TopicInfo]:
"""Topics stored in the rosbags."""
assert self.isopen
if self.is2:
assert isinstance(self.readers[0], Reader2)
return self.readers[0].topics
assert is_reader1(self.readers)
def summarize(names_infos: Iterable[tuple[str, TopicInfo]]) -> TopicInfo:
"""Summarize topic infos."""
infos = [x[1] for x in names_infos]
return TopicInfo(
msgtypes.pop() if len(msgtypes := {x.msgtype for x in infos}) == 1 else None,
msgdefs.pop() if len(msgdefs := {x.msgdef for x in infos}) == 1 else None,
sum(x.msgcount for x in infos),
sum((x.connections for x in infos), []),
)
return {
name: summarize(infos) for name, infos in groupby(
sorted(
(x for reader in self.readers for x in reader.topics.items()),
key=lambda x: x[0],
),
key=lambda x: x[0],
)
}
def messages(
self,
connections: Iterable[Any] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[tuple[Any, int, bytes], None, None]:
"""Read messages from bags.
Args:
connections: Iterable with connections to filter for. An empty
iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns).
Yields:
Tuples of connection, timestamp (ns), and rawdata.
"""
assert self.isopen
def get_owner(connection: Connection) -> Union[Reader1, Reader2]:
assert isinstance(connection.owner, (Reader1, Reader2))
return connection.owner
if connections:
generators = [
reader.messages(connections=list(conns), start=start, stop=stop) for reader, conns
in groupby(sorted(connections, key=lambda x: id(get_owner(x))), key=get_owner)
]
else:
generators = [reader.messages(start=start, stop=stop) for reader in self.readers]
yield from merge(*generators, key=lambda x: x[1])

View File

202
tests/test_highlevel.py Normal file
View File

@ -0,0 +1,202 @@
# Copyright 2020-2022 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from rosbags.highlevel import AnyReader, AnyReaderError
from rosbags.rosbag1 import Writer as Writer1
from rosbags.rosbag2 import Writer as Writer2
if TYPE_CHECKING:
from pathlib import Path
from typing import Sequence
HEADER = b'\x00\x01\x00\x00'
@pytest.fixture()
def bags1(tmp_path: Path) -> list[Path]:
"""Test data fixture."""
paths = [
tmp_path / 'ros1_1.bag',
tmp_path / 'ros1_2.bag',
tmp_path / 'ros1_3.bag',
tmp_path / 'bad.bag',
]
with (Writer1(paths[0])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic1, 1, b'\x01')
writer.write(topic2, 2, b'\x02\x00')
writer.write(topic1, 9, b'\x09')
with (Writer1(paths[1])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
writer.write(topic1, 5, b'\x05')
with (Writer1(paths[2])) as writer:
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic2, 15, b'\x15\x00')
paths[3].touch()
return paths
@pytest.fixture()
def bags2(tmp_path: Path) -> list[Path]:
"""Test data fixture."""
paths = [
tmp_path / 'ros2_1',
tmp_path / 'bad',
]
with (Writer2(paths[0])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic1, 1, HEADER + b'\x01')
writer.write(topic2, 2, HEADER + b'\x02\x00')
writer.write(topic1, 9, HEADER + b'\x09')
writer.write(topic1, 5, HEADER + b'\x05')
writer.write(topic2, 15, HEADER + b'\x15\x00')
paths[1].mkdir()
(paths[1] / 'metadata.yaml').write_text(':')
return paths
def test_anyreader1(bags1: Sequence[Path]) -> None: # pylint: disable=redefined-outer-name
"""Test AnyReader on rosbag1."""
# pylint: disable=too-many-statements
with pytest.raises(AnyReaderError, match='at least one'):
AnyReader([])
with pytest.raises(AnyReaderError, match='missing'):
AnyReader([bags1[0] / 'badname'])
reader = AnyReader(bags1)
with pytest.raises(AssertionError):
assert reader.topics
with pytest.raises(AssertionError):
next(reader.messages())
reader = AnyReader(bags1)
with pytest.raises(AnyReaderError, match='seems to be empty'):
reader.open()
assert all(not x.bio for x in reader.readers)
with AnyReader(bags1[:3]) as reader:
assert reader.duration == 15
assert reader.start_time == 1
assert reader.end_time == 16
assert reader.message_count == 5
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
assert len(reader.topics['/topic1'].connections) == 2
assert reader.topics['/topic1'].msgcount == 3
assert len(reader.topics['/topic2'].connections) == 2
assert reader.topics['/topic2'].msgcount == 2
gen = reader.messages()
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (1, b'\x01')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 1 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (2, b'\x02\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 2 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (5, b'\x05')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 5 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (9, b'\x09')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 9 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (15, b'\x15\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 21 # type: ignore
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=reader.topics['/topic1'].connections)
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
with pytest.raises(StopIteration):
next(gen)
def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
"""Test AnyReader on rosbag2."""
# pylint: disable=too-many-statements
with pytest.raises(AnyReaderError, match='multiple rosbag2'):
AnyReader(bags2)
with pytest.raises(AnyReaderError, match='YAML'):
AnyReader([bags2[1]])
with AnyReader([bags2[0]]) as reader:
assert reader.duration == 15
assert reader.start_time == 1
assert reader.end_time == 16
assert reader.message_count == 5
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
assert len(reader.topics['/topic1'].connections) == 1
assert reader.topics['/topic1'].msgcount == 3
assert len(reader.topics['/topic2'].connections) == 1
assert reader.topics['/topic2'].msgcount == 2
gen = reader.messages()
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (1, HEADER + b'\x01')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 1 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (2, HEADER + b'\x02\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 2 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (5, HEADER + b'\x05')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 5 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (9, HEADER + b'\x09')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 9 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (15, HEADER + b'\x15\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 21 # type: ignore
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=reader.topics['/topic1'].connections)
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
with pytest.raises(StopIteration):
next(gen)