# Copyright 2020-2021 Ternaris. # SPDX-License-Identifier: Apache-2.0 """Reader tests.""" from __future__ import annotations from collections import defaultdict from struct import pack from typing import TYPE_CHECKING from unittest.mock import patch import pytest from rosbags.rosbag1 import Reader, ReaderError from rosbags.rosbag1.reader import IndexData if TYPE_CHECKING: from pathlib import Path from typing import Any, Sequence, Union def ser(data: Union[dict[str, Any], bytes]) -> bytes: """Serialize record header.""" if isinstance(data, dict): fields = [] for key, value in data.items(): field = b'='.join([key.encode(), value]) fields.append(pack(' dict[str, bytes]: """Create empty rosbag header.""" return { 'op': b'\x03', 'conn_count': pack(' tuple[dict[str, bytes], dict[str, bytes]]: """Create connection record.""" return { 'op': b'\x07', 'conn': pack(' tuple[dict[str, Union[bytes, int]], bytes]: """Create message record.""" return { 'op': b'\x02', 'conn': cid, 'time': time, }, f'MSGCONTENT{msg}'.encode() def write_bag( # pylint: disable=too-many-locals,too-many-statements bag: Path, header: dict[str, bytes], chunks: Sequence[Any] = (), ) -> None: """Write bag file.""" magic = b'#ROSBAG V2.0\n' pos = 13 + 4096 conn_count = 0 chunk_count = len(chunks or []) chunks_bytes = b'' connections = b'' chunkinfos = b'' if chunks: for chunk in chunks: chunk_bytes = b'' start_time = 2**32 - 1 end_time = 0 counts: dict[int, int] = defaultdict(int) index = {} offset = 0 for head, data in chunk: if head.get('op') == b'\x07': conn_count += 1 add = ser(head) + ser(data) chunk_bytes += add connections += add elif head.get('op') == b'\x02': time = head['time'] head['time'] = pack(' None: """Test IndexData sort sorder.""" x42_1_0 = IndexData(42, 1, 0) x42_2_0 = IndexData(42, 2, 0) x43_3_0 = IndexData(43, 3, 0) # flake8: noqa # pylint: disable=unneeded-not assert not x42_1_0 < x42_2_0 assert x42_1_0 <= x42_2_0 assert x42_1_0 == x42_2_0 assert not x42_1_0 != x42_2_0 assert x42_1_0 >= x42_2_0 assert not x42_1_0 > x42_2_0 assert x42_1_0 < x43_3_0 assert x42_1_0 <= x43_3_0 assert not x42_1_0 == x43_3_0 assert x42_1_0 != x43_3_0 assert not x42_1_0 >= x43_3_0 assert not x42_1_0 > x43_3_0 def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements """Test reader and deserializer on simple bag.""" # empty bag bag = tmp_path / 'test.bag' write_bag(bag, create_default_header()) with Reader(bag) as reader: assert reader.message_count == 0 # empty bag, explicit encryptor bag = tmp_path / 'test.bag' write_bag(bag, {**create_default_header(), 'encryptor': b''}) with Reader(bag) as reader: assert reader.message_count == 0 # single message write_bag( bag, create_default_header(), chunks=[[ create_connection(), create_message(time=42), ]] ) with Reader(bag) as reader: assert reader.message_count == 1 assert reader.duration == 1 assert reader.start_time == 42 * 10**9 assert reader.end_time == 42 * 10**9 + 1 assert len(reader.topics.keys()) == 1 assert reader.topics['/topic0'].msgcount == 1 msgs = list(reader.messages()) assert len(msgs) == 1 # sorts by time on same topic write_bag( bag, create_default_header(), chunks=[ [ create_connection(), create_message(time=10, msg=10), create_message(time=5, msg=5), ] ] ) with Reader(bag) as reader: assert reader.message_count == 2 assert reader.duration == 5 * 10**9 + 1 assert reader.start_time == 5 * 10**9 assert reader.end_time == 10 * 10**9 + 1 assert len(reader.topics.keys()) == 1 assert reader.topics['/topic0'].msgcount == 2 msgs = list(reader.messages()) assert len(msgs) == 2 assert msgs[0][0].topic == '/topic0' assert msgs[0][2] == b'MSGCONTENT5' assert msgs[1][0].topic == '/topic0' assert msgs[1][2] == b'MSGCONTENT10' # sorts by time on different topic write_bag( bag, create_default_header(), chunks=[ [ create_connection(), create_message(time=10, msg=10), create_connection(cid=2, topic=2), create_message(cid=2, time=5, msg=5), ] ] ) with Reader(bag) as reader: assert len(reader.topics.keys()) == 2 assert reader.topics['/topic0'].msgcount == 1 assert reader.topics['/topic2'].msgcount == 1 msgs = list(reader.messages()) assert len(msgs) == 2 assert msgs[0][2] == b'MSGCONTENT5' assert msgs[1][2] == b'MSGCONTENT10' connections = [x for x in reader.connections.values() if x.topic == '/topic0'] msgs = list(reader.messages(connections)) assert len(msgs) == 1 assert msgs[0][2] == b'MSGCONTENT10' msgs = list(reader.messages(start=7 * 10**9)) assert len(msgs) == 1 assert msgs[0][2] == b'MSGCONTENT10' msgs = list(reader.messages(stop=7 * 10**9)) assert len(msgs) == 1 assert msgs[0][2] == b'MSGCONTENT5' def test_user_errors(tmp_path: Path) -> None: """Test user errors.""" bag = tmp_path / 'test.bag' write_bag(bag, create_default_header(), chunks=[[ create_connection(), create_message(), ]]) reader = Reader(bag) with pytest.raises(ReaderError, match='is not open'): next(reader.messages()) def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements """Test failure cases.""" bag = tmp_path / 'test.bag' with pytest.raises(ReaderError, match='does not exist'): Reader(bag).open() bag.write_text('') with patch('pathlib.Path.open', side_effect=IOError), \ pytest.raises(ReaderError, match='not open'): Reader(bag).open() with pytest.raises(ReaderError, match='empty'): Reader(bag).open() bag.write_text('#BADMAGIC') with pytest.raises(ReaderError, match='magic is invalid'): Reader(bag).open() bag.write_text('#ROSBAG V3.0\n') with pytest.raises(ReaderError, match='Bag version 300 is not supported.'): Reader(bag).open() bag.write_bytes(b'#ROSBAG V2.0\x0a\x00') with pytest.raises(ReaderError, match='Header could not be read from file.'): Reader(bag).open() bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00') with pytest.raises(ReaderError, match='Header could not be read from file.'): Reader(bag).open() bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00\x01') with pytest.raises(ReaderError, match='Header field size could not be read.'): Reader(bag).open() bag.write_bytes(b'#ROSBAG V2.0\x0a\x04\x00\x00\x00\x01\x00\x00\x00') with pytest.raises(ReaderError, match='Declared field size is too large for header.'): Reader(bag).open() bag.write_bytes(b'#ROSBAG V2.0\x0a\x05\x00\x00\x00\x01\x00\x00\x00x') with pytest.raises(ReaderError, match='Header field could not be parsed.'): Reader(bag).open() write_bag(bag, {'encryptor': b'enc', **create_default_header()}) with pytest.raises(ReaderError, match='is not supported'): Reader(bag).open() write_bag(bag, {**create_default_header(), 'index_pos': pack('