Add 'rosbags/' from commit 'c80625df279c154c6ec069cbac30faa319755e47'

git-subtree-dir: rosbags
git-subtree-mainline: 48df1fbdf4
git-subtree-split: c80625df27
This commit is contained in:
2023-03-28 18:21:08 +05:30
99 changed files with 16378 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag tests."""
+445
View File
@@ -0,0 +1,445 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reference CDR message serializer and deserializer."""
from __future__ import annotations
import sys
from struct import Struct, pack_into, unpack_from
from typing import TYPE_CHECKING, Dict, List, Union, cast
import numpy
from numpy.typing import NDArray
from rosbags.serde.messages import SerdeError, get_msgdef
from rosbags.serde.typing import Msgdef
from rosbags.serde.utils import SIZEMAP, Valtype
from rosbags.typesys import types
if TYPE_CHECKING:
from typing import Any, Tuple
from rosbags.serde.typing import Descriptor
Array = Union[List[Msgdef], List[str], numpy.ndarray]
BasetypeMap = Dict[str, Struct]
BASETYPEMAP_LE: BasetypeMap = {
'bool': Struct('?'),
'int8': Struct('b'),
'int16': Struct('<h'),
'int32': Struct('<i'),
'int64': Struct('<q'),
'uint8': Struct('B'),
'uint16': Struct('<H'),
'uint32': Struct('<I'),
'uint64': Struct('<Q'),
'float32': Struct('<f'),
'float64': Struct('<d'),
}
BASETYPEMAP_BE: BasetypeMap = {
'bool': Struct('?'),
'int8': Struct('b'),
'int16': Struct('>h'),
'int32': Struct('>i'),
'int64': Struct('>q'),
'uint8': Struct('B'),
'uint16': Struct('>H'),
'uint32': Struct('>I'),
'uint64': Struct('>Q'),
'float32': Struct('>f'),
'float64': Struct('>d'),
}
def deserialize_number(rawdata: bytes, bmap: BasetypeMap, pos: int, basetype: str) \
-> Tuple[Union[bool, float, int], int]:
"""Deserialize a single boolean, float, or int.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Read position.
basetype: Number type string.
Returns:
Deserialized number and new read position.
"""
dtype, size = bmap[basetype], SIZEMAP[basetype]
pos = (pos + size - 1) & -size
return dtype.unpack_from(rawdata, pos)[0], pos + size
def deserialize_string(rawdata: bytes, bmap: BasetypeMap, pos: int) \
-> Tuple[str, int]:
"""Deserialize a string value.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Read position.
Returns:
Deserialized string and new read position.
"""
pos = (pos + 4 - 1) & -4
length = bmap['int32'].unpack_from(rawdata, pos)[0]
val = bytes(rawdata[pos + 4:pos + 4 + length - 1])
return val.decode(), pos + 4 + length
def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, desc: Descriptor) \
-> Tuple[Array, int]:
"""Deserialize an array of items of same type.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Read position.
num: Number of elements.
desc: Element type descriptor.
Returns:
Deserialized array and new read position.
Raises:
SerdeError: Unexpected element type.
"""
if desc.valtype == Valtype.BASE:
if desc.args == 'string':
strs = []
while (num := num - 1) >= 0:
val, pos = deserialize_string(rawdata, bmap, pos)
strs.append(val)
return strs, pos
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args]
if desc.valtype == Valtype.MESSAGE:
msgs = []
while (num := num - 1) >= 0:
msg, pos = deserialize_message(rawdata, bmap, pos, desc.args)
msgs.append(msg)
return msgs, pos
raise SerdeError(f'Nested arrays {desc!r} are not supported.')
def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msgdef) \
-> Tuple[Msgdef, int]:
"""Deserialize a message.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Read position.
msgdef: Message definition.
Returns:
Deserialized message and new read position.
"""
values: List[Any] = []
for _, desc in msgdef.fields:
if desc.valtype == Valtype.MESSAGE:
obj, pos = deserialize_message(rawdata, bmap, pos, desc.args)
values.append(obj)
elif desc.valtype == Valtype.BASE:
if desc.args == 'string':
val, pos = deserialize_string(rawdata, bmap, pos)
values.append(val)
else:
num, pos = deserialize_number(rawdata, bmap, pos, desc.args)
values.append(num)
elif desc.valtype == Valtype.ARRAY:
subdesc, length = desc.args
arr, pos = deserialize_array(rawdata, bmap, pos, length, subdesc)
values.append(arr)
elif desc.valtype == Valtype.SEQUENCE:
size, pos = deserialize_number(rawdata, bmap, pos, 'int32')
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args[0])
values.append(arr)
return msgdef.cls(*values), pos
def deserialize(rawdata: bytes, typename: str) -> Msgdef:
"""Deserialize raw data into a message object.
Args:
rawdata: Serialized data.
typename: Type to deserialize.
Returns:
Deserialized message object.
"""
_, little_endian = unpack_from('BB', rawdata, 0)
msgdef = get_msgdef(typename, types)
obj, _ = deserialize_message(
rawdata[4:],
BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE,
0,
msgdef,
)
return obj
def serialize_number(
rawdata: memoryview,
bmap: BasetypeMap,
pos: int,
basetype: str,
val: Union[bool, float, int],
) -> int:
"""Serialize a single boolean, float, or int.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Write position.
basetype: Number type string.
val: Value to serialize.
Returns:
Next write position.
"""
dtype, size = bmap[basetype], SIZEMAP[basetype]
pos = (pos + size - 1) & -size
dtype.pack_into(rawdata, pos, val)
return pos + size
def serialize_string(rawdata: memoryview, bmap: BasetypeMap, pos: int, val: str) \
-> int:
"""Deserialize a string value.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Write position.
val: Value to serialize.
Returns:
Next write position.
"""
bval = memoryview(val.encode())
length = len(bval) + 1
pos = (pos + 4 - 1) & -4
bmap['int32'].pack_into(rawdata, pos, length)
rawdata[pos + 4:pos + 4 + length - 1] = bval
return pos + 4 + length
def serialize_array(
rawdata: memoryview,
bmap: BasetypeMap,
pos: int,
desc: Descriptor,
val: Array,
) -> int:
"""Serialize an array of items of same type.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Write position.
desc: Element type descriptor.
val: Value to serialize.
Returns:
Next write position.
Raises:
SerdeError: Unexpected element type.
"""
if desc.valtype == Valtype.BASE:
if desc.args == 'string':
for item in val:
pos = serialize_string(rawdata, bmap, pos, cast('str', item))
return pos
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
size *= len(val)
val = cast('NDArray[numpy.int_]', val)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
val = val.byteswap() # no inplace on readonly array
rawdata[pos:pos + size] = memoryview(val.tobytes())
return pos + size
if desc.valtype == Valtype.MESSAGE:
for item in val:
pos = serialize_message(rawdata, bmap, pos, item, desc.args)
return pos
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
def serialize_message(
rawdata: memoryview,
bmap: BasetypeMap,
pos: int,
message: object,
msgdef: Msgdef,
) -> int:
"""Serialize a message.
Args:
rawdata: Serialized data.
bmap: Basetype metadata.
pos: Write position.
message: Message object.
msgdef: Message definition.
Returns:
Next write position.
"""
for fieldname, desc in msgdef.fields:
val = getattr(message, fieldname)
if desc.valtype == Valtype.MESSAGE:
pos = serialize_message(rawdata, bmap, pos, val, desc.args)
elif desc.valtype == Valtype.BASE:
if desc.args == 'string':
pos = serialize_string(rawdata, bmap, pos, val)
else:
pos = serialize_number(rawdata, bmap, pos, desc.args, val)
elif desc.valtype == Valtype.ARRAY:
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
elif desc.valtype == Valtype.SEQUENCE:
size = len(val)
pos = serialize_number(rawdata, bmap, pos, 'int32', size)
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
return pos
def get_array_size(desc: Descriptor, val: Array, size: int) -> int:
"""Calculate size of an array.
Args:
desc: Element type descriptor.
val: Array to calculate size of.
size: Current size of message.
Returns:
Size of val in bytes.
Raises:
SerdeError: Unexpected element type.
"""
if desc.valtype == Valtype.BASE:
if desc.args == 'string':
for item in val:
size = (size + 4 - 1) & -4
size += 4 + len(item) + 1
return size
isize = SIZEMAP[desc.args]
size = (size + isize - 1) & -isize
return size + isize * len(val)
if desc.valtype == Valtype.MESSAGE:
for item in val:
size = get_size(item, desc.args, size)
return size
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int:
"""Calculate size of serialzied message.
Args:
message: Message object.
msgdef: Message definition.
size: Current size of message.
Returns:
Size of message in bytes.
Raises:
SerdeError: Unexpected array length in message.
"""
for fieldname, desc in msgdef.fields:
val = getattr(message, fieldname)
if desc.valtype == Valtype.MESSAGE:
size = get_size(val, desc.args, size)
elif desc.valtype == Valtype.BASE:
if desc.args == 'string':
size = (size + 4 - 1) & -4
size += 4 + len(val.encode()) + 1
else:
isize = SIZEMAP[desc.args]
size = (size + isize - 1) & -isize
size += isize
elif desc.valtype == Valtype.ARRAY:
subdesc, length = desc.args
if len(val) != length:
raise SerdeError(f'Unexpected array length: {len(val)} != {length}.')
size = get_array_size(subdesc, val, size)
elif desc.valtype == Valtype.SEQUENCE:
size = (size + 4 - 1) & -4
size += 4
size = get_array_size(desc.args[0], val, size)
return size
def serialize(
message: object,
typename: str,
little_endian: bool = sys.byteorder == 'little',
) -> memoryview:
"""Serialize message object to bytes.
Args:
message: Message object.
typename: Type to serialize.
little_endian: Should use little endianess.
Returns:
Serialized bytes.
"""
msgdef = get_msgdef(typename, types)
size = 4 + get_size(message, msgdef)
rawdata = memoryview(bytearray(size))
pack_into('BB', rawdata, 0, 0, little_endian)
pos = serialize_message(
rawdata[4:],
BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE,
0,
message,
msgdef,
)
assert pos + 4 == size
return rawdata.toreadonly()
+444
View File
@@ -0,0 +1,444 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag1to2 converter tests."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import call, patch
import pytest
from rosbags.convert import ConverterError, convert
from rosbags.convert.__main__ import main
from rosbags.convert.converter import LATCH
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
from rosbags.rosbag1 import ReaderError
from rosbags.rosbag2 import WriterError
if TYPE_CHECKING:
from typing import Any
def test_cliwrapper(tmp_path: Path) -> None:
"""Test cli wrapper."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'ros1.bag').write_text('')
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt']), \
pytest.raises(SystemExit):
main()
assert not cvrt.called
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'no.bag')]), \
pytest.raises(SystemExit):
main()
assert not cvrt.called
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'ros1.bag')]):
main()
cvrt.assert_called_with(
src=tmp_path / 'ros1.bag',
dst=None,
exclude_topics=[],
include_topics=[],
)
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'ros1.bag'),
'--dst',
str(tmp_path / 'subdir')]), \
pytest.raises(SystemExit):
main()
assert not cvrt.called
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'ros1.bag'),
'--dst',
str(tmp_path / 'ros2.bag')]), \
pytest.raises(SystemExit):
main()
assert not cvrt.called
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'ros1.bag'),
'--dst',
str(tmp_path / 'target')]):
main()
cvrt.assert_called_with(
src=tmp_path / 'ros1.bag',
dst=tmp_path / 'target',
exclude_topics=[],
include_topics=[],
)
with patch.object(sys, 'argv', ['cvt', str(tmp_path / 'ros1.bag')]), \
patch('builtins.print') as mock_print, \
patch('rosbags.convert.__main__.convert', side_effect=ConverterError('exc')), \
pytest.raises(SystemExit):
main()
mock_print.assert_called_with('ERROR: exc')
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'subdir')]):
main()
cvrt.assert_called_with(
src=tmp_path / 'subdir',
dst=None,
exclude_topics=[],
include_topics=[],
)
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'subdir'),
'--dst',
str(tmp_path / 'ros1.bag')]), \
pytest.raises(SystemExit):
main()
assert not cvrt.called
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'subdir'),
'--dst',
str(tmp_path / 'target.bag')]):
main()
cvrt.assert_called_with(
src=tmp_path / 'subdir',
dst=tmp_path / 'target.bag',
exclude_topics=[],
include_topics=[],
)
with patch.object(sys, 'argv', ['cvt', str(tmp_path / 'subdir')]), \
patch('builtins.print') as mock_print, \
patch('rosbags.convert.__main__.convert', side_effect=ConverterError('exc')), \
pytest.raises(SystemExit):
main()
mock_print.assert_called_with('ERROR: exc')
with patch('rosbags.convert.__main__.convert') as cvrt, \
patch.object(sys, 'argv', ['cvt',
str(tmp_path / 'ros1.bag'),
'--exclude-topic',
'/foo']):
main()
cvrt.assert_called_with(
src=tmp_path / 'ros1.bag',
dst=None,
exclude_topics=['/foo'],
include_topics=[],
)
def test_convert_1to2(tmp_path: Path) -> None:
"""Test conversion from rosbag1 to rosbag2."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'foo.bag').write_text('')
with pytest.raises(ConverterError, match='exists already'):
convert(Path('foo.bag'), tmp_path / 'subdir')
with patch('rosbags.convert.converter.Reader1') as reader, \
patch('rosbags.convert.converter.Writer2') as writer, \
patch('rosbags.convert.converter.get_types_from_msg', return_value={'typ': 'def'}), \
patch('rosbags.convert.converter.register_types') as register_types, \
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
readerinst = reader.return_value.__enter__.return_value
writerinst = writer.return_value.__enter__.return_value
connections = [
Connection(1, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False), None),
Connection(2, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, True), None),
Connection(3, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False), None),
Connection(
4,
'/other',
'typ',
'def',
'',
-1,
ConnectionExtRosbag1('caller', False),
None,
),
]
wconnections = [
Connection(1, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None),
Connection(2, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', LATCH), None),
Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None),
]
readerinst.connections = [
connections[0],
connections[1],
connections[2],
connections[3],
]
readerinst.messages.return_value = [
(connections[0], 42, b'\x42'),
(connections[1], 43, b'\x43'),
(connections[2], 44, b'\x44'),
(connections[3], 45, b'\x45'),
]
writerinst.connections = []
def add_connection(*_: Any) -> Connection: # noqa: ANN401
"""Mock for Writer.add_connection."""
writerinst.connections = [
conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
]
return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection
ros1_to_cdr.return_value = b'666'
convert(Path('foo.bag'), None)
reader.assert_called_with(Path('foo.bag'))
readerinst.messages.assert_called_with(connections=readerinst.connections)
writer.assert_called_with(Path('foo'))
writerinst.add_connection.assert_has_calls(
[
call('/topic', 'typ', 'cdr', ''),
call('/topic', 'typ', 'cdr', LATCH),
call('/other', 'typ', 'cdr', ''),
],
)
writerinst.write.assert_has_calls(
[
call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666'),
call(wconnections[2], 44, b'666'),
call(wconnections[2], 45, b'666'),
],
)
register_types.assert_called_with({'typ': 'def'})
ros1_to_cdr.assert_has_calls(
[
call(b'\x42', 'typ'),
call(b'\x43', 'typ'),
call(b'\x44', 'typ'),
call(b'\x45', 'typ'),
],
)
with pytest.raises(ConverterError, match='No connections left for conversion'):
convert(Path('foo.bag'), None, ['/topic', '/other'])
writerinst.connections.clear()
ros1_to_cdr.side_effect = KeyError('exc')
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
convert(Path('foo.bag'), None)
writer.side_effect = WriterError('exc')
with pytest.raises(ConverterError, match='Writing destination bag: exc'):
convert(Path('foo.bag'), None)
reader.side_effect = ReaderError('exc')
with pytest.raises(ConverterError, match='Reading source bag: exc'):
convert(Path('foo.bag'), None)
def test_convert_2to1(tmp_path: Path) -> None:
"""Test conversion from rosbag2 to rosbag1."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'foo.bag').write_text('')
with pytest.raises(ConverterError, match='exists already'):
convert(Path('subdir'), tmp_path / 'foo.bag')
with patch('rosbags.convert.converter.Reader2') as reader, \
patch('rosbags.convert.converter.Writer1') as writer, \
patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1:
readerinst = reader.return_value.__enter__.return_value
writerinst = writer.return_value.__enter__.return_value
connections = [
Connection(
1,
'/topic',
'std_msgs/msg/Bool',
'',
'',
-1,
ConnectionExtRosbag2('', ''),
None,
),
Connection(
2,
'/topic',
'std_msgs/msg/Bool',
'',
'',
-1,
ConnectionExtRosbag2('', LATCH),
None,
),
Connection(
3,
'/other',
'std_msgs/msg/Bool',
'',
'',
-1,
ConnectionExtRosbag2('', ''),
None,
),
Connection(
4,
'/other',
'std_msgs/msg/Bool',
'',
'',
-1,
ConnectionExtRosbag2('', '0'),
None,
),
]
wconnections = [
Connection(
1,
'/topic',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, False),
None,
),
Connection(
2,
'/topic',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, True),
None,
),
Connection(
3,
'/other',
'std_msgs/msg/Bool',
'',
'8b94c1b53db61fb6aed406028ad6332a',
-1,
ConnectionExtRosbag1(None, False),
None,
),
]
readerinst.connections = [
connections[0],
connections[1],
connections[2],
connections[3],
]
readerinst.messages.return_value = [
(connections[0], 42, b'\x42'),
(connections[1], 43, b'\x43'),
(connections[2], 44, b'\x44'),
(connections[3], 45, b'\x45'),
]
writerinst.connections = []
def add_connection(*_: Any) -> Connection: # noqa: ANN401
"""Mock for Writer.add_connection."""
writerinst.connections = [
conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
]
return wconnections[len(writerinst.connections) - 1]
writerinst.add_connection.side_effect = add_connection
cdr_to_ros1.return_value = b'666'
convert(Path('foo'), None)
reader.assert_called_with(Path('foo'))
reader.return_value.__enter__.return_value.messages.assert_called_with(
connections=readerinst.connections,
)
writer.assert_called_with(Path('foo.bag'))
writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
[
call(
'/topic',
'std_msgs/msg/Bool',
'bool data\n',
'8b94c1b53db61fb6aed406028ad6332a',
None,
0,
),
call(
'/topic',
'std_msgs/msg/Bool',
'bool data\n',
'8b94c1b53db61fb6aed406028ad6332a',
None,
1,
),
call(
'/other',
'std_msgs/msg/Bool',
'bool data\n',
'8b94c1b53db61fb6aed406028ad6332a',
None,
0,
),
],
)
writer.return_value.__enter__.return_value.write.assert_has_calls(
[
call(wconnections[0], 42, b'666'),
call(wconnections[1], 43, b'666'),
call(wconnections[2], 44, b'666'),
call(wconnections[2], 45, b'666'),
],
)
cdr_to_ros1.assert_has_calls(
[
call(b'\x42', 'std_msgs/msg/Bool'),
call(b'\x43', 'std_msgs/msg/Bool'),
call(b'\x44', 'std_msgs/msg/Bool'),
call(b'\x45', 'std_msgs/msg/Bool'),
],
)
with pytest.raises(ConverterError, match='No connections left for conversion'):
convert(Path('foobag'), None, ['/topic', '/other'])
writerinst.connections.clear()
cdr_to_ros1.side_effect = KeyError('exc')
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
convert(Path('foo'), None)
writer.side_effect = WriterError('exc')
with pytest.raises(ConverterError, match='Writing destination bag: exc'):
convert(Path('foo'), None)
reader.side_effect = ReaderError('exc')
with pytest.raises(ConverterError, match='Reading source bag: exc'):
convert(Path('foo'), None)
+262
View File
@@ -0,0 +1,262 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from rosbags.highlevel import AnyReader, AnyReaderError
from rosbags.interfaces import Connection
from rosbags.rosbag1 import Writer as Writer1
from rosbags.rosbag2 import Writer as Writer2
if TYPE_CHECKING:
from pathlib import Path
from typing import Sequence
HEADER = b'\x00\x01\x00\x00'
@pytest.fixture()
def bags1(tmp_path: Path) -> list[Path]:
"""Test data fixture."""
paths = [
tmp_path / 'ros1_1.bag',
tmp_path / 'ros1_2.bag',
tmp_path / 'ros1_3.bag',
tmp_path / 'bad.bag',
]
with (Writer1(paths[0])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic1, 1, b'\x01')
writer.write(topic2, 2, b'\x02\x00')
writer.write(topic1, 9, b'\x09')
with (Writer1(paths[1])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
writer.write(topic1, 5, b'\x05')
with (Writer1(paths[2])) as writer:
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic2, 15, b'\x15\x00')
paths[3].touch()
return paths
@pytest.fixture()
def bags2(tmp_path: Path) -> list[Path]:
"""Test data fixture."""
paths = [
tmp_path / 'ros2_1',
tmp_path / 'bad',
]
with (Writer2(paths[0])) as writer:
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
writer.write(topic1, 1, HEADER + b'\x01')
writer.write(topic2, 2, HEADER + b'\x02\x00')
writer.write(topic1, 9, HEADER + b'\x09')
writer.write(topic1, 5, HEADER + b'\x05')
writer.write(topic2, 15, HEADER + b'\x15\x00')
paths[1].mkdir()
(paths[1] / 'metadata.yaml').write_text(':')
return paths
def test_anyreader1(bags1: Sequence[Path]) -> None: # pylint: disable=redefined-outer-name
"""Test AnyReader on rosbag1."""
# pylint: disable=too-many-statements
with pytest.raises(AnyReaderError, match='at least one'):
AnyReader([])
with pytest.raises(AnyReaderError, match='missing'):
AnyReader([bags1[0] / 'badname'])
reader = AnyReader(bags1)
with pytest.raises(AssertionError):
assert reader.topics
with pytest.raises(AssertionError):
next(reader.messages())
reader = AnyReader(bags1)
with pytest.raises(AnyReaderError, match='seems to be empty'):
reader.open()
assert all(not x.bio for x in reader.readers) # type: ignore[union-attr]
with AnyReader(bags1[:3]) as reader:
assert reader.duration == 15
assert reader.start_time == 1
assert reader.end_time == 16
assert reader.message_count == 5
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
assert len(reader.topics['/topic1'].connections) == 2
assert reader.topics['/topic1'].msgcount == 3
assert len(reader.topics['/topic2'].connections) == 2
assert reader.topics['/topic2'].msgcount == 2
gen = reader.messages()
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (1, b'\x01')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 1 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (2, b'\x02\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 2 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (5, b'\x05')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 5 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (9, b'\x09')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 9 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (15, b'\x15\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 21 # type: ignore
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=reader.topics['/topic1'].connections)
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
with pytest.raises(StopIteration):
next(gen)
def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
"""Test AnyReader on rosbag2."""
# pylint: disable=too-many-statements
with pytest.raises(AnyReaderError, match='multiple rosbag2'):
AnyReader(bags2)
with pytest.raises(AnyReaderError, match='YAML'):
AnyReader([bags2[1]])
with AnyReader([bags2[0]]) as reader:
assert reader.duration == 15
assert reader.start_time == 1
assert reader.end_time == 16
assert reader.message_count == 5
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
assert len(reader.topics['/topic1'].connections) == 1
assert reader.topics['/topic1'].msgcount == 3
assert len(reader.topics['/topic2'].connections) == 1
assert reader.topics['/topic2'].msgcount == 2
gen = reader.messages()
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (1, HEADER + b'\x01')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 1 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (2, HEADER + b'\x02\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 2 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (5, HEADER + b'\x05')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 5 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic1'
assert nxt[1:] == (9, HEADER + b'\x09')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 9 # type: ignore
nxt = next(gen)
assert nxt[0].topic == '/topic2'
assert nxt[1:] == (15, HEADER + b'\x15\x00')
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
assert msg.data == 21 # type: ignore
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=reader.topics['/topic1'].connections)
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
nxt = next(gen)
assert nxt[0].topic == '/topic1'
with pytest.raises(StopIteration):
next(gen)
def test_anyreader2_autoregister(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
"""Test AnyReader on rosbag2."""
class MockReader:
"""Mock reader."""
# pylint: disable=too-few-public-methods
def __init__(self, paths: list[Path]):
"""Initialize mock."""
_ = paths
self.metadata = {'storage_identifier': 'mcap'}
self.connections = [
Connection(
1,
'/foo',
'test_msg/msg/Foo',
'string foo',
'msg',
0,
None, # type: ignore
self,
),
Connection(
2,
'/bar',
'test_msg/msg/Bar',
'module test_msgs { module msg { struct Bar {string bar;}; }; };',
'idl',
0,
None, # type: ignore
self,
),
Connection(
3,
'/baz',
'test_msg/msg/Baz',
'',
'',
0,
None, # type: ignore
self,
),
]
def open(self) -> None:
"""Unused."""
with patch('rosbags.highlevel.anyreader.Reader2', MockReader), \
patch('rosbags.highlevel.anyreader.register_types') as mock_register_types:
AnyReader([bags2[0]]).open()
mock_register_types.assert_called_once()
assert mock_register_types.call_args[0][0] == {
'test_msg/msg/Foo': ([], [('foo', (1, 'string'))]),
'test_msgs/msg/Bar': ([], [('bar', (1, 'string'))]),
}
+363
View File
@@ -0,0 +1,363 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Message definition parser tests."""
import pytest
from rosbags.typesys import (
TypesysError,
generate_msgdef,
get_types_from_idl,
get_types_from_msg,
register_types,
)
from rosbags.typesys.base import Nodetype
from rosbags.typesys.types import FIELDDEFS
MSG = """
# comment
bool b=true
int32 global=42
float32 f=1.33
string str= foo bar\t
std_msgs/Header header
std_msgs/msg/Bool bool
test_msgs/Bar sibling
float64 base
float64[] seq1
float64[] seq2
float64[4] array
"""
MSG_BOUNDS = """
int32[] unbounded_integer_array
int32[5] five_integers_array
int32[<=5] up_to_five_integers_array
string string_of_unbounded_size
string<=10 up_to_ten_characters_string
string[<=5] up_to_five_unbounded_strings
string<=10[] unbounded_array_of_string_up_to_ten_characters_each
string<=10[<=5] up_to_five_strings_up_to_ten_characters_each
"""
MSG_DEFAULTS = """
bool b false
uint8 i 42
uint8 o 0377
uint8 h 0xff
float32 y -314.15e-2
string name1 "John"
string name2 'Ringo'
int32[] samples [-200, -100, 0, 100, 200]
"""
MULTI_MSG = """
std_msgs/Header header
byte b
char c
Other[] o
================================================================================
MSG: std_msgs/Header
time time
================================================================================
MSG: test_msgs/Other
uint64[3] Header
uint32 static = 42
"""
CSTRING_CONFUSION_MSG = """
std_msgs/Header header
string s
================================================================================
MSG: std_msgs/Header
time time
"""
RELSIBLING_MSG = """
Header header
Other other
"""
IDL_LANG = """
// assign different literals and expressions
#ifndef FOO
#define FOO
#include <global>
#include "local"
const bool g_bool = TRUE;
const int8 g_int1 = 7;
const int8 g_int2 = 07;
const int8 g_int3 = 0x7;
const float64 g_float1 = 1.1;
const float64 g_float2 = 1e10;
const char g_char = 'c';
const string g_string1 = "";
const string<128> g_string2 = "str" "ing";
module Foo {
const int64 g_expr1 = ~1;
const int64 g_expr2 = 2 * 4;
};
#endif
"""
IDL = """
// comment in file
module test_msgs {
// comment in module
typedef std_msgs::msg::Bool Bool;
/**/ /***/ /* block comment */
/*
* block comment
*/
module msg {
// comment in submodule
typedef Bool Balias;
typedef test_msgs::msg::Bar Bar;
typedef double d4[4];
module Foo_Constants {
const int32 FOO = 32;
const int64 BAR = 64;
};
@comment(type="text", text="ignore")
struct Foo {
// comment in struct
std_msgs::msg::Header header;
Balias bool;
Bar sibling;
double/* comment in member declaration */x;
sequence<double> seq1;
sequence<double, 4> seq2;
d4 array;
};
};
struct Bar {
int i;
};
};
"""
IDL_STRINGARRAY = """
module test_msgs {
module msg {
typedef string string__3[3];
struct Strings {
string__3 values;
};
};
};
"""
def test_parse_empty_msg() -> None:
"""Test msg parser with empty message."""
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
assert ret == {'std_msgs/msg/Empty': ([], [])}
def test_parse_bounds_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
assert ret == {
'test_msgs/msg/Foo': (
[],
[
('unbounded_integer_array', (4, ((1, 'int32'), None))),
('five_integers_array', (3, ((1, 'int32'), 5))),
('up_to_five_integers_array', (4, ((1, 'int32'), None))),
('string_of_unbounded_size', (1, 'string')),
('up_to_ten_characters_string', (1, 'string')),
('up_to_five_unbounded_strings', (4, ((1, 'string'), None))),
('unbounded_array_of_string_up_to_ten_characters_each', (4, ((1, 'string'), None))),
('up_to_five_strings_up_to_ten_characters_each', (4, ((1, 'string'), None))),
],
),
}
def test_parse_defaults_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
assert ret == {
'test_msgs/msg/Foo': (
[],
[
('b', (1, 'bool')),
('i', (1, 'uint8')),
('o', (1, 'uint8')),
('h', (1, 'uint8')),
('y', (1, 'float32')),
('name1', (1, 'string')),
('name2', (1, 'string')),
('samples', (4, ((1, 'int32'), None))),
],
),
}
def test_parse_msg() -> None:
"""Test msg parser."""
with pytest.raises(TypesysError, match='Could not parse'):
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo')
assert 'test_msgs/msg/Foo' in ret
consts, fields = ret['test_msgs/msg/Foo']
assert consts == [
('b', 'bool', True),
('global', 'int32', 42),
('f', 'float32', 1.33),
('str', 'string', 'foo bar'),
]
assert fields[0][0] == 'header'
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][0] == 'bool'
assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][0] == 'sibling'
assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[3][1][0] == Nodetype.BASE
assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg() -> None:
"""Test multi msg parser."""
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 3
assert 'test_msgs/msg/Foo' in ret
assert 'std_msgs/msg/Header' in ret
assert 'test_msgs/msg/Other' in ret
fields = ret['test_msgs/msg/Foo'][1]
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][1][1] == 'uint8'
assert fields[2][1][1] == 'uint8'
consts = ret['test_msgs/msg/Other'][0]
assert consts == [('static', 'uint32', 42)]
def test_parse_cstring_confusion() -> None:
"""Test if msg separator is confused with const string."""
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 2
assert 'test_msgs/msg/Foo' in ret
assert 'std_msgs/msg/Header' in ret
consts, fields = ret['test_msgs/msg/Foo']
assert consts == []
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][1][1] == 'string'
def test_parse_relative_siblings_msg() -> None:
"""Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][1][1][1][1] == 'test_msgs/msg/Other'
ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo')
assert ret['rel_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl() -> None:
"""Test idl parser."""
ret = get_types_from_idl(IDL_LANG)
assert ret == {}
ret = get_types_from_idl(IDL)
assert 'test_msgs/msg/Foo' in ret
consts, fields = ret['test_msgs/msg/Foo']
assert consts == [('FOO', 'int32', 32), ('BAR', 'int64', 64)]
assert fields[0][0] == 'header'
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][0] == 'bool'
assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][0] == 'sibling'
assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[3][1][0] == Nodetype.BASE
assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
assert 'test_msgs/Bar' in ret
consts, fields = ret['test_msgs/Bar']
assert consts == []
assert len(fields) == 1
assert fields[0][0] == 'i'
assert fields[0][1][1] == 'int'
ret = get_types_from_idl(IDL_STRINGARRAY)
consts, fields = ret['test_msgs/msg/Strings']
assert consts == []
assert len(fields) == 1
assert fields[0][0] == 'values'
assert fields[0][1] == (Nodetype.ARRAY, ((Nodetype.BASE, 'string'), 3))
def test_register_types() -> None:
"""Test type registeration."""
assert 'foo' not in FIELDDEFS
register_types({})
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
def test_generate_msgdef() -> None:
"""Test message definition generator."""
res = generate_msgdef('std_msgs/msg/Header')
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
res = generate_msgdef('geometry_msgs/msg/PointStamped')
assert res[0].split(f'{"=" * 80}\n') == [
'std_msgs/Header header\ngeometry_msgs/Point point\n',
'MSG: std_msgs/Header\nuint32 seq\ntime stamp\nstring frame_id\n',
'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n',
]
res = generate_msgdef('geometry_msgs/msg/Twist')
assert res[0].split(f'{"=" * 80}\n') == [
'geometry_msgs/Vector3 linear\ngeometry_msgs/Vector3 angular\n',
'MSG: geometry_msgs/Vector3\nfloat64 x\nfloat64 y\nfloat64 z\n',
]
res = generate_msgdef('shape_msgs/msg/Mesh')
assert res[0].split(f'{"=" * 80}\n') == [
'shape_msgs/MeshTriangle[] triangles\ngeometry_msgs/Point[] vertices\n',
'MSG: shape_msgs/MeshTriangle\nuint32[3] vertex_indices\n',
'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n',
]
res = generate_msgdef('shape_msgs/msg/Plane')
assert res[0] == 'float64[4] coef\n'
res = generate_msgdef('sensor_msgs/msg/MultiEchoLaserScan')
assert len(res[0].split('=' * 80)) == 3
register_types(get_types_from_msg('time[3] times\nuint8 foo=42', 'foo_msgs/Timelist'))
res = generate_msgdef('foo_msgs/msg/Timelist')
assert res[0] == 'uint8 foo=42\ntime[3] times\n'
with pytest.raises(TypesysError, match='is unknown'):
generate_msgdef('foo_msgs/msg/Badname')
+751
View File
@@ -0,0 +1,751 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
# pylint: disable=redefined-outer-name
from __future__ import annotations
import sqlite3
import struct
from io import BytesIO
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
import pytest
import zstandard
from rosbags.rosbag2 import Reader, ReaderError, Writer
from .test_serde import MSG_JOINT, MSG_MAGN, MSG_MAGN_BIG, MSG_POLY
if TYPE_CHECKING:
from typing import BinaryIO, Iterable
from _pytest.fixtures import SubRequest
METADATA = """
rosbag2_bagfile_information:
version: 4
storage_identifier: sqlite3
relative_file_paths:
- db.db3{extension}
duration:
nanoseconds: 42
starting_time:
nanoseconds_since_epoch: 666
message_count: 4
topics_with_message_count:
- topic_metadata:
name: /poly
type: geometry_msgs/msg/Polygon
serialization_format: cdr
offered_qos_profiles: ""
message_count: 1
- topic_metadata:
name: /magn
type: sensor_msgs/msg/MagneticField
serialization_format: cdr
offered_qos_profiles: ""
message_count: 2
- topic_metadata:
name: /joint
type: trajectory_msgs/msg/JointTrajectory
serialization_format: cdr
offered_qos_profiles: ""
message_count: 1
compression_format: {compression_format}
compression_mode: {compression_mode}
"""
METADATA_EMPTY = """
rosbag2_bagfile_information:
version: 6
storage_identifier: sqlite3
relative_file_paths:
- db.db3
duration:
nanoseconds: 0
starting_time:
nanoseconds_since_epoch: 0
message_count: 0
topics_with_message_count: []
compression_format: ""
compression_mode: ""
files:
- duration:
nanoseconds: 0
message_count: 0
path: db.db3
starting_time:
nanoseconds_since_epoch: 0
custom_data:
key1: value1
key2: value2
"""
@pytest.fixture(params=['none', 'file', 'message'])
def bag(request: SubRequest, tmp_path: Path) -> Path:
"""Manually contruct bag."""
(tmp_path / 'metadata.yaml').write_text(
METADATA.format(
extension='' if request.param != 'file' else '.zstd',
compression_format='""' if request.param == 'none' else 'zstd',
compression_mode='""' if request.param == 'none' else request.param.upper(),
),
)
comp = zstandard.ZstdCompressor()
dbpath = tmp_path / 'db.db3'
dbh = sqlite3.connect(dbpath)
dbh.executescript(Writer.SQLITE_SCHEMA)
cur = dbh.cursor()
cur.execute(
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
(1, '/poly', 'geometry_msgs/msg/Polygon', 'cdr', ''),
)
cur.execute(
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
(2, '/magn', 'sensor_msgs/msg/MagneticField', 'cdr', ''),
)
cur.execute(
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
(3, '/joint', 'trajectory_msgs/msg/JointTrajectory', 'cdr', ''),
)
cur.execute(
'INSERT INTO messages VALUES(?, ?, ?, ?)',
(1, 1, 666, MSG_POLY[0] if request.param != 'message' else comp.compress(MSG_POLY[0])),
)
cur.execute(
'INSERT INTO messages VALUES(?, ?, ?, ?)',
(2, 2, 708, MSG_MAGN[0] if request.param != 'message' else comp.compress(MSG_MAGN[0])),
)
cur.execute(
'INSERT INTO messages VALUES(?, ?, ?, ?)',
(
3,
2,
708,
MSG_MAGN_BIG[0] if request.param != 'message' else comp.compress(MSG_MAGN_BIG[0]),
),
)
cur.execute(
'INSERT INTO messages VALUES(?, ?, ?, ?)',
(4, 3, 708, MSG_JOINT[0] if request.param != 'message' else comp.compress(MSG_JOINT[0])),
)
dbh.commit()
if request.param == 'file':
with dbpath.open('rb') as ifh, (tmp_path / 'db.db3.zstd').open('wb') as ofh:
comp.copy_stream(ifh, ofh)
dbpath.unlink()
return tmp_path
def test_empty_bag(tmp_path: Path) -> None:
"""Test bags with broken fs layout."""
(tmp_path / 'metadata.yaml').write_text(METADATA_EMPTY)
dbpath = tmp_path / 'db.db3'
dbh = sqlite3.connect(dbpath)
dbh.executescript(Writer.SQLITE_SCHEMA)
with Reader(tmp_path) as reader:
assert reader.message_count == 0
assert reader.start_time == 2**63 - 1
assert reader.end_time == 0
assert reader.duration == 0
assert not list(reader.messages())
assert reader.custom_data['key1'] == 'value1'
assert reader.custom_data['key2'] == 'value2'
def test_reader(bag: Path) -> None:
"""Test reader and deserializer on simple bag."""
with Reader(bag) as reader:
assert reader.duration == 43
assert reader.start_time == 666
assert reader.end_time == 709
assert reader.message_count == 4
if reader.compression_mode:
assert reader.compression_format == 'zstd'
assert [x.id for x in reader.connections] == [1, 2, 3]
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
gen = reader.messages()
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/poly'
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
assert timestamp == 666
assert rawdata == MSG_POLY[0]
for idx in range(2):
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/magn'
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
assert timestamp == 708
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/joint'
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
with pytest.raises(StopIteration):
next(gen)
def test_message_filters(bag: Path) -> None:
"""Test reader filters messages."""
with Reader(bag) as reader:
magn_connections = [x for x in reader.connections if x.topic == '/magn']
gen = reader.messages(connections=magn_connections)
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(start=667)
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/joint'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(stop=667)
connection, _, _ = next(gen)
assert connection.topic == '/poly'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=magn_connections, stop=667)
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(start=666, stop=666)
with pytest.raises(StopIteration):
next(gen)
def test_user_errors(bag: Path) -> None:
"""Test user errors."""
reader = Reader(bag)
with pytest.raises(ReaderError, match='Rosbag is not open'):
next(reader.messages())
def test_failure_cases(tmp_path: Path) -> None:
"""Test bags with broken fs layout."""
with pytest.raises(ReaderError, match='not read metadata'):
Reader(tmp_path)
metadata = tmp_path / 'metadata.yaml'
metadata.write_text('')
with pytest.raises(ReaderError, match='not read'), \
mock.patch.object(Path, 'read_text', side_effect=PermissionError):
Reader(tmp_path)
metadata.write_text(' invalid:\nthis is not yaml')
with pytest.raises(ReaderError, match='not load YAML from'):
Reader(tmp_path)
metadata.write_text('foo:')
with pytest.raises(ReaderError, match='key is missing'):
Reader(tmp_path)
metadata.write_text(
METADATA.format(
extension='',
compression_format='""',
compression_mode='""',
).replace('version: 4', 'version: 999'),
)
with pytest.raises(ReaderError, match='version 999'):
Reader(tmp_path)
metadata.write_text(
METADATA.format(
extension='',
compression_format='""',
compression_mode='""',
).replace('sqlite3', 'hdf5'),
)
with pytest.raises(ReaderError, match='Storage plugin'):
Reader(tmp_path)
metadata.write_text(
METADATA.format(
extension='',
compression_format='""',
compression_mode='""',
),
)
with pytest.raises(ReaderError, match='files are missing'):
Reader(tmp_path)
(tmp_path / 'db.db3').write_text('')
metadata.write_text(
METADATA.format(
extension='',
compression_format='""',
compression_mode='""',
).replace('cdr', 'bson'),
)
with pytest.raises(ReaderError, match='Serialization format'):
Reader(tmp_path)
metadata.write_text(
METADATA.format(
extension='',
compression_format='"gz"',
compression_mode='"file"',
),
)
with pytest.raises(ReaderError, match='Compression format'):
Reader(tmp_path)
metadata.write_text(
METADATA.format(
extension='',
compression_format='""',
compression_mode='""',
),
)
with pytest.raises(ReaderError, match='not open database'), \
Reader(tmp_path) as reader:
next(reader.messages())
def write_record(bio: BinaryIO, opcode: int, records: Iterable[bytes]) -> None:
"""Write record."""
data = b''.join(records)
bio.write(bytes([opcode]) + struct.pack('<Q', len(data)) + data)
def make_string(text: str) -> bytes:
"""Serialize string."""
data = text.encode()
return struct.pack('<I', len(data)) + data
MCAP_HEADER = b'\x89MCAP0\r\n'
SCHEMAS = [
(
0x03,
(
struct.pack('<H', 1),
make_string('geometry_msgs/msg/Polygon'),
make_string('ros2msg'),
make_string('string foo'),
),
),
(
0x03,
(
struct.pack('<H', 2),
make_string('sensor_msgs/msg/MagneticField'),
make_string('ros2msg'),
make_string('string foo'),
),
),
(
0x03,
(
struct.pack('<H', 3),
make_string('trajectory_msgs/msg/JointTrajectory'),
make_string('ros2msg'),
make_string('string foo'),
),
),
]
CHANNELS = [
(
0x04,
(
struct.pack('<H', 1),
struct.pack('<H', 1),
make_string('/poly'),
make_string('cdr'),
make_string(''),
),
),
(
0x04,
(
struct.pack('<H', 2),
struct.pack('<H', 2),
make_string('/magn'),
make_string('cdr'),
make_string(''),
),
),
(
0x04,
(
struct.pack('<H', 3),
struct.pack('<H', 3),
make_string('/joint'),
make_string('cdr'),
make_string(''),
),
),
]
@pytest.fixture(
params=['unindexed', 'partially_indexed', 'indexed', 'chunked_unindexed', 'chunked_indexed'],
)
def bag_mcap(request: SubRequest, tmp_path: Path) -> Path:
"""Manually contruct mcap bag."""
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
(tmp_path / 'metadata.yaml').write_text(
METADATA.format(
extension='.mcap',
compression_format='""',
compression_mode='""',
).replace('sqlite3', 'mcap'),
)
path = tmp_path / 'db.db3.mcap'
bio: BinaryIO
messages: list[tuple[int, int, int]] = []
chunks = []
with path.open('wb') as bio:
realbio = bio
bio.write(MCAP_HEADER)
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
if request.param.startswith('chunked'):
bio = BytesIO()
messages = []
write_record(bio, *SCHEMAS[0])
write_record(bio, *CHANNELS[0])
messages.append((1, 666, bio.tell()))
write_record(
bio,
0x05,
(
struct.pack('<H', 1),
struct.pack('<I', 1),
struct.pack('<Q', 666),
struct.pack('<Q', 666),
MSG_POLY[0],
),
)
if request.param.startswith('chunked'):
assert isinstance(bio, BytesIO)
chunk_start = realbio.tell()
compression = make_string('')
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
write_record(
realbio,
0x06,
(
struct.pack('<Q', 666),
struct.pack('<Q', 666),
uncompressed_size,
struct.pack('<I', 0),
compression,
compressed_size,
bio.getbuffer(),
),
)
message_index_offsets = []
message_index_start = realbio.tell()
for channel_id, group in groupby(messages, key=lambda x: x[0]):
message_index_offsets.append((channel_id, realbio.tell()))
tpls = [y for x in group for y in x[1:]]
write_record(
realbio,
0x07,
(
struct.pack('<H', channel_id),
struct.pack('<I', 8 * len(tpls)),
struct.pack('<' + 'Q' * len(tpls), *tpls),
),
)
chunk = [
struct.pack('<Q', 666),
struct.pack('<Q', 666),
struct.pack('<Q', chunk_start),
struct.pack('<Q', message_index_start - chunk_start),
struct.pack('<I', 10 * len(message_index_offsets)),
*(struct.pack('<HQ', *x) for x in message_index_offsets),
struct.pack('<Q',
realbio.tell() - message_index_start),
compression,
compressed_size,
uncompressed_size,
]
chunks.append(chunk)
bio = BytesIO()
messages = []
write_record(bio, *SCHEMAS[1])
write_record(bio, *CHANNELS[1])
messages.append((2, 708, bio.tell()))
write_record(
bio,
0x05,
(
struct.pack('<H', 2),
struct.pack('<I', 1),
struct.pack('<Q', 708),
struct.pack('<Q', 708),
MSG_MAGN[0],
),
)
messages.append((2, 708, bio.tell()))
write_record(
bio,
0x05,
(
struct.pack('<H', 2),
struct.pack('<I', 2),
struct.pack('<Q', 708),
struct.pack('<Q', 708),
MSG_MAGN_BIG[0],
),
)
write_record(bio, *SCHEMAS[2])
write_record(bio, *CHANNELS[2])
messages.append((3, 708, bio.tell()))
write_record(
bio,
0x05,
(
struct.pack('<H', 3),
struct.pack('<I', 1),
struct.pack('<Q', 708),
struct.pack('<Q', 708),
MSG_JOINT[0],
),
)
if request.param.startswith('chunked'):
assert isinstance(bio, BytesIO)
chunk_start = realbio.tell()
compression = make_string('')
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
write_record(
realbio,
0x06,
(
struct.pack('<Q', 708),
struct.pack('<Q', 708),
uncompressed_size,
struct.pack('<I', 0),
compression,
compressed_size,
bio.getbuffer(),
),
)
message_index_offsets = []
message_index_start = realbio.tell()
for channel_id, group in groupby(messages, key=lambda x: x[0]):
message_index_offsets.append((channel_id, realbio.tell()))
tpls = [y for x in group for y in x[1:]]
write_record(
realbio,
0x07,
(
struct.pack('<H', channel_id),
struct.pack('<I', 8 * len(tpls)),
struct.pack('<' + 'Q' * len(tpls), *tpls),
),
)
chunk = [
struct.pack('<Q', 708),
struct.pack('<Q', 708),
struct.pack('<Q', chunk_start),
struct.pack('<Q', message_index_start - chunk_start),
struct.pack('<I', 10 * len(message_index_offsets)),
*(struct.pack('<HQ', *x) for x in message_index_offsets),
struct.pack('<Q',
realbio.tell() - message_index_start),
compression,
compressed_size,
uncompressed_size,
]
chunks.append(chunk)
bio = realbio
messages = []
if request.param in ['indexed', 'partially_indexed', 'chunked_indexed']:
summary_start = bio.tell()
for schema in SCHEMAS:
write_record(bio, *schema)
if request.param != 'partially_indexed':
for channel in CHANNELS:
write_record(bio, *channel)
if request.param == 'chunked_indexed':
for chunk in chunks:
write_record(bio, 0x08, chunk)
summary_offset_start = 0
write_record(bio, 0x0a, (b'ignored',))
write_record(
bio,
0x0b,
(
struct.pack('<Q', 4),
struct.pack('<H', 3),
struct.pack('<I', 3),
struct.pack('<I', 0),
struct.pack('<I', 0),
struct.pack('<I', 0 if request.param == 'indexed' else 1),
struct.pack('<Q', 666),
struct.pack('<Q', 708),
struct.pack('<I', 0),
),
)
write_record(bio, 0x0d, (b'ignored',))
write_record(bio, 0xff, (b'ignored',))
else:
summary_start = 0
summary_offset_start = 0
write_record(
bio,
0x02,
(
struct.pack('<Q', summary_start),
struct.pack('<Q', summary_offset_start),
struct.pack('<I', 0),
),
)
bio.write(MCAP_HEADER)
return tmp_path
def test_reader_mcap(bag_mcap: Path) -> None:
"""Test reader and deserializer on simple bag."""
with Reader(bag_mcap) as reader:
assert reader.duration == 43
assert reader.start_time == 666
assert reader.end_time == 709
assert reader.message_count == 4
if reader.compression_mode:
assert reader.compression_format == 'zstd'
assert [x.id for x in reader.connections] == [1, 2, 3]
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
gen = reader.messages()
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/poly'
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
assert timestamp == 666
assert rawdata == MSG_POLY[0]
for idx in range(2):
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/magn'
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
assert timestamp == 708
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
connection, timestamp, rawdata = next(gen)
assert connection.topic == '/joint'
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
with pytest.raises(StopIteration):
next(gen)
def test_message_filters_mcap(bag_mcap: Path) -> None:
"""Test reader filters messages."""
with Reader(bag_mcap) as reader:
magn_connections = [x for x in reader.connections if x.topic == '/magn']
gen = reader.messages(connections=magn_connections)
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(start=667)
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/magn'
connection, _, _ = next(gen)
assert connection.topic == '/joint'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(stop=667)
connection, _, _ = next(gen)
assert connection.topic == '/poly'
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(connections=magn_connections, stop=667)
with pytest.raises(StopIteration):
next(gen)
gen = reader.messages(start=666, stop=666)
with pytest.raises(StopIteration):
next(gen)
def test_bag_mcap_files(tmp_path: Path) -> None:
"""Test bad mcap files."""
(tmp_path / 'metadata.yaml').write_text(
METADATA.format(
extension='.mcap',
compression_format='""',
compression_mode='""',
).replace('sqlite3', 'mcap'),
)
path = tmp_path / 'db.db3.mcap'
path.touch()
reader = Reader(tmp_path)
path.unlink()
with pytest.raises(ReaderError, match='Could not open'):
reader.open()
path.touch()
with pytest.raises(ReaderError, match='seems to be empty'):
Reader(tmp_path).open()
path.write_bytes(b'xxxxxxxx')
with pytest.raises(ReaderError, match='magic is invalid'):
Reader(tmp_path).open()
path.write_bytes(b'\x89MCAP0\r\n\xFF')
with pytest.raises(ReaderError, match='Unexpected record'):
Reader(tmp_path).open()
with path.open('wb') as bio:
bio.write(b'\x89MCAP0\r\n')
write_record(bio, 0x01, (make_string('ros1'), make_string('test_mcap')))
with pytest.raises(ReaderError, match='Profile is not'):
Reader(tmp_path).open()
with path.open('wb') as bio:
bio.write(b'\x89MCAP0\r\n')
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
with pytest.raises(ReaderError, match='File end magic is invalid'):
Reader(tmp_path).open()
+422
View File
@@ -0,0 +1,422 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from collections import defaultdict
from struct import pack
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag1.reader import IndexData
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Union
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
"""Serialize record header."""
if isinstance(data, dict):
fields = []
for key, value in data.items():
field = b'='.join([key.encode(), value])
fields.append(pack('<L', len(field)) + field)
data = b''.join(fields)
return pack('<L', len(data)) + data
def create_default_header() -> dict[str, bytes]:
"""Create empty rosbag header."""
return {
'op': b'\x03',
'conn_count': pack('<L', 0),
'chunk_count': pack('<L', 0),
}
def create_connection(
cid: int = 1,
topic: int = 0,
typ: int = 0,
) -> tuple[dict[str, bytes], dict[str, bytes]]:
"""Create connection record."""
return {
'op': b'\x07',
'conn': pack('<L', cid),
'topic': f'/topic{topic}'.encode(),
}, {
'type': f'foo_msgs/msg/Foo{typ}'.encode(),
'md5sum': b'AAAA',
'message_definition': b'MSGDEF',
}
def create_message(
cid: int = 1,
time: int = 0,
msg: int = 0,
) -> tuple[dict[str, Union[bytes, int]], bytes]:
"""Create message record."""
return {
'op': b'\x02',
'conn': cid,
'time': time,
}, f'MSGCONTENT{msg}'.encode()
def write_bag( # pylint: disable=too-many-locals
bag: Path,
header: dict[str, bytes],
chunks: Sequence[Any] = (),
) -> None:
"""Write bag file."""
magic = b'#ROSBAG V2.0\n'
pos = 13 + 4096
conn_count = 0
chunk_count = len(chunks or [])
chunks_bytes = b''
connections = b''
chunkinfos = b''
if chunks:
for chunk in chunks:
chunk_bytes = b''
start_time = 2**32 - 1
end_time = 0
counts: dict[int, int] = defaultdict(int)
index = {}
offset = 0
for head, data in chunk:
if head.get('op') == b'\x07':
conn_count += 1
add = ser(head) + ser(data)
chunk_bytes += add
connections += add
elif head.get('op') == b'\x02':
time = head['time']
head['time'] = pack('<LL', head['time'], 0)
conn = head['conn']
head['conn'] = pack('<L', head['conn'])
start_time = min([start_time, time])
end_time = max([end_time, time])
counts[conn] += 1
if conn not in index:
index[conn] = {
'count': 0,
'msgs': b'',
}
index[conn]['count'] += 1 # type: ignore
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
add = ser(head) + ser(data)
chunk_bytes += add
offset = len(chunk_bytes)
else:
add = ser(head) + ser(data)
chunk_bytes += add
chunk_bytes = ser(
{
'op': b'\x05',
'compression': b'none',
'size': pack('<L', len(chunk_bytes)),
},
) + ser(chunk_bytes)
for conn, data in index.items():
chunk_bytes += ser(
{
'op': b'\x04',
'ver': pack('<L', 1),
'conn': pack('<L', conn),
'count': pack('<L', data['count']),
},
) + ser(data['msgs'])
chunks_bytes += chunk_bytes
chunkinfos += ser(
{
'op': b'\x06',
'ver': pack('<L', 1),
'chunk_pos': pack('<Q', pos),
'start_time': pack('<LL', start_time, 0),
'end_time': pack('<LL', end_time, 0),
'count': pack('<L', len(counts.keys())),
},
) + ser(b''.join([pack('<LL', x, y) for x, y in counts.items()]))
pos += len(chunk_bytes)
header['conn_count'] = pack('<L', conn_count)
header['chunk_count'] = pack('<L', chunk_count)
if 'index_pos' not in header:
header['index_pos'] = pack('<Q', pos)
header_bytes = ser(header)
header_bytes += b'\x20' * (4096 - len(header_bytes))
bag.write_bytes(b''.join([
magic,
header_bytes,
chunks_bytes,
connections,
chunkinfos,
]))
def test_indexdata() -> None:
"""Test IndexData sort sorder."""
x42_1_0 = IndexData(42, 1, 0)
x42_2_0 = IndexData(42, 2, 0)
x43_3_0 = IndexData(43, 3, 0)
assert not x42_1_0 < x42_2_0
assert x42_1_0 <= x42_2_0
assert x42_1_0 == x42_2_0
assert not x42_1_0 != x42_2_0 # noqa
assert x42_1_0 >= x42_2_0
assert not x42_1_0 > x42_2_0
assert x42_1_0 < x43_3_0
assert x42_1_0 <= x43_3_0
assert not x42_1_0 == x43_3_0 # noqa
assert x42_1_0 != x43_3_0
assert not x42_1_0 >= x43_3_0
assert not x42_1_0 > x43_3_0
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test reader and deserializer on simple bag."""
# empty bag
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header())
with Reader(bag) as reader:
assert reader.message_count == 0
assert reader.start_time == 2**63 - 1
assert reader.end_time == 0
assert reader.duration == 0
assert not list(reader.messages())
# empty bag, explicit encryptor
bag = tmp_path / 'test.bag'
write_bag(bag, {**create_default_header(), 'encryptor': b''})
with Reader(bag) as reader:
assert reader.message_count == 0
# single message
write_bag(
bag,
create_default_header(),
chunks=[[
create_connection(),
create_message(time=42),
]],
)
with Reader(bag) as reader:
assert reader.message_count == 1
assert reader.duration == 1
assert reader.start_time == 42 * 10**9
assert reader.end_time == 42 * 10**9 + 1
assert len(reader.topics.keys()) == 1
assert reader.topics['/topic0'].msgcount == 1
msgs = list(reader.messages())
assert len(msgs) == 1
# sorts by time on same topic
write_bag(
bag,
create_default_header(),
chunks=[
[
create_connection(),
create_message(time=10, msg=10),
create_message(time=5, msg=5),
],
],
)
with Reader(bag) as reader:
assert reader.message_count == 2
assert reader.duration == 5 * 10**9 + 1
assert reader.start_time == 5 * 10**9
assert reader.end_time == 10 * 10**9 + 1
assert len(reader.topics.keys()) == 1
assert reader.topics['/topic0'].msgcount == 2
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][0].topic == '/topic0'
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][0].topic == '/topic0'
assert msgs[1][2] == b'MSGCONTENT10'
# sorts by time on different topic
write_bag(
bag,
create_default_header(),
chunks=[
[
create_connection(),
create_message(time=10, msg=10),
create_connection(cid=2, topic=2),
create_message(cid=2, time=5, msg=5),
],
],
)
with Reader(bag) as reader:
assert len(reader.topics.keys()) == 2
assert reader.topics['/topic0'].msgcount == 1
assert reader.topics['/topic2'].msgcount == 1
msgs = list(reader.messages())
assert len(msgs) == 2
assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][2] == b'MSGCONTENT10'
connections = [x for x in reader.connections if x.topic == '/topic0']
msgs = list(reader.messages(connections))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(start=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10'
msgs = list(reader.messages(stop=7 * 10**9))
assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path: Path) -> None:
"""Test user errors."""
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
reader = Reader(bag)
with pytest.raises(ReaderError, match='is not open'):
next(reader.messages())
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test failure cases."""
bag = tmp_path / 'test.bag'
with pytest.raises(ReaderError, match='does not exist'):
Reader(bag).open()
bag.write_text('')
with patch('pathlib.Path.open', side_effect=IOError), \
pytest.raises(ReaderError, match='not open'):
Reader(bag).open()
with pytest.raises(ReaderError, match='empty'):
Reader(bag).open()
bag.write_text('#BADMAGIC')
with pytest.raises(ReaderError, match='magic is invalid'):
Reader(bag).open()
bag.write_text('#ROSBAG V3.0\n')
with pytest.raises(ReaderError, match='Bag version 300 is not supported.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x00')
with pytest.raises(ReaderError, match='Header could not be read from file.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00')
with pytest.raises(ReaderError, match='Header could not be read from file.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00\x01')
with pytest.raises(ReaderError, match='Header field size could not be read.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x04\x00\x00\x00\x01\x00\x00\x00')
with pytest.raises(ReaderError, match='Declared field size is too large for header.'):
Reader(bag).open()
bag.write_bytes(b'#ROSBAG V2.0\x0a\x05\x00\x00\x00\x01\x00\x00\x00x')
with pytest.raises(ReaderError, match='Header field could not be parsed.'):
Reader(bag).open()
write_bag(bag, {'encryptor': b'enc', **create_default_header()})
with pytest.raises(ReaderError, match='is not supported'):
Reader(bag).open()
write_bag(bag, {**create_default_header(), 'index_pos': pack('<Q', 0)})
with pytest.raises(ReaderError, match='Bag is not indexed'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'none', b'COMP'))
with pytest.raises(ReaderError, match='Compression \'COMP\' is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02'))
with pytest.raises(ReaderError, match='CHUNK_INFO version 2 is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02', 1))
with pytest.raises(ReaderError, match='IDXDATA version 2 is not supported.'):
Reader(bag).open()
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'op=\x02', b'op=\x00', 1))
with Reader(bag) as reader, \
pytest.raises(ReaderError, match='Expected to find message data.'):
next(reader.messages())
write_bag(bag, create_default_header(), chunks=[[
create_connection(),
create_message(),
]])
bag.write_bytes(bag.read_bytes().replace(b'op=\x03', b'op=\x02', 1))
with pytest.raises(ReaderError, match='Record of type \'MSGDATA\' is unexpected.'):
Reader(bag).open()
# bad uint8 field
write_bag(
bag,
create_default_header(),
chunks=[[
({}, {}),
create_connection(),
create_message(),
]],
)
with Reader(bag) as reader, \
pytest.raises(ReaderError, match='field \'op\''):
next(reader.messages())
# bad uint32, uint64, time field
for name in ('conn_count', 'chunk_pos', 'time'):
write_bag(bag, create_default_header(), chunks=[[create_connection(), create_message()]])
bag.write_bytes(bag.read_bytes().replace(name.encode(), b'x' * len(name), 1))
if name == 'time':
with pytest.raises(ReaderError, match=f'field \'{name}\''), \
Reader(bag) as reader:
next(reader.messages())
else:
with pytest.raises(ReaderError, match=f'field \'{name}\''):
Reader(bag).open()
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Test full data roundtrip."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from rosbags.rosbag2 import Reader, Writer
from rosbags.serde import deserialize_cdr, serialize_cdr
if TYPE_CHECKING:
from pathlib import Path
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods
"""Dummy class."""
data = 1.25
path = tmp_path / 'rosbag2'
wbag = Writer(path)
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
with wbag:
msgtype = 'std_msgs/msg/Float64'
wconnection = wbag.add_connection('/test', msgtype)
wbag.write(wconnection, 42, serialize_cdr(Foo, msgtype))
rbag = Reader(path)
with rbag:
gen = rbag.messages()
rconnection, _, raw = next(gen)
assert rconnection.topic == wconnection.topic
assert rconnection.msgtype == wconnection.msgtype
assert rconnection.ext == wconnection.ext
msg = deserialize_cdr(raw, rconnection.msgtype)
assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration):
next(gen)
+44
View File
@@ -0,0 +1,44 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Test full data roundtrip."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from rosbags.rosbag1 import Reader, Writer
from rosbags.serde import cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
if TYPE_CHECKING:
from pathlib import Path
from typing import Optional
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods
"""Dummy class."""
data = 1.25
path = tmp_path / 'test.bag'
wbag = Writer(path)
if fmt:
wbag.set_compression(fmt)
with wbag:
msgtype = 'std_msgs/msg/Float64'
conn = wbag.add_connection('/test', msgtype)
wbag.write(conn, 42, cdr_to_ros1(serialize_cdr(Foo, msgtype), msgtype))
rbag = Reader(path)
with rbag:
gen = rbag.messages()
connection, _, raw = next(gen)
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration):
next(gen)
+513
View File
@@ -0,0 +1,513 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Serializer and deserializer tests."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import numpy
import pytest
from rosbags.serde import (
SerdeError,
cdr_to_ros1,
deserialize_cdr,
deserialize_ros1,
ros1_to_cdr,
serialize_cdr,
serialize_ros1,
)
from rosbags.serde.messages import get_msgdef
from rosbags.typesys import get_types_from_msg, register_types, types
from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon
from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField
from rosbags.typesys.types import std_msgs__msg__Header as Header
from .cdr import deserialize, serialize
if TYPE_CHECKING:
from typing import Any, Generator, Union
MSG_POLY = (
(
b'\x00\x01\x00\x00' # header
b'\x02\x00\x00\x00' # number of points = 2
b'\x00\x00\x80\x3f' # x = 1
b'\x00\x00\x00\x40' # y = 2
b'\x00\x00\x40\x40' # z = 3
b'\x00\x00\xa0\x3f' # x = 1.25
b'\x00\x00\x10\x40' # y = 2.25
b'\x00\x00\x50\x40' # z = 3.25
),
'geometry_msgs/msg/Polygon',
True,
)
MSG_MAGN = (
(
b'\x00\x01\x00\x00' # header
b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns
b'\x06\x00\x00\x00foo42\x00' # frameid 'foo42'
b'\x00\x00\x00\x00\x00\x00' # padding
b'\x00\x00\x00\x00\x00\x00\x60\x40' # x = 128
b'\x00\x00\x00\x00\x00\x00\x60\x40' # y = 128
b'\x00\x00\x00\x00\x00\x00\x60\x40' # z = 128
b'\x00\x00\x00\x00\x00\x00\xF0\x3F' # covariance matrix = 3x3 diag
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\xF0\x3F'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\xF0\x3F'
),
'sensor_msgs/msg/MagneticField',
True,
)
MSG_MAGN_BIG = (
(
b'\x00\x00\x00\x00' # header
b'\x00\x00\x02\xc4\x00\x00\x01\x00' # timestamp = 708s 256ns
b'\x00\x00\x00\x06foo42\x00' # frameid 'foo42'
b'\x00\x00\x00\x00\x00\x00' # padding
b'\x40\x60\x00\x00\x00\x00\x00\x00' # x = 128
b'\x40\x60\x00\x00\x00\x00\x00\x00' # y = 128
b'\x40\x60\x00\x00\x00\x00\x00\x00' # z = 128
b'\x3F\xF0\x00\x00\x00\x00\x00\x00' # covariance matrix = 3x3 diag
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x3F\xF0\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x3F\xF0\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00' # garbage
),
'sensor_msgs/msg/MagneticField',
False,
)
MSG_JOINT = (
(
b'\x00\x01\x00\x00' # header
b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns
b'\x04\x00\x00\x00bar\x00' # frameid 'bar'
b'\x02\x00\x00\x00' # number of strings
b'\x02\x00\x00\x00a\x00' # string 'a'
b'\x00\x00' # padding
b'\x02\x00\x00\x00b\x00' # string 'b'
b'\x00\x00' # padding
b'\x00\x00\x00\x00' # number of points
b'\x00\x00\x00' # garbage
),
'trajectory_msgs/msg/JointTrajectory',
True,
)
MESSAGES = [MSG_POLY, MSG_MAGN, MSG_MAGN_BIG, MSG_JOINT]
STATIC_64_64 = """
uint64[2] u64
"""
STATIC_64_16 = """
uint64 u64
uint16 u16
"""
STATIC_16_64 = """
uint16 u16
uint64 u64
"""
DYNAMIC_64_64 = """
uint64[] u64
"""
DYNAMIC_64_B_64 = """
uint64 u64
bool b
float64 f64
"""
DYNAMIC_64_S = """
uint64 u64
string s
"""
DYNAMIC_S_64 = """
string s
uint64 u64
"""
CUSTOM = """
string base_str
float32 base_f32
test_msgs/msg/static_64_64 msg_s66
test_msgs/msg/static_64_16 msg_s61
test_msgs/msg/static_16_64 msg_s16
test_msgs/msg/dynamic_64_64 msg_d66
test_msgs/msg/dynamic_64_b_64 msg_d6b6
test_msgs/msg/dynamic_64_s msg_d6s
test_msgs/msg/dynamic_s_64 msg_ds6
string[2] arr_base_str
float32[2] arr_base_f32
test_msgs/msg/static_64_64[2] arr_msg_s66
test_msgs/msg/static_64_16[2] arr_msg_s61
test_msgs/msg/static_16_64[2] arr_msg_s16
test_msgs/msg/dynamic_64_64[2] arr_msg_d66
test_msgs/msg/dynamic_64_b_64[2] arr_msg_d6b6
test_msgs/msg/dynamic_64_s[2] arr_msg_d6s
test_msgs/msg/dynamic_s_64[2] arr_msg_ds6
string[] seq_base_str
float32[] seq_base_f32
test_msgs/msg/static_64_64[] seq_msg_s66
test_msgs/msg/static_64_16[] seq_msg_s61
test_msgs/msg/static_16_64[] seq_msg_s16
test_msgs/msg/dynamic_64_64[] seq_msg_d66
test_msgs/msg/dynamic_64_b_64[] seq_msg_d6b6
test_msgs/msg/dynamic_64_s[] seq_msg_d6s
test_msgs/msg/dynamic_s_64[] seq_msg_ds6
"""
SU64_B = """
uint64[] su64
bool b
"""
SU64_U64 = """
uint64[] su64
uint64 u64
"""
SMSG_U64 = """
su64_u64[] seq
uint64 u64
"""
@pytest.fixture()
def _comparable() -> Generator[None, None, None]:
"""Make messages containing numpy arrays comparable.
Notes:
This solution is necessary as numpy.ndarray is not directly patchable.
"""
frombuffer = numpy.frombuffer
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
lhs = self._mock_wraps # pylint: disable=protected-access
rhs = getattr(other, '_mock_wraps', other)
return (lhs == rhs).all() # type: ignore
class CNDArray(MagicMock):
"""Mock ndarray."""
def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
super().__init__(*args, **kwargs)
self.dtype = kwargs['wraps'].dtype
self.reshape = kwargs['wraps'].reshape
self.__eq__ = arreq # type: ignore
def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401
"""Wrap return value also in mock."""
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401
return CNDArray(wraps=frombuffer(*args, **kwargs))
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield
@pytest.mark.parametrize('message', MESSAGES)
def test_serde(message: tuple[bytes, str, bool]) -> None:
"""Test serialization deserialization roundtrip."""
rawdata, typ, is_little = message
serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little)
assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little)
assert serdeser == rawdata[:len(serdeser)]
assert len(rawdata) - len(serdeser) < 4
assert all(x == 0 for x in rawdata[len(serdeser):])
if rawdata[1] == 1:
rawdata = cdr_to_ros1(rawdata, typ)
serdeser = serialize_ros1(deserialize_ros1(rawdata, typ), typ)
assert serdeser == rawdata
@pytest.mark.usefixtures('_comparable')
def test_deserializer() -> None:
"""Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2])
assert isinstance(msg, Polygon)
assert len(msg.points) == 2
assert msg.points[0].x == 1
assert msg.points[0].y == 2
assert msg.points[0].z == 3
assert msg.points[1].x == 1.25
assert msg.points[1].y == 2.25
assert msg.points[1].z == 3.25
msg_ros1 = deserialize_ros1(cdr_to_ros1(*MSG_POLY[:2]), MSG_POLY[1])
assert msg_ros1 == msg
msg = deserialize_cdr(*MSG_MAGN[:2])
assert msg == deserialize(*MSG_MAGN[:2])
assert isinstance(msg, MagneticField)
assert 'MagneticField' in repr(msg)
assert msg.header.stamp.sec == 708
assert msg.header.stamp.nanosec == 256
assert msg.header.frame_id == 'foo42'
field = msg.magnetic_field
assert (field.x, field.y, field.z) == (128., 128., 128.)
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3))
assert (diag == [1., 1., 1.]).all()
msg_ros1 = deserialize_ros1(cdr_to_ros1(*MSG_MAGN[:2]), MSG_MAGN[1])
assert msg_ros1 == msg
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
assert isinstance(msg_big, MagneticField)
assert msg.magnetic_field == msg_big.magnetic_field
@pytest.mark.usefixtures('_comparable')
def test_serializer() -> None:
"""Test serializer."""
class Foo: # pylint: disable=too-few-public-methods
"""Dummy class."""
data = 7
msg = Foo()
ret = serialize_cdr(msg, 'std_msgs/msg/Int8', True)
assert ret == serialize(msg, 'std_msgs/msg/Int8', True)
assert ret == b'\x00\x01\x00\x00\x07'
ret = serialize_cdr(msg, 'std_msgs/msg/Int8', False)
assert ret == serialize(msg, 'std_msgs/msg/Int8', False)
assert ret == b'\x00\x00\x00\x00\x07'
ret = serialize_cdr(msg, 'std_msgs/msg/Int16', True)
assert ret == serialize(msg, 'std_msgs/msg/Int16', True)
assert ret == b'\x00\x01\x00\x00\x07\x00'
ret = serialize_cdr(msg, 'std_msgs/msg/Int16', False)
assert ret == serialize(msg, 'std_msgs/msg/Int16', False)
assert ret == b'\x00\x00\x00\x00\x00\x07'
@pytest.mark.usefixtures('_comparable')
def test_serializer_errors() -> None:
"""Test seralizer with broken messages."""
class Foo: # pylint: disable=too-few-public-methods
"""Dummy class."""
coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4])
msg = Foo()
ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
assert ret == serialize(msg, 'shape_msgs/msg/Plane', True)
msg.coef = numpy.array([1, 2, 3, 4, 4])
with pytest.raises(SerdeError, match='array length'):
serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
@pytest.mark.usefixtures('_comparable')
def test_custom_type() -> None:
"""Test custom type."""
cname = 'test_msgs/msg/custom'
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
register_types(dict(get_types_from_msg(STATIC_64_16, 'test_msgs/msg/static_64_16')))
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
register_types(dict(get_types_from_msg(DYNAMIC_64_64, 'test_msgs/msg/dynamic_64_64')))
register_types(dict(get_types_from_msg(DYNAMIC_64_B_64, 'test_msgs/msg/dynamic_64_b_64')))
register_types(dict(get_types_from_msg(DYNAMIC_64_S, 'test_msgs/msg/dynamic_64_s')))
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
register_types(dict(get_types_from_msg(CUSTOM, cname)))
static_64_64 = get_msgdef('test_msgs/msg/static_64_64', types).cls
static_64_16 = get_msgdef('test_msgs/msg/static_64_16', types).cls
static_16_64 = get_msgdef('test_msgs/msg/static_16_64', types).cls
dynamic_64_64 = get_msgdef('test_msgs/msg/dynamic_64_64', types).cls
dynamic_64_b_64 = get_msgdef('test_msgs/msg/dynamic_64_b_64', types).cls
dynamic_64_s = get_msgdef('test_msgs/msg/dynamic_64_s', types).cls
dynamic_s_64 = get_msgdef('test_msgs/msg/dynamic_s_64', types).cls
custom = get_msgdef('test_msgs/msg/custom', types).cls
msg = custom(
'str',
1.5,
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
static_64_16(64, 16),
static_16_64(16, 64),
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
dynamic_64_b_64(64, True, 1.25),
dynamic_64_s(64, 's'),
dynamic_s_64('s', 64),
# arrays
['str_1', ''],
numpy.array([1.5, 0.75], dtype=numpy.float32),
[
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
],
[static_64_16(64, 16), static_64_16(64, 16)],
[static_16_64(16, 64), static_16_64(16, 64)],
[
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
],
[
dynamic_64_b_64(64, True, 1.25),
dynamic_64_b_64(64, True, 1.25),
],
[dynamic_64_s(64, 's'), dynamic_64_s(64, 's')],
[dynamic_s_64('s', 64), dynamic_s_64('s', 64)],
# sequences
['str_1', ''],
numpy.array([1.5, 0.75], dtype=numpy.float32),
[
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
],
[static_64_16(64, 16), static_64_16(64, 16)],
[static_16_64(16, 64), static_16_64(16, 64)],
[
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
],
[
dynamic_64_b_64(64, True, 1.25),
dynamic_64_b_64(64, True, 1.25),
],
[dynamic_64_s(64, 's'), dynamic_64_s(64, 's')],
[dynamic_s_64('s', 64), dynamic_s_64('s', 64)],
)
res = deserialize_cdr(serialize_cdr(msg, cname), cname)
assert res == deserialize(serialize(msg, cname), cname)
assert res == msg
res = deserialize_ros1(serialize_ros1(msg, cname), cname)
assert res == msg
def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion."""
msgtype = 'test_msgs/msg/static_16_64'
register_types(dict(get_types_from_msg(STATIC_16_64, msgtype)))
msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02'
)
assert ros1_to_cdr(msg_ros, msgtype) == msg_cdr
assert serialize_cdr(deserialize_ros1(msg_ros, msgtype), msgtype) == msg_cdr
msgtype = 'test_msgs/msg/dynamic_s_64'
register_types(dict(get_types_from_msg(DYNAMIC_S_64, msgtype)))
msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00'
b'\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02'
)
assert ros1_to_cdr(msg_ros, msgtype) == msg_cdr
assert serialize_cdr(deserialize_ros1(msg_ros, msgtype), msgtype) == msg_cdr
def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion."""
msgtype = 'test_msgs/msg/static_16_64'
register_types(dict(get_types_from_msg(STATIC_16_64, msgtype)))
msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02'
)
assert cdr_to_ros1(msg_cdr, msgtype) == msg_ros
assert serialize_ros1(deserialize_cdr(msg_cdr, msgtype), msgtype) == msg_ros
msgtype = 'test_msgs/msg/dynamic_s_64'
register_types(dict(get_types_from_msg(DYNAMIC_S_64, msgtype)))
msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00'
b'\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02'
)
assert cdr_to_ros1(msg_cdr, msgtype) == msg_ros
assert serialize_ros1(deserialize_cdr(msg_cdr, msgtype), msgtype) == msg_ros
header = Header(stamp=Time(42, 666), frame_id='frame')
msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header')
assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame'
@pytest.mark.usefixtures('_comparable')
def test_padding_empty_sequence() -> None:
"""Test empty sequences do not add item padding."""
register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b')))
su64_b = get_msgdef('test_msgs/msg/su64_b', types).cls
msg = su64_b(numpy.array([], dtype=numpy.uint64), True)
cdr = serialize_cdr(msg, msg.__msgtype__)
assert cdr[4:] == b'\x00\x00\x00\x00\x01'
ros1 = cdr_to_ros1(cdr, msg.__msgtype__)
assert ros1 == cdr[4:]
assert ros1_to_cdr(ros1, msg.__msgtype__) == cdr
assert deserialize_cdr(cdr, msg.__msgtype__) == msg
@pytest.mark.usefixtures('_comparable')
def test_align_after_empty_sequence() -> None:
"""Test alignment after empty sequences."""
register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64')))
register_types(dict(get_types_from_msg(SMSG_U64, 'test_msgs/msg/smsg_u64')))
su64_u64 = get_msgdef('test_msgs/msg/su64_u64', types).cls
smsg_u64 = get_msgdef('test_msgs/msg/smsg_u64', types).cls
msg1 = su64_u64(numpy.array([], dtype=numpy.uint64), 42)
msg2 = smsg_u64([], 42)
cdr = serialize_cdr(msg1, msg1.__msgtype__)
assert cdr[4:] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00'
assert serialize_cdr(msg2, msg2.__msgtype__) == cdr
ros1 = cdr_to_ros1(cdr, msg1.__msgtype__)
assert ros1 == b'\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00'
assert cdr_to_ros1(cdr, msg2.__msgtype__) == ros1
assert ros1_to_cdr(ros1, msg1.__msgtype__) == cdr
assert deserialize_cdr(cdr, msg1.__msgtype__) == msg1
assert deserialize_cdr(cdr, msg2.__msgtype__) == msg2
+127
View File
@@ -0,0 +1,127 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Writer tests."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from rosbags.interfaces import Connection, ConnectionExtRosbag2
from rosbags.rosbag2 import Writer, WriterError
if TYPE_CHECKING:
from pathlib import Path
def test_writer(tmp_path: Path) -> None:
"""Test Writer."""
path = tmp_path / 'rosbag2'
with Writer(path) as bag:
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'rosbag2.db3').exists()
size = (path / 'rosbag2.db3').stat().st_size
path = tmp_path / 'compress_none'
bag = Writer(path)
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
with bag:
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'compress_none.db3').exists()
assert size == (path / 'compress_none.db3').stat().st_size
path = tmp_path / 'compress_file'
bag = Writer(path)
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
with bag:
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert not (path / 'compress_file.db3').exists()
assert (path / 'compress_file.db3.zstd').exists()
path = tmp_path / 'compress_message'
bag = Writer(path)
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
with bag:
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
bag.write(connection, 42, b'\x00')
bag.write(connection, 666, b'\x01' * 4096)
assert (path / 'metadata.yaml').exists()
assert (path / 'compress_message.db3').exists()
assert size > (path / 'compress_message.db3').stat().st_size
path = tmp_path / 'with_custom_data'
bag = Writer(path)
bag.open()
bag.set_custom_data('key1', 'value1')
with pytest.raises(WriterError, match='non-string value'):
bag.set_custom_data('key1', 42) # type: ignore
bag.close()
assert b'key1: value1' in (path / 'metadata.yaml').read_bytes()
def test_failure_cases(tmp_path: Path) -> None:
"""Test writer failure cases."""
with pytest.raises(WriterError, match='exists'):
Writer(tmp_path)
bag = Writer(tmp_path / 'race')
(tmp_path / 'race').mkdir()
with pytest.raises(WriterError, match='exists'):
bag.open()
bag = Writer(tmp_path / 'compress_after_open')
bag.open()
with pytest.raises(WriterError, match='already open'):
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
bag = Writer(tmp_path / 'topic')
with pytest.raises(WriterError, match='was not opened'):
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'write')
with pytest.raises(WriterError, match='was not opened'):
bag.write(
Connection(
1,
'/tf',
'tf_msgs/msg/tf2',
'',
'',
0,
ConnectionExtRosbag2('cdr', ''),
None,
),
0,
b'',
)
bag = Writer(tmp_path / 'topic')
bag.open()
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
with pytest.raises(WriterError, match='only be added once'):
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
bag = Writer(tmp_path / 'notopic')
bag.open()
connection = Connection(
1,
'/tf',
'tf_msgs/msg/tf2',
'',
'',
0,
ConnectionExtRosbag2('cdr', ''),
None,
)
with pytest.raises(WriterError, match='unknown connection'):
bag.write(connection, 42, b'\x00')
+201
View File
@@ -0,0 +1,201 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Writer tests."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import Mock
import pytest
from rosbags.rosbag1 import Writer, WriterError
if TYPE_CHECKING:
from pathlib import Path
from typing import Optional
def test_no_overwrite(tmp_path: Path) -> None:
"""Test writer does not touch existing files."""
path = tmp_path / 'test.bag'
path.write_text('foo')
with pytest.raises(WriterError, match='exists'):
Writer(path).open()
path.unlink()
writer = Writer(path)
path.write_text('foo')
with pytest.raises(WriterError, match='exists'):
writer.open()
def test_empty(tmp_path: Path) -> None:
"""Test empty bag."""
path = tmp_path / 'test.bag'
with Writer(path):
pass
data = path.read_bytes()
assert len(data) == 13 + 4096
def test_add_connection(tmp_path: Path) -> None:
"""Test adding of connections."""
path = tmp_path / 'test.bag'
with pytest.raises(WriterError, match='not opened'):
Writer(path).add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
with Writer(path) as writer:
res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
assert res.id == 0
data = path.read_bytes()
assert data.count(b'MESSAGE_DEFINITION') == 2
assert data.count(b'HASH') == 2
path.unlink()
with Writer(path) as writer:
res = writer.add_connection('/foo', 'std_msgs/msg/Int8')
assert res.id == 0
data = path.read_bytes()
assert data.count(b'int8 data') == 2
assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2
path.unlink()
with Writer(path) as writer:
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
with pytest.raises(WriterError, match='can only be added once'):
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
path.unlink()
with Writer(path) as writer:
res1 = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
res2 = writer.add_connection(
'/foo',
'test_msgs/msg/Test',
'MESSAGE_DEFINITION',
'HASH',
callerid='src',
)
res3 = writer.add_connection(
'/foo',
'test_msgs/msg/Test',
'MESSAGE_DEFINITION',
'HASH',
latching=1,
)
assert (res1.id, res2.id, res3.id) == (0, 1, 2)
def test_write_errors(tmp_path: Path) -> None:
"""Test write errors."""
path = tmp_path / 'test.bag'
with pytest.raises(WriterError, match='not opened'):
Writer(path).write(Mock(), 42, b'DEADBEEF')
with Writer(path) as writer, \
pytest.raises(WriterError, match='is no connection'):
writer.write(Mock(), 42, b'DEADBEEF')
path.unlink()
def test_write_simple(tmp_path: Path) -> None:
"""Test writing of messages."""
path = tmp_path / 'test.bag'
with Writer(path) as writer:
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
conn_latching = writer.add_connection(
'/foo',
'test_msgs/msg/Test',
'MESSAGE_DEFINITION',
'HASH',
latching=1,
)
conn_bar = writer.add_connection(
'/bar',
'test_msgs/msg/Bar',
'OTHER_DEFINITION',
'HASH',
callerid='src',
)
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
writer.write(conn_foo, 42, b'DEADBEEF')
writer.write(conn_latching, 42, b'DEADBEEF')
writer.write(conn_bar, 43, b'SECRET')
writer.write(conn_bar, 43, b'SUBSEQUENT')
res = path.read_bytes()
assert res.count(b'op=\x05') == 1
assert res.count(b'op=\x06') == 1
assert res.count(b'MESSAGE_DEFINITION') == 4
assert res.count(b'latching=1') == 2
assert res.count(b'OTHER_DEFINITION') == 2
assert res.count(b'callerid=src') == 2
assert res.count(b'NEVER_WRITTEN') == 2
assert res.count(b'DEADBEEF') == 2
assert res.count(b'SECRET') == 1
assert res.count(b'SUBSEQUENT') == 1
path.unlink()
with Writer(path) as writer:
writer.chunk_threshold = 256
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
conn_latching = writer.add_connection(
'/foo',
'test_msgs/msg/Test',
'MESSAGE_DEFINITION',
'HASH',
latching=1,
)
conn_bar = writer.add_connection(
'/bar',
'test_msgs/msg/Bar',
'OTHER_DEFINITION',
'HASH',
callerid='src',
)
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
writer.write(conn_foo, 42, b'DEADBEEF')
writer.write(conn_latching, 42, b'DEADBEEF')
writer.write(conn_bar, 43, b'SECRET')
writer.write(conn_bar, 43, b'SUBSEQUENT')
res = path.read_bytes()
assert res.count(b'op=\x05') == 2
assert res.count(b'op=\x06') == 2
assert res.count(b'MESSAGE_DEFINITION') == 4
assert res.count(b'latching=1') == 2
assert res.count(b'OTHER_DEFINITION') == 2
assert res.count(b'callerid=src') == 2
assert res.count(b'NEVER_WRITTEN') == 2
assert res.count(b'DEADBEEF') == 2
assert res.count(b'SECRET') == 1
assert res.count(b'SUBSEQUENT') == 1
path.unlink()
def test_compression_errors(tmp_path: Path) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
with Writer(path) as writer, \
pytest.raises(WriterError, match='already open'):
writer.set_compression(writer.CompressionFormat.BZ2)
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
writer = Writer(path)
if fmt:
writer.set_compression(fmt)
with writer:
conn = writer.add_connection('/foo', 'std_msgs/msg/Int8')
writer.write(conn, 42, b'\x42')
data = path.read_bytes()
assert data.count(f'compression={fmt.name.lower() if fmt else "none"}'.encode()) == 1