greenhouse/rosbags/tests/test_reader1.py
apoorva 0c9504b343 Add 'rosbags/' from commit 'c80625df279c154c6ec069cbac30faa319755e47'
git-subtree-dir: rosbags
git-subtree-mainline: 48df1fbdf4490f3cbfa3267c998d1a0fc98378ca
git-subtree-split: c80625df279c154c6ec069cbac30faa319755e47
2023-03-28 18:21:08 +05:30

423 lines
13 KiB
Python

# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from collections import defaultdict
from struct import pack
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag1.reader import IndexData
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Union
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
"""Serialize record header."""
if isinstance(data, dict):
fields = []
for key, value in data.items():
field = b'='.join([key.encode(), value])
fields.append(pack('<L', len(field)) + field)
data = b''.join(fields)
return pack('<L', len(data)) + data
def create_default_header() -> dict[str, bytes]:
"""Create empty rosbag header."""
return {
'op': b'\x03',
'conn_count': pack('<L', 0),
'chunk_count': pack('<L', 0),
}
def create_connection(
cid: int = 1,
topic: int = 0,
typ: int = 0,
) -> tuple[dict[str, bytes], dict[str, bytes]]:
"""Create connection record."""
return {
'op': b'\x07',
'conn': pack('<L', cid),
'topic': f'/topic{topic}'.encode(),
}, {
'type': f'foo_msgs/msg/Foo{typ}'.encode(),
'md5sum': b'AAAA',
'message_definition': b'MSGDEF',
}
def create_message(
cid: int = 1,
time: int = 0,
msg: int = 0,
) -> tuple[dict[str, Union[bytes, int]], bytes]:
"""Create message record."""
return {
'op': b'\x02',
'conn': cid,
'time': time,
}, f'MSGCONTENT{msg}'.encode()
def write_bag( # pylint: disable=too-many-locals
bag: Path,
header: dict[str, bytes],
chunks: Sequence[Any] = (),
) -> None:
"""Write bag file."""
magic = b'#ROSBAG V2.0\n'
pos = 13 + 4096
conn_count = 0
chunk_count = len(chunks or [])
chunks_bytes = b''
connections = b''
chunkinfos = b''
if chunks:
for chunk in chunks:
chunk_bytes = b''
start_time = 2**32 - 1
end_time = 0
counts: dict[int, int] = defaultdict(int)
index = {}
offset = 0
for head, data in chunk:
if head.get('op') == b'\x07':
conn_count += 1
add = ser(head) + ser(data)
chunk_bytes += add
connections += add
elif head.get('op') == b'\x02':
time = head['time']
head['time'] = pack('<LL', head['time'], 0)
conn = head['conn']
head['conn'] = pack('<L', head['conn'])
start_time = min([start_time, time])
end_time = max([end_time, time])
counts[conn] += 1
if conn not in index:
index[conn] = {
'count': 0,
'msgs': b'',
}
index[conn]['count'] += 1 # type: ignore
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
add = ser(head) + ser(data)
chunk_bytes += add
offset = len(chunk_bytes)
else:
add = ser(head) + ser(data)
chunk_bytes += add
chunk_bytes = ser(
{
'op': b'\x05',
'compression': b'none',
'size': pack('<L', len(chunk_bytes)),
},
) + ser(chunk_bytes)
for conn, data in index.items():
chunk_bytes += ser(
{
'op': b'\x04',
'ver': pack('<L', 1),
'conn': pack('<L', conn),
'count': pack('<L', data['count']),
},
) + ser(data['msgs'])
chunks_bytes += chunk_bytes
chunkinfos += ser(
{
'op': b'\x06',
'ver': pack('<L', 1),
'chunk_pos': pack('<Q', pos),
'start_time': pack('<LL', start_time, 0),
'end_time': pack('<LL', end_time, 0),
'count': pack('<L', len(counts.keys())),
},
) + ser(b''.join([pack('<LL', x, y) for x, y in counts.items()]))
pos += len(chunk_bytes)
header['conn_count'] = pack('<L', conn_count)
header['chunk_count'] = pack('<L', chunk_count)
if 'index_pos' not in header:
header['index_pos'] = pack('<Q', pos)
header_bytes = ser(header)
header_bytes += b'\x20' * (4096 - len(header_bytes))
bag.write_bytes(b''.join([
magic,
header_bytes,
chunks_bytes,
connections,
chunkinfos,
]))
def test_indexdata() -> None:
"""Test IndexData sort sorder."""
x42_1_0 = IndexData(42, 1, 0)
x42_2_0 = IndexData(42, 2, 0)
x43_3_0 = IndexData(43, 3, 0)
assert not x42_1_0 < x42_2_0
assert x42_1_0 <= x42_2_0
assert x42_1_0 == x42_2_0
assert not x42_1_0 != x42_2_0 # noqa
assert x42_1_0 >= x42_2_0
assert not x42_1_0 > x42_2_0
assert x42_1_0 < x43_3_0
assert x42_1_0 <= x43_3_0
assert not x42_1_0 == x43_3_0 # noqa
assert x42_1_0 != x43_3_0
assert not x42_1_0 >= x43_3_0
assert not x42_1_0 > x43_3_0
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test reader and deserializer on simple bag."""
# empty bag
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header())
with Reader(bag) as reader:
assert reader.message_count == 0
assert reader.start_time == 2**63 - 1
assert reader.end_time == 0
assert reader.duration == 0
assert not list(reader.messages())
# empty bag, explicit encryptor
bag = tmp_path / 'test.bag'
write_bag(bag, {**create_default_header(), 'encryptor': b''})
with Reader(bag) as reader:
assert reader.message_count == 0
# single message
write_bag(
bag,
create_default_header(),
chunks=[[
create_connection(),
create_message(time=42),
]],
)
with Reader(bag) as reader:
assert reader.message_count == 1
assert reader.duration == 1
assert reader.start_time == 42 * 10**9
assert reader.end_time == 42 * 10**9 + 1
assert len(reader.topics.keys()) == 1
assert reader.topics['/topic0'].msgcount == 1
msgs = list(reader.messages())
assert len(msgs) == 1
# sorts by time on same topic
write_bag(
bag,
create_default_header(),
chunks=[
[
create_connection(),
create_message(time=10, msg=10),
create_message(time=5, msg=5),
],
],
)
with Reader(bag) as reader:
assert reader.message_count == 2
assert reader.duration == 5 * 10**9 + 1
assert reader.start_time == 5 * 10**9
assert reader.end_time == 10 * 10**9 + 1
assert len(reader.topics.keys()) == 1
assert reader.topics['/topic0'].msgcount == 2
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][0].topic == '/topic0'
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][0].topic == '/topic0'
assert msgs[1][2] == b'MSGCONTENT10'
# sorts by time on different topic
write_bag(
bag,
create_default_header(),
chunks=[
[
create_connection(),
create_message(time=10, msg=10),
create_connection(cid=2, topic=2),
create_message(cid=2, time=5, msg=5),
],
],
)
with Reader(bag) as reader:
assert len(reader.topics.keys()) == 2
assert reader.topics['/topic0'].msgcount == 1
assert reader.topics['/topic2'].msgcount == 1
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][2] == b'MSGCONTENT10'
connections = [x for x in reader.connections if x.topic == '/topic0']
msgs = list(reader.messages(connections))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(start=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(stop=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path: Path) -> None:
"""Test user errors."""
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
reader = Reader(bag)
with pytest.raises(ReaderError, match='is not open'):
next(reader.messages())
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test failure cases."""
bag = tmp_path / 'test.bag'
with pytest.raises(ReaderError, match='does not exist'):
Reader(bag).open()
bag.write_text('')
with patch('pathlib.Path.open', side_effect=IOError), \
pytest.raises(ReaderError, match='not open'):
Reader(bag).open()
with pytest.raises(ReaderError, match='empty'):
Reader(bag).open()
bag.write_text('#BADMAGIC')
with pytest.raises(ReaderError, match='magic is invalid'):
Reader(bag).open()
bag.write_text('#ROSBAG V3.0\n')
with pytest.raises(ReaderError, match='Bag version 300 is not supported.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x00')
with pytest.raises(ReaderError, match='Header could not be read from file.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00')
with pytest.raises(ReaderError, match='Header could not be read from file.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00\x01')
with pytest.raises(ReaderError, match='Header field size could not be read.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x04\x00\x00\x00\x01\x00\x00\x00')
with pytest.raises(ReaderError, match='Declared field size is too large for header.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x05\x00\x00\x00\x01\x00\x00\x00x')
with pytest.raises(ReaderError, match='Header field could not be parsed.'):
Reader(bag).open()
write_bag(bag, {'encryptor': b'enc', **create_default_header()})
with pytest.raises(ReaderError, match='is not supported'):
Reader(bag).open()
write_bag(bag, {**create_default_header(), 'index_pos': pack('<Q', 0)})
with pytest.raises(ReaderError, match='Bag is not indexed'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'none', b'COMP'))
with pytest.raises(ReaderError, match='Compression \'COMP\' is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02'))
with pytest.raises(ReaderError, match='CHUNK_INFO version 2 is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02', 1))
with pytest.raises(ReaderError, match='IDXDATA version 2 is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'op=\x02', b'op=\x00', 1))
with Reader(bag) as reader, \
pytest.raises(ReaderError, match='Expected to find message data.'):
next(reader.messages())
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'op=\x03', b'op=\x02', 1))
with pytest.raises(ReaderError, match='Record of type \'MSGDATA\' is unexpected.'):
Reader(bag).open()
# bad uint8 field
write_bag(
bag,
create_default_header(),
chunks=[[
({}, {}),
create_connection(),
create_message(),
]],
)
with Reader(bag) as reader, \
pytest.raises(ReaderError, match='field \'op\''):
next(reader.messages())
# bad uint32, uint64, time field
for name in ('conn_count', 'chunk_pos', 'time'):
write_bag(bag, create_default_header(), chunks=[[create_connection(), create_message()]])
bag.write_bytes(bag.read_bytes().replace(name.encode(), b'x' * len(name), 1))
if name == 'time':
with pytest.raises(ReaderError, match=f'field \'{name}\''), \
Reader(bag) as reader:
next(reader.messages())
else:
with pytest.raises(ReaderError, match=f'field \'{name}\''):
Reader(bag).open()