git-subtree-dir: rosbags git-subtree-mainline: 48df1fbdf4490f3cbfa3267c998d1a0fc98378ca git-subtree-split: c80625df279c154c6ec069cbac30faa319755e47
174 lines
5.6 KiB
Python
174 lines
5.6 KiB
Python
# Copyright 2020-2023 Ternaris.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
"""Tool checking if contents of two rosbags are equal."""
|
|
|
|
# pylint: disable=import-error
|
|
|
|
from __future__ import annotations
|
|
|
|
import array
|
|
import math
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import Mock
|
|
|
|
import genpy # type: ignore
|
|
import numpy
|
|
import rosgraph_msgs.msg # type: ignore
|
|
from rclpy.serialization import deserialize_message # type: ignore
|
|
from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore
|
|
from rosidl_runtime_py.utilities import get_message # type: ignore
|
|
|
|
rosgraph_msgs.msg.Log = Mock()
|
|
rosgraph_msgs.msg.TopicStatistics = Mock()
|
|
|
|
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Generator, List, Protocol, Union, runtime_checkable
|
|
|
|
@runtime_checkable
|
|
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
|
|
"""Minimal native ROS message interface used for benchmark."""
|
|
|
|
def get_fields_and_field_types(self) -> dict[str, str]:
|
|
"""Introspect message type."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Reader: # pylint: disable=too-few-public-methods
|
|
"""Mimimal shim using rosbag2_py to emulate rosbags API."""
|
|
|
|
def __init__(self, path: Union[str, Path]):
|
|
"""Initialize reader shim."""
|
|
self.reader = SequentialReader()
|
|
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
|
|
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
|
|
|
def messages(self) -> Generator[tuple[str, int, bytes], None, None]:
|
|
"""Expose rosbag2 like generator behavior."""
|
|
while self.reader.has_next():
|
|
topic, data, timestamp = self.reader.read_next()
|
|
pytype = get_message(self.typemap[topic])
|
|
yield topic, timestamp, deserialize_message(data, pytype)
|
|
|
|
|
|
def fixup_ros1(conns: List[rosbag.bag._Connection_Info]) -> None:
|
|
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
|
|
|
|
Args:
|
|
conns: Rosbag1 connections.
|
|
|
|
"""
|
|
genpy.Time.sec = property(lambda x: x.secs)
|
|
genpy.Time.nanosec = property(lambda x: x.nsecs)
|
|
genpy.Duration.sec = property(lambda x: x.secs)
|
|
genpy.Duration.nanosec = property(lambda x: x.nsecs)
|
|
|
|
if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None):
|
|
print('Patching CameraInfo') # noqa: T201
|
|
cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access
|
|
cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010
|
|
cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010
|
|
cls.r = property(lambda x: x.R, lambda x, y: setattr(x, 'R', y)) # noqa: B010
|
|
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
|
|
|
|
|
|
def compare(ref: object, msg: object) -> None:
|
|
"""Compare message to its reference.
|
|
|
|
Args:
|
|
ref: Reference ROS1 message.
|
|
msg: Converted ROS2 message.
|
|
|
|
"""
|
|
if isinstance(msg, NativeMSG):
|
|
for name in msg.get_fields_and_field_types():
|
|
refval = getattr(ref, name)
|
|
msgval = getattr(msg, name)
|
|
compare(refval, msgval)
|
|
|
|
elif isinstance(msg, array.array):
|
|
if isinstance(ref, bytes):
|
|
assert msg.tobytes() == ref
|
|
else:
|
|
assert isinstance(msg, numpy.ndarray)
|
|
assert (msg == ref).all()
|
|
|
|
elif isinstance(msg, list):
|
|
assert isinstance(ref, (list, numpy.ndarray))
|
|
assert len(msg) == len(ref)
|
|
for refitem, msgitem in zip(ref, msg):
|
|
compare(refitem, msgitem)
|
|
|
|
elif isinstance(msg, str):
|
|
assert msg == ref
|
|
|
|
elif isinstance(msg, float) and math.isnan(msg):
|
|
assert isinstance(ref, float)
|
|
assert math.isnan(ref)
|
|
|
|
else:
|
|
assert ref == msg
|
|
|
|
|
|
def main_bag1_bag1(path1: Path, path2: Path) -> None:
|
|
"""Compare rosbag1 to rosbag1 message by message.
|
|
|
|
Args:
|
|
path1: Rosbag1 filename.
|
|
path2: Rosbag1 filename.
|
|
|
|
"""
|
|
reader1 = rosbag.bag.Bag(path1)
|
|
reader2 = rosbag.bag.Bag(path2)
|
|
src1 = reader1.read_messages(raw=True, return_connection_header=True)
|
|
src2 = reader2.read_messages(raw=True, return_connection_header=True)
|
|
|
|
for msg1, msg2 in zip(src1, src2):
|
|
assert msg1.connection_header == msg2.connection_header
|
|
assert msg1.message[:-2] == msg2.message[:-2]
|
|
assert msg1.timestamp == msg2.timestamp
|
|
assert msg1.topic == msg2.topic
|
|
|
|
assert next(src1, None) is None
|
|
assert next(src2, None) is None
|
|
|
|
print('Bags are identical.') # noqa: T201
|
|
|
|
|
|
def main_bag1_bag2(path1: Path, path2: Path) -> None:
|
|
"""Compare rosbag1 to rosbag2 message by message.
|
|
|
|
Args:
|
|
path1: Rosbag1 filename.
|
|
path2: Rosbag2 filename.
|
|
|
|
"""
|
|
reader1 = rosbag.bag.Bag(path1)
|
|
src1 = reader1.read_messages()
|
|
src2 = Reader(path2).messages()
|
|
|
|
fixup_ros1(reader1._connections.values()) # pylint: disable=protected-access
|
|
|
|
for msg1, msg2 in zip(src1, src2):
|
|
assert msg1.topic == msg2[0]
|
|
assert msg1.timestamp.to_nsec() == msg2[1]
|
|
compare(msg1.message, msg2[2])
|
|
|
|
assert next(src1, None) is None
|
|
assert next(src2, None) is None
|
|
|
|
print('Bags are identical.') # noqa: T201
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) != 3:
|
|
print(f'Usage: {sys.argv} [rosbag1] [rosbag2]') # noqa: T201
|
|
sys.exit(1)
|
|
arg1 = Path(sys.argv[1])
|
|
arg2 = Path(sys.argv[2])
|
|
main = main_bag1_bag2 if arg2.is_dir() else main_bag1_bag1
|
|
main(arg1, arg2)
|