apoorva 0c9504b343 Add 'rosbags/' from commit 'c80625df279c154c6ec069cbac30faa319755e47'
git-subtree-dir: rosbags
git-subtree-mainline: 48df1fbdf4490f3cbfa3267c998d1a0fc98378ca
git-subtree-split: c80625df279c154c6ec069cbac30faa319755e47
2023-03-28 18:21:08 +05:30

155 lines
5.2 KiB
Python

# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Check and benchmark rosbag2 read implementations."""
# pylint: disable=import-error
from __future__ import annotations
import sys
from math import isnan
from pathlib import Path
from timeit import timeit
from typing import TYPE_CHECKING
import numpy
from rclpy.serialization import deserialize_message # type: ignore
from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore
from rosidl_runtime_py.utilities import get_message # type: ignore
from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr
if TYPE_CHECKING:
from typing import Generator, Protocol
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
"""Minimal native ROS message interface used for benchmark."""
def get_fields_and_field_types(self) -> dict[str, str]:
"""Introspect message type."""
raise NotImplementedError
class ReaderPy: # pylint: disable=too-few-public-methods
"""Mimimal shim using rosbag2_py to emulate rosbag2 API."""
def __init__(self, path: Path):
"""Initialize reader shim."""
soptions = StorageOptions(str(path), 'sqlite3')
coptions = ConverterOptions('', '')
self.reader = SequentialReader()
self.reader.open(soptions, coptions)
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]:
"""Expose rosbag2 like generator behavior."""
while self.reader.has_next():
topic, data, timestamp = self.reader.read_next()
yield topic, self.typemap[topic], timestamp, data
def deserialize_py(data: bytes, msgtype: str) -> NativeMSG:
"""Deserialization helper for rosidl_runtime_py + rclpy."""
pytype = get_message(msgtype)
return deserialize_message(data, pytype) # type: ignore
def compare_msg(lite: object, native: NativeMSG) -> None:
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
Args:
lite: Message from rosbag2.
native: Message from rosbag2_py.
Raises:
AssertionError: If messages are not identical.
"""
for fieldname in native.get_fields_and_field_types().keys():
native_val = getattr(native, fieldname)
lite_val = getattr(lite, fieldname)
if hasattr(lite_val, '__dataclass_fields__'):
compare_msg(lite_val, native_val)
elif isinstance(lite_val, numpy.ndarray):
assert not (native_val != lite_val).any(), f'{fieldname}: {native_val} != {lite_val}'
elif isinstance(lite_val, list):
assert len(native_val) == len(lite_val), f'{fieldname} length mismatch'
for sub1, sub2 in zip(native_val, lite_val):
compare_msg(sub2, sub1)
elif isinstance(lite_val, float) and isnan(lite_val):
assert isnan(native_val)
else:
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
def compare(path: Path) -> None:
"""Compare raw and deserialized messages."""
with Reader(path) as reader:
gens = (reader.messages(), ReaderPy(path).messages())
for item, item_py in zip(*gens):
connection, timestamp, data = item
topic_py, msgtype_py, timestamp_py, data_py = item_py
assert connection.topic == topic_py
assert connection.msgtype == msgtype_py
assert timestamp == timestamp_py
assert data == data_py
msg_py = deserialize_py(data_py, msgtype_py)
msg = deserialize_cdr(data, connection.msgtype)
compare_msg(msg, msg_py)
assert not list(gens[0])
assert not list(gens[1])
def read_deser_rosbag2_py(path: Path) -> None:
"""Read testbag with rosbag2_py."""
soptions = StorageOptions(str(path), 'sqlite3')
coptions = ConverterOptions('', '')
reader = SequentialReader()
reader.open(soptions, coptions)
typemap = {x.name: x.type for x in reader.get_all_topics_and_types()}
while reader.has_next():
topic, rawdata, _ = reader.read_next()
msgtype = typemap[topic]
pytype = get_message(msgtype)
deserialize_message(rawdata, pytype)
def read_deser_rosbag2(path: Path) -> None:
"""Read testbag with rosbag2lite."""
with Reader(path) as reader:
for connection, _, data in reader.messages():
deserialize_cdr(data, connection.msgtype)
def main() -> None:
"""Benchmark rosbag2 against rosbag2_py."""
path = Path(sys.argv[1])
try:
print('Comparing messages from rosbag2 and rosbag2_py.') # noqa: T201
compare(path)
except AssertionError as err:
print(f'Comparison failed {err!r}') # noqa: T201
sys.exit(1)
print('Measuring execution times of rosbag2 and rosbag2_py.') # noqa: T201
time_py = timeit(lambda: read_deser_rosbag2_py(path), number=1)
time = timeit(lambda: read_deser_rosbag2(path), number=1)
print( # noqa: T201
f'Processing times:\n'
f'rosbag2_py {time_py:.3f}\n'
f'rosbag2 {time:.3f}\n'
f'speedup {time_py / time:.2f}\n',
)
if __name__ == '__main__':
main()