Add 'rosbags/' from commit 'c80625df279c154c6ec069cbac30faa319755e47'
git-subtree-dir: rosbags git-subtree-mainline:48df1fbdf4git-subtree-split:c80625df27
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag tests."""
|
||||
@@ -0,0 +1,445 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reference CDR message serializer and deserializer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from struct import Struct, pack_into, unpack_from
|
||||
from typing import TYPE_CHECKING, Dict, List, Union, cast
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from rosbags.serde.messages import SerdeError, get_msgdef
|
||||
from rosbags.serde.typing import Msgdef
|
||||
from rosbags.serde.utils import SIZEMAP, Valtype
|
||||
from rosbags.typesys import types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Tuple
|
||||
|
||||
from rosbags.serde.typing import Descriptor
|
||||
|
||||
Array = Union[List[Msgdef], List[str], numpy.ndarray]
|
||||
BasetypeMap = Dict[str, Struct]
|
||||
BASETYPEMAP_LE: BasetypeMap = {
|
||||
'bool': Struct('?'),
|
||||
'int8': Struct('b'),
|
||||
'int16': Struct('<h'),
|
||||
'int32': Struct('<i'),
|
||||
'int64': Struct('<q'),
|
||||
'uint8': Struct('B'),
|
||||
'uint16': Struct('<H'),
|
||||
'uint32': Struct('<I'),
|
||||
'uint64': Struct('<Q'),
|
||||
'float32': Struct('<f'),
|
||||
'float64': Struct('<d'),
|
||||
}
|
||||
|
||||
BASETYPEMAP_BE: BasetypeMap = {
|
||||
'bool': Struct('?'),
|
||||
'int8': Struct('b'),
|
||||
'int16': Struct('>h'),
|
||||
'int32': Struct('>i'),
|
||||
'int64': Struct('>q'),
|
||||
'uint8': Struct('B'),
|
||||
'uint16': Struct('>H'),
|
||||
'uint32': Struct('>I'),
|
||||
'uint64': Struct('>Q'),
|
||||
'float32': Struct('>f'),
|
||||
'float64': Struct('>d'),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_number(rawdata: bytes, bmap: BasetypeMap, pos: int, basetype: str) \
|
||||
-> Tuple[Union[bool, float, int], int]:
|
||||
"""Deserialize a single boolean, float, or int.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Read position.
|
||||
basetype: Number type string.
|
||||
|
||||
Returns:
|
||||
Deserialized number and new read position.
|
||||
|
||||
"""
|
||||
dtype, size = bmap[basetype], SIZEMAP[basetype]
|
||||
pos = (pos + size - 1) & -size
|
||||
return dtype.unpack_from(rawdata, pos)[0], pos + size
|
||||
|
||||
|
||||
def deserialize_string(rawdata: bytes, bmap: BasetypeMap, pos: int) \
|
||||
-> Tuple[str, int]:
|
||||
"""Deserialize a string value.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Read position.
|
||||
|
||||
Returns:
|
||||
Deserialized string and new read position.
|
||||
|
||||
"""
|
||||
pos = (pos + 4 - 1) & -4
|
||||
length = bmap['int32'].unpack_from(rawdata, pos)[0]
|
||||
val = bytes(rawdata[pos + 4:pos + 4 + length - 1])
|
||||
return val.decode(), pos + 4 + length
|
||||
|
||||
|
||||
def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, desc: Descriptor) \
|
||||
-> Tuple[Array, int]:
|
||||
"""Deserialize an array of items of same type.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Read position.
|
||||
num: Number of elements.
|
||||
desc: Element type descriptor.
|
||||
|
||||
Returns:
|
||||
Deserialized array and new read position.
|
||||
|
||||
Raises:
|
||||
SerdeError: Unexpected element type.
|
||||
|
||||
"""
|
||||
if desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
strs = []
|
||||
while (num := num - 1) >= 0:
|
||||
val, pos = deserialize_string(rawdata, bmap, pos)
|
||||
strs.append(val)
|
||||
return strs, pos
|
||||
|
||||
size = SIZEMAP[desc.args]
|
||||
pos = (pos + size - 1) & -size
|
||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
|
||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||
ndarr = ndarr.byteswap() # no inplace on readonly array
|
||||
return ndarr, pos + num * SIZEMAP[desc.args]
|
||||
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
msgs = []
|
||||
while (num := num - 1) >= 0:
|
||||
msg, pos = deserialize_message(rawdata, bmap, pos, desc.args)
|
||||
msgs.append(msg)
|
||||
return msgs, pos
|
||||
|
||||
raise SerdeError(f'Nested arrays {desc!r} are not supported.')
|
||||
|
||||
|
||||
def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msgdef) \
|
||||
-> Tuple[Msgdef, int]:
|
||||
"""Deserialize a message.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Read position.
|
||||
msgdef: Message definition.
|
||||
|
||||
Returns:
|
||||
Deserialized message and new read position.
|
||||
|
||||
"""
|
||||
values: List[Any] = []
|
||||
|
||||
for _, desc in msgdef.fields:
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
obj, pos = deserialize_message(rawdata, bmap, pos, desc.args)
|
||||
values.append(obj)
|
||||
|
||||
elif desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
val, pos = deserialize_string(rawdata, bmap, pos)
|
||||
values.append(val)
|
||||
else:
|
||||
num, pos = deserialize_number(rawdata, bmap, pos, desc.args)
|
||||
values.append(num)
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc, length = desc.args
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, length, subdesc)
|
||||
values.append(arr)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size, pos = deserialize_number(rawdata, bmap, pos, 'int32')
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args[0])
|
||||
values.append(arr)
|
||||
|
||||
return msgdef.cls(*values), pos
|
||||
|
||||
|
||||
def deserialize(rawdata: bytes, typename: str) -> Msgdef:
|
||||
"""Deserialize raw data into a message object.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
typename: Type to deserialize.
|
||||
|
||||
Returns:
|
||||
Deserialized message object.
|
||||
|
||||
"""
|
||||
_, little_endian = unpack_from('BB', rawdata, 0)
|
||||
|
||||
msgdef = get_msgdef(typename, types)
|
||||
obj, _ = deserialize_message(
|
||||
rawdata[4:],
|
||||
BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE,
|
||||
0,
|
||||
msgdef,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_number(
|
||||
rawdata: memoryview,
|
||||
bmap: BasetypeMap,
|
||||
pos: int,
|
||||
basetype: str,
|
||||
val: Union[bool, float, int],
|
||||
) -> int:
|
||||
"""Serialize a single boolean, float, or int.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Write position.
|
||||
basetype: Number type string.
|
||||
val: Value to serialize.
|
||||
|
||||
Returns:
|
||||
Next write position.
|
||||
|
||||
"""
|
||||
dtype, size = bmap[basetype], SIZEMAP[basetype]
|
||||
pos = (pos + size - 1) & -size
|
||||
dtype.pack_into(rawdata, pos, val)
|
||||
return pos + size
|
||||
|
||||
|
||||
def serialize_string(rawdata: memoryview, bmap: BasetypeMap, pos: int, val: str) \
|
||||
-> int:
|
||||
"""Deserialize a string value.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Write position.
|
||||
val: Value to serialize.
|
||||
|
||||
Returns:
|
||||
Next write position.
|
||||
|
||||
"""
|
||||
bval = memoryview(val.encode())
|
||||
length = len(bval) + 1
|
||||
|
||||
pos = (pos + 4 - 1) & -4
|
||||
bmap['int32'].pack_into(rawdata, pos, length)
|
||||
rawdata[pos + 4:pos + 4 + length - 1] = bval
|
||||
return pos + 4 + length
|
||||
|
||||
|
||||
def serialize_array(
|
||||
rawdata: memoryview,
|
||||
bmap: BasetypeMap,
|
||||
pos: int,
|
||||
desc: Descriptor,
|
||||
val: Array,
|
||||
) -> int:
|
||||
"""Serialize an array of items of same type.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Write position.
|
||||
desc: Element type descriptor.
|
||||
val: Value to serialize.
|
||||
|
||||
Returns:
|
||||
Next write position.
|
||||
|
||||
Raises:
|
||||
SerdeError: Unexpected element type.
|
||||
|
||||
"""
|
||||
if desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
for item in val:
|
||||
pos = serialize_string(rawdata, bmap, pos, cast('str', item))
|
||||
return pos
|
||||
|
||||
size = SIZEMAP[desc.args]
|
||||
pos = (pos + size - 1) & -size
|
||||
size *= len(val)
|
||||
val = cast('NDArray[numpy.int_]', val)
|
||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||
val = val.byteswap() # no inplace on readonly array
|
||||
rawdata[pos:pos + size] = memoryview(val.tobytes())
|
||||
return pos + size
|
||||
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
for item in val:
|
||||
pos = serialize_message(rawdata, bmap, pos, item, desc.args)
|
||||
return pos
|
||||
|
||||
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
|
||||
|
||||
|
||||
def serialize_message(
|
||||
rawdata: memoryview,
|
||||
bmap: BasetypeMap,
|
||||
pos: int,
|
||||
message: object,
|
||||
msgdef: Msgdef,
|
||||
) -> int:
|
||||
"""Serialize a message.
|
||||
|
||||
Args:
|
||||
rawdata: Serialized data.
|
||||
bmap: Basetype metadata.
|
||||
pos: Write position.
|
||||
message: Message object.
|
||||
msgdef: Message definition.
|
||||
|
||||
Returns:
|
||||
Next write position.
|
||||
|
||||
"""
|
||||
for fieldname, desc in msgdef.fields:
|
||||
val = getattr(message, fieldname)
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
pos = serialize_message(rawdata, bmap, pos, val, desc.args)
|
||||
|
||||
elif desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
pos = serialize_string(rawdata, bmap, pos, val)
|
||||
else:
|
||||
pos = serialize_number(rawdata, bmap, pos, desc.args, val)
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size = len(val)
|
||||
pos = serialize_number(rawdata, bmap, pos, 'int32', size)
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def get_array_size(desc: Descriptor, val: Array, size: int) -> int:
|
||||
"""Calculate size of an array.
|
||||
|
||||
Args:
|
||||
desc: Element type descriptor.
|
||||
val: Array to calculate size of.
|
||||
size: Current size of message.
|
||||
|
||||
Returns:
|
||||
Size of val in bytes.
|
||||
|
||||
Raises:
|
||||
SerdeError: Unexpected element type.
|
||||
|
||||
"""
|
||||
if desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
for item in val:
|
||||
size = (size + 4 - 1) & -4
|
||||
size += 4 + len(item) + 1
|
||||
return size
|
||||
|
||||
isize = SIZEMAP[desc.args]
|
||||
size = (size + isize - 1) & -isize
|
||||
return size + isize * len(val)
|
||||
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
for item in val:
|
||||
size = get_size(item, desc.args, size)
|
||||
return size
|
||||
|
||||
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
|
||||
|
||||
|
||||
def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int:
|
||||
"""Calculate size of serialzied message.
|
||||
|
||||
Args:
|
||||
message: Message object.
|
||||
msgdef: Message definition.
|
||||
size: Current size of message.
|
||||
|
||||
Returns:
|
||||
Size of message in bytes.
|
||||
|
||||
Raises:
|
||||
SerdeError: Unexpected array length in message.
|
||||
|
||||
"""
|
||||
for fieldname, desc in msgdef.fields:
|
||||
val = getattr(message, fieldname)
|
||||
if desc.valtype == Valtype.MESSAGE:
|
||||
size = get_size(val, desc.args, size)
|
||||
|
||||
elif desc.valtype == Valtype.BASE:
|
||||
if desc.args == 'string':
|
||||
size = (size + 4 - 1) & -4
|
||||
size += 4 + len(val.encode()) + 1
|
||||
else:
|
||||
isize = SIZEMAP[desc.args]
|
||||
size = (size + isize - 1) & -isize
|
||||
size += isize
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc, length = desc.args
|
||||
if len(val) != length:
|
||||
raise SerdeError(f'Unexpected array length: {len(val)} != {length}.')
|
||||
size = get_array_size(subdesc, val, size)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size = (size + 4 - 1) & -4
|
||||
size += 4
|
||||
size = get_array_size(desc.args[0], val, size)
|
||||
|
||||
return size
|
||||
|
||||
|
||||
def serialize(
|
||||
message: object,
|
||||
typename: str,
|
||||
little_endian: bool = sys.byteorder == 'little',
|
||||
) -> memoryview:
|
||||
"""Serialize message object to bytes.
|
||||
|
||||
Args:
|
||||
message: Message object.
|
||||
typename: Type to serialize.
|
||||
little_endian: Should use little endianess.
|
||||
|
||||
Returns:
|
||||
Serialized bytes.
|
||||
|
||||
"""
|
||||
msgdef = get_msgdef(typename, types)
|
||||
size = 4 + get_size(message, msgdef)
|
||||
rawdata = memoryview(bytearray(size))
|
||||
|
||||
pack_into('BB', rawdata, 0, 0, little_endian)
|
||||
pos = serialize_message(
|
||||
rawdata[4:],
|
||||
BASETYPEMAP_LE if little_endian else BASETYPEMAP_BE,
|
||||
0,
|
||||
message,
|
||||
msgdef,
|
||||
)
|
||||
assert pos + 4 == size
|
||||
return rawdata.toreadonly()
|
||||
@@ -0,0 +1,444 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag1to2 converter tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.convert import ConverterError, convert
|
||||
from rosbags.convert.__main__ import main
|
||||
from rosbags.convert.converter import LATCH
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag1, ConnectionExtRosbag2
|
||||
from rosbags.rosbag1 import ReaderError
|
||||
from rosbags.rosbag2 import WriterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def test_cliwrapper(tmp_path: Path) -> None:
|
||||
"""Test cli wrapper."""
|
||||
(tmp_path / 'subdir').mkdir()
|
||||
(tmp_path / 'ros1.bag').write_text('')
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt']), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
assert not cvrt.called
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'no.bag')]), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
assert not cvrt.called
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'ros1.bag')]):
|
||||
main()
|
||||
cvrt.assert_called_with(
|
||||
src=tmp_path / 'ros1.bag',
|
||||
dst=None,
|
||||
exclude_topics=[],
|
||||
include_topics=[],
|
||||
)
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'ros1.bag'),
|
||||
'--dst',
|
||||
str(tmp_path / 'subdir')]), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
assert not cvrt.called
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'ros1.bag'),
|
||||
'--dst',
|
||||
str(tmp_path / 'ros2.bag')]), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
assert not cvrt.called
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'ros1.bag'),
|
||||
'--dst',
|
||||
str(tmp_path / 'target')]):
|
||||
main()
|
||||
cvrt.assert_called_with(
|
||||
src=tmp_path / 'ros1.bag',
|
||||
dst=tmp_path / 'target',
|
||||
exclude_topics=[],
|
||||
include_topics=[],
|
||||
)
|
||||
|
||||
with patch.object(sys, 'argv', ['cvt', str(tmp_path / 'ros1.bag')]), \
|
||||
patch('builtins.print') as mock_print, \
|
||||
patch('rosbags.convert.__main__.convert', side_effect=ConverterError('exc')), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
mock_print.assert_called_with('ERROR: exc')
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt', str(tmp_path / 'subdir')]):
|
||||
main()
|
||||
cvrt.assert_called_with(
|
||||
src=tmp_path / 'subdir',
|
||||
dst=None,
|
||||
exclude_topics=[],
|
||||
include_topics=[],
|
||||
)
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'subdir'),
|
||||
'--dst',
|
||||
str(tmp_path / 'ros1.bag')]), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
assert not cvrt.called
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'subdir'),
|
||||
'--dst',
|
||||
str(tmp_path / 'target.bag')]):
|
||||
main()
|
||||
cvrt.assert_called_with(
|
||||
src=tmp_path / 'subdir',
|
||||
dst=tmp_path / 'target.bag',
|
||||
exclude_topics=[],
|
||||
include_topics=[],
|
||||
)
|
||||
|
||||
with patch.object(sys, 'argv', ['cvt', str(tmp_path / 'subdir')]), \
|
||||
patch('builtins.print') as mock_print, \
|
||||
patch('rosbags.convert.__main__.convert', side_effect=ConverterError('exc')), \
|
||||
pytest.raises(SystemExit):
|
||||
main()
|
||||
mock_print.assert_called_with('ERROR: exc')
|
||||
|
||||
with patch('rosbags.convert.__main__.convert') as cvrt, \
|
||||
patch.object(sys, 'argv', ['cvt',
|
||||
str(tmp_path / 'ros1.bag'),
|
||||
'--exclude-topic',
|
||||
'/foo']):
|
||||
main()
|
||||
cvrt.assert_called_with(
|
||||
src=tmp_path / 'ros1.bag',
|
||||
dst=None,
|
||||
exclude_topics=['/foo'],
|
||||
include_topics=[],
|
||||
)
|
||||
|
||||
|
||||
def test_convert_1to2(tmp_path: Path) -> None:
|
||||
"""Test conversion from rosbag1 to rosbag2."""
|
||||
(tmp_path / 'subdir').mkdir()
|
||||
(tmp_path / 'foo.bag').write_text('')
|
||||
|
||||
with pytest.raises(ConverterError, match='exists already'):
|
||||
convert(Path('foo.bag'), tmp_path / 'subdir')
|
||||
|
||||
with patch('rosbags.convert.converter.Reader1') as reader, \
|
||||
patch('rosbags.convert.converter.Writer2') as writer, \
|
||||
patch('rosbags.convert.converter.get_types_from_msg', return_value={'typ': 'def'}), \
|
||||
patch('rosbags.convert.converter.register_types') as register_types, \
|
||||
patch('rosbags.convert.converter.ros1_to_cdr') as ros1_to_cdr:
|
||||
|
||||
readerinst = reader.return_value.__enter__.return_value
|
||||
writerinst = writer.return_value.__enter__.return_value
|
||||
|
||||
connections = [
|
||||
Connection(1, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False), None),
|
||||
Connection(2, '/topic', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, True), None),
|
||||
Connection(3, '/other', 'typ', 'def', '', -1, ConnectionExtRosbag1(None, False), None),
|
||||
Connection(
|
||||
4,
|
||||
'/other',
|
||||
'typ',
|
||||
'def',
|
||||
'',
|
||||
-1,
|
||||
ConnectionExtRosbag1('caller', False),
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
wconnections = [
|
||||
Connection(1, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None),
|
||||
Connection(2, '/topic', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', LATCH), None),
|
||||
Connection(3, '/other', 'typ', '', '', -1, ConnectionExtRosbag2('cdr', ''), None),
|
||||
]
|
||||
|
||||
readerinst.connections = [
|
||||
connections[0],
|
||||
connections[1],
|
||||
connections[2],
|
||||
connections[3],
|
||||
]
|
||||
|
||||
readerinst.messages.return_value = [
|
||||
(connections[0], 42, b'\x42'),
|
||||
(connections[1], 43, b'\x43'),
|
||||
(connections[2], 44, b'\x44'),
|
||||
(connections[3], 45, b'\x45'),
|
||||
]
|
||||
|
||||
writerinst.connections = []
|
||||
|
||||
def add_connection(*_: Any) -> Connection: # noqa: ANN401
|
||||
"""Mock for Writer.add_connection."""
|
||||
writerinst.connections = [
|
||||
conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
|
||||
]
|
||||
return wconnections[len(writerinst.connections) - 1]
|
||||
|
||||
writerinst.add_connection.side_effect = add_connection
|
||||
|
||||
ros1_to_cdr.return_value = b'666'
|
||||
|
||||
convert(Path('foo.bag'), None)
|
||||
|
||||
reader.assert_called_with(Path('foo.bag'))
|
||||
readerinst.messages.assert_called_with(connections=readerinst.connections)
|
||||
|
||||
writer.assert_called_with(Path('foo'))
|
||||
writerinst.add_connection.assert_has_calls(
|
||||
[
|
||||
call('/topic', 'typ', 'cdr', ''),
|
||||
call('/topic', 'typ', 'cdr', LATCH),
|
||||
call('/other', 'typ', 'cdr', ''),
|
||||
],
|
||||
)
|
||||
writerinst.write.assert_has_calls(
|
||||
[
|
||||
call(wconnections[0], 42, b'666'),
|
||||
call(wconnections[1], 43, b'666'),
|
||||
call(wconnections[2], 44, b'666'),
|
||||
call(wconnections[2], 45, b'666'),
|
||||
],
|
||||
)
|
||||
|
||||
register_types.assert_called_with({'typ': 'def'})
|
||||
ros1_to_cdr.assert_has_calls(
|
||||
[
|
||||
call(b'\x42', 'typ'),
|
||||
call(b'\x43', 'typ'),
|
||||
call(b'\x44', 'typ'),
|
||||
call(b'\x45', 'typ'),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ConverterError, match='No connections left for conversion'):
|
||||
convert(Path('foo.bag'), None, ['/topic', '/other'])
|
||||
|
||||
writerinst.connections.clear()
|
||||
ros1_to_cdr.side_effect = KeyError('exc')
|
||||
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
||||
convert(Path('foo.bag'), None)
|
||||
|
||||
writer.side_effect = WriterError('exc')
|
||||
with pytest.raises(ConverterError, match='Writing destination bag: exc'):
|
||||
convert(Path('foo.bag'), None)
|
||||
|
||||
reader.side_effect = ReaderError('exc')
|
||||
with pytest.raises(ConverterError, match='Reading source bag: exc'):
|
||||
convert(Path('foo.bag'), None)
|
||||
|
||||
|
||||
def test_convert_2to1(tmp_path: Path) -> None:
|
||||
"""Test conversion from rosbag2 to rosbag1."""
|
||||
(tmp_path / 'subdir').mkdir()
|
||||
(tmp_path / 'foo.bag').write_text('')
|
||||
|
||||
with pytest.raises(ConverterError, match='exists already'):
|
||||
convert(Path('subdir'), tmp_path / 'foo.bag')
|
||||
|
||||
with patch('rosbags.convert.converter.Reader2') as reader, \
|
||||
patch('rosbags.convert.converter.Writer1') as writer, \
|
||||
patch('rosbags.convert.converter.cdr_to_ros1') as cdr_to_ros1:
|
||||
|
||||
readerinst = reader.return_value.__enter__.return_value
|
||||
writerinst = writer.return_value.__enter__.return_value
|
||||
|
||||
connections = [
|
||||
Connection(
|
||||
1,
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'',
|
||||
-1,
|
||||
ConnectionExtRosbag2('', ''),
|
||||
None,
|
||||
),
|
||||
Connection(
|
||||
2,
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'',
|
||||
-1,
|
||||
ConnectionExtRosbag2('', LATCH),
|
||||
None,
|
||||
),
|
||||
Connection(
|
||||
3,
|
||||
'/other',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'',
|
||||
-1,
|
||||
ConnectionExtRosbag2('', ''),
|
||||
None,
|
||||
),
|
||||
Connection(
|
||||
4,
|
||||
'/other',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'',
|
||||
-1,
|
||||
ConnectionExtRosbag2('', '0'),
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
wconnections = [
|
||||
Connection(
|
||||
1,
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
-1,
|
||||
ConnectionExtRosbag1(None, False),
|
||||
None,
|
||||
),
|
||||
Connection(
|
||||
2,
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
-1,
|
||||
ConnectionExtRosbag1(None, True),
|
||||
None,
|
||||
),
|
||||
Connection(
|
||||
3,
|
||||
'/other',
|
||||
'std_msgs/msg/Bool',
|
||||
'',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
-1,
|
||||
ConnectionExtRosbag1(None, False),
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
readerinst.connections = [
|
||||
connections[0],
|
||||
connections[1],
|
||||
connections[2],
|
||||
connections[3],
|
||||
]
|
||||
|
||||
readerinst.messages.return_value = [
|
||||
(connections[0], 42, b'\x42'),
|
||||
(connections[1], 43, b'\x43'),
|
||||
(connections[2], 44, b'\x44'),
|
||||
(connections[3], 45, b'\x45'),
|
||||
]
|
||||
|
||||
writerinst.connections = []
|
||||
|
||||
def add_connection(*_: Any) -> Connection: # noqa: ANN401
|
||||
"""Mock for Writer.add_connection."""
|
||||
writerinst.connections = [
|
||||
conn for _, conn in zip(range(len(writerinst.connections) + 1), wconnections)
|
||||
]
|
||||
return wconnections[len(writerinst.connections) - 1]
|
||||
|
||||
writerinst.add_connection.side_effect = add_connection
|
||||
|
||||
cdr_to_ros1.return_value = b'666'
|
||||
|
||||
convert(Path('foo'), None)
|
||||
|
||||
reader.assert_called_with(Path('foo'))
|
||||
reader.return_value.__enter__.return_value.messages.assert_called_with(
|
||||
connections=readerinst.connections,
|
||||
)
|
||||
|
||||
writer.assert_called_with(Path('foo.bag'))
|
||||
writer.return_value.__enter__.return_value.add_connection.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'bool data\n',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
None,
|
||||
0,
|
||||
),
|
||||
call(
|
||||
'/topic',
|
||||
'std_msgs/msg/Bool',
|
||||
'bool data\n',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
None,
|
||||
1,
|
||||
),
|
||||
call(
|
||||
'/other',
|
||||
'std_msgs/msg/Bool',
|
||||
'bool data\n',
|
||||
'8b94c1b53db61fb6aed406028ad6332a',
|
||||
None,
|
||||
0,
|
||||
),
|
||||
],
|
||||
)
|
||||
writer.return_value.__enter__.return_value.write.assert_has_calls(
|
||||
[
|
||||
call(wconnections[0], 42, b'666'),
|
||||
call(wconnections[1], 43, b'666'),
|
||||
call(wconnections[2], 44, b'666'),
|
||||
call(wconnections[2], 45, b'666'),
|
||||
],
|
||||
)
|
||||
|
||||
cdr_to_ros1.assert_has_calls(
|
||||
[
|
||||
call(b'\x42', 'std_msgs/msg/Bool'),
|
||||
call(b'\x43', 'std_msgs/msg/Bool'),
|
||||
call(b'\x44', 'std_msgs/msg/Bool'),
|
||||
call(b'\x45', 'std_msgs/msg/Bool'),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ConverterError, match='No connections left for conversion'):
|
||||
convert(Path('foobag'), None, ['/topic', '/other'])
|
||||
|
||||
writerinst.connections.clear()
|
||||
cdr_to_ros1.side_effect = KeyError('exc')
|
||||
with pytest.raises(ConverterError, match='Converting rosbag: .*exc'):
|
||||
convert(Path('foo'), None)
|
||||
|
||||
writer.side_effect = WriterError('exc')
|
||||
with pytest.raises(ConverterError, match='Writing destination bag: exc'):
|
||||
convert(Path('foo'), None)
|
||||
|
||||
reader.side_effect = ReaderError('exc')
|
||||
with pytest.raises(ConverterError, match='Reading source bag: exc'):
|
||||
convert(Path('foo'), None)
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.highlevel import AnyReader, AnyReaderError
|
||||
from rosbags.interfaces import Connection
|
||||
from rosbags.rosbag1 import Writer as Writer1
|
||||
from rosbags.rosbag2 import Writer as Writer2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
HEADER = b'\x00\x01\x00\x00'
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bags1(tmp_path: Path) -> list[Path]:
|
||||
"""Test data fixture."""
|
||||
paths = [
|
||||
tmp_path / 'ros1_1.bag',
|
||||
tmp_path / 'ros1_2.bag',
|
||||
tmp_path / 'ros1_3.bag',
|
||||
tmp_path / 'bad.bag',
|
||||
]
|
||||
with (Writer1(paths[0])) as writer:
|
||||
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
|
||||
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
|
||||
writer.write(topic1, 1, b'\x01')
|
||||
writer.write(topic2, 2, b'\x02\x00')
|
||||
writer.write(topic1, 9, b'\x09')
|
||||
with (Writer1(paths[1])) as writer:
|
||||
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
|
||||
writer.write(topic1, 5, b'\x05')
|
||||
with (Writer1(paths[2])) as writer:
|
||||
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
|
||||
writer.write(topic2, 15, b'\x15\x00')
|
||||
|
||||
paths[3].touch()
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bags2(tmp_path: Path) -> list[Path]:
|
||||
"""Test data fixture."""
|
||||
paths = [
|
||||
tmp_path / 'ros2_1',
|
||||
tmp_path / 'bad',
|
||||
]
|
||||
with (Writer2(paths[0])) as writer:
|
||||
topic1 = writer.add_connection('/topic1', 'std_msgs/msg/Int8')
|
||||
topic2 = writer.add_connection('/topic2', 'std_msgs/msg/Int16')
|
||||
writer.write(topic1, 1, HEADER + b'\x01')
|
||||
writer.write(topic2, 2, HEADER + b'\x02\x00')
|
||||
writer.write(topic1, 9, HEADER + b'\x09')
|
||||
writer.write(topic1, 5, HEADER + b'\x05')
|
||||
writer.write(topic2, 15, HEADER + b'\x15\x00')
|
||||
|
||||
paths[1].mkdir()
|
||||
(paths[1] / 'metadata.yaml').write_text(':')
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def test_anyreader1(bags1: Sequence[Path]) -> None: # pylint: disable=redefined-outer-name
|
||||
"""Test AnyReader on rosbag1."""
|
||||
# pylint: disable=too-many-statements
|
||||
with pytest.raises(AnyReaderError, match='at least one'):
|
||||
AnyReader([])
|
||||
|
||||
with pytest.raises(AnyReaderError, match='missing'):
|
||||
AnyReader([bags1[0] / 'badname'])
|
||||
|
||||
reader = AnyReader(bags1)
|
||||
with pytest.raises(AssertionError):
|
||||
assert reader.topics
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
next(reader.messages())
|
||||
|
||||
reader = AnyReader(bags1)
|
||||
with pytest.raises(AnyReaderError, match='seems to be empty'):
|
||||
reader.open()
|
||||
assert all(not x.bio for x in reader.readers) # type: ignore[union-attr]
|
||||
|
||||
with AnyReader(bags1[:3]) as reader:
|
||||
assert reader.duration == 15
|
||||
assert reader.start_time == 1
|
||||
assert reader.end_time == 16
|
||||
assert reader.message_count == 5
|
||||
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
|
||||
assert len(reader.topics['/topic1'].connections) == 2
|
||||
assert reader.topics['/topic1'].msgcount == 3
|
||||
assert len(reader.topics['/topic2'].connections) == 2
|
||||
assert reader.topics['/topic2'].msgcount == 2
|
||||
|
||||
gen = reader.messages()
|
||||
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (1, b'\x01')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 1 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic2'
|
||||
assert nxt[1:] == (2, b'\x02\x00')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 2 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (5, b'\x05')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 5 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (9, b'\x09')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 9 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic2'
|
||||
assert nxt[1:] == (15, b'\x15\x00')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 21 # type: ignore
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(connections=reader.topics['/topic1'].connections)
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_anyreader2(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
|
||||
"""Test AnyReader on rosbag2."""
|
||||
# pylint: disable=too-many-statements
|
||||
with pytest.raises(AnyReaderError, match='multiple rosbag2'):
|
||||
AnyReader(bags2)
|
||||
|
||||
with pytest.raises(AnyReaderError, match='YAML'):
|
||||
AnyReader([bags2[1]])
|
||||
|
||||
with AnyReader([bags2[0]]) as reader:
|
||||
assert reader.duration == 15
|
||||
assert reader.start_time == 1
|
||||
assert reader.end_time == 16
|
||||
assert reader.message_count == 5
|
||||
assert list(reader.topics.keys()) == ['/topic1', '/topic2']
|
||||
assert len(reader.topics['/topic1'].connections) == 1
|
||||
assert reader.topics['/topic1'].msgcount == 3
|
||||
assert len(reader.topics['/topic2'].connections) == 1
|
||||
assert reader.topics['/topic2'].msgcount == 2
|
||||
|
||||
gen = reader.messages()
|
||||
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (1, HEADER + b'\x01')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 1 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic2'
|
||||
assert nxt[1:] == (2, HEADER + b'\x02\x00')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 2 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (5, HEADER + b'\x05')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 5 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
assert nxt[1:] == (9, HEADER + b'\x09')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 9 # type: ignore
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic2'
|
||||
assert nxt[1:] == (15, HEADER + b'\x15\x00')
|
||||
msg = reader.deserialize(nxt[2], nxt[0].msgtype)
|
||||
assert msg.data == 21 # type: ignore
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(connections=reader.topics['/topic1'].connections)
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
nxt = next(gen)
|
||||
assert nxt[0].topic == '/topic1'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_anyreader2_autoregister(bags2: list[Path]) -> None: # pylint: disable=redefined-outer-name
|
||||
"""Test AnyReader on rosbag2."""
|
||||
|
||||
class MockReader:
|
||||
"""Mock reader."""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
||||
def __init__(self, paths: list[Path]):
|
||||
"""Initialize mock."""
|
||||
_ = paths
|
||||
self.metadata = {'storage_identifier': 'mcap'}
|
||||
self.connections = [
|
||||
Connection(
|
||||
1,
|
||||
'/foo',
|
||||
'test_msg/msg/Foo',
|
||||
'string foo',
|
||||
'msg',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
Connection(
|
||||
2,
|
||||
'/bar',
|
||||
'test_msg/msg/Bar',
|
||||
'module test_msgs { module msg { struct Bar {string bar;}; }; };',
|
||||
'idl',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
Connection(
|
||||
3,
|
||||
'/baz',
|
||||
'test_msg/msg/Baz',
|
||||
'',
|
||||
'',
|
||||
0,
|
||||
None, # type: ignore
|
||||
self,
|
||||
),
|
||||
]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Unused."""
|
||||
|
||||
with patch('rosbags.highlevel.anyreader.Reader2', MockReader), \
|
||||
patch('rosbags.highlevel.anyreader.register_types') as mock_register_types:
|
||||
AnyReader([bags2[0]]).open()
|
||||
mock_register_types.assert_called_once()
|
||||
assert mock_register_types.call_args[0][0] == {
|
||||
'test_msg/msg/Foo': ([], [('foo', (1, 'string'))]),
|
||||
'test_msgs/msg/Bar': ([], [('bar', (1, 'string'))]),
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Message definition parser tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.typesys import (
|
||||
TypesysError,
|
||||
generate_msgdef,
|
||||
get_types_from_idl,
|
||||
get_types_from_msg,
|
||||
register_types,
|
||||
)
|
||||
from rosbags.typesys.base import Nodetype
|
||||
from rosbags.typesys.types import FIELDDEFS
|
||||
|
||||
MSG = """
|
||||
# comment
|
||||
|
||||
bool b=true
|
||||
int32 global=42
|
||||
float32 f=1.33
|
||||
string str= foo bar\t
|
||||
|
||||
std_msgs/Header header
|
||||
std_msgs/msg/Bool bool
|
||||
test_msgs/Bar sibling
|
||||
float64 base
|
||||
float64[] seq1
|
||||
float64[] seq2
|
||||
float64[4] array
|
||||
"""
|
||||
|
||||
MSG_BOUNDS = """
|
||||
int32[] unbounded_integer_array
|
||||
int32[5] five_integers_array
|
||||
int32[<=5] up_to_five_integers_array
|
||||
|
||||
string string_of_unbounded_size
|
||||
string<=10 up_to_ten_characters_string
|
||||
|
||||
string[<=5] up_to_five_unbounded_strings
|
||||
string<=10[] unbounded_array_of_string_up_to_ten_characters_each
|
||||
string<=10[<=5] up_to_five_strings_up_to_ten_characters_each
|
||||
"""
|
||||
|
||||
MSG_DEFAULTS = """
|
||||
bool b false
|
||||
uint8 i 42
|
||||
uint8 o 0377
|
||||
uint8 h 0xff
|
||||
float32 y -314.15e-2
|
||||
string name1 "John"
|
||||
string name2 'Ringo'
|
||||
int32[] samples [-200, -100, 0, 100, 200]
|
||||
"""
|
||||
|
||||
MULTI_MSG = """
|
||||
std_msgs/Header header
|
||||
byte b
|
||||
char c
|
||||
Other[] o
|
||||
|
||||
================================================================================
|
||||
MSG: std_msgs/Header
|
||||
time time
|
||||
|
||||
================================================================================
|
||||
MSG: test_msgs/Other
|
||||
uint64[3] Header
|
||||
uint32 static = 42
|
||||
"""
|
||||
|
||||
CSTRING_CONFUSION_MSG = """
|
||||
std_msgs/Header header
|
||||
string s
|
||||
|
||||
================================================================================
|
||||
MSG: std_msgs/Header
|
||||
time time
|
||||
"""
|
||||
|
||||
RELSIBLING_MSG = """
|
||||
Header header
|
||||
Other other
|
||||
"""
|
||||
|
||||
IDL_LANG = """
|
||||
// assign different literals and expressions
|
||||
|
||||
#ifndef FOO
|
||||
#define FOO
|
||||
|
||||
#include <global>
|
||||
#include "local"
|
||||
|
||||
const bool g_bool = TRUE;
|
||||
const int8 g_int1 = 7;
|
||||
const int8 g_int2 = 07;
|
||||
const int8 g_int3 = 0x7;
|
||||
const float64 g_float1 = 1.1;
|
||||
const float64 g_float2 = 1e10;
|
||||
const char g_char = 'c';
|
||||
const string g_string1 = "";
|
||||
const string<128> g_string2 = "str" "ing";
|
||||
|
||||
module Foo {
|
||||
const int64 g_expr1 = ~1;
|
||||
const int64 g_expr2 = 2 * 4;
|
||||
};
|
||||
|
||||
#endif
|
||||
"""
|
||||
|
||||
IDL = """
|
||||
// comment in file
|
||||
module test_msgs {
|
||||
// comment in module
|
||||
typedef std_msgs::msg::Bool Bool;
|
||||
|
||||
/**/ /***/ /* block comment */
|
||||
|
||||
/*
|
||||
* block comment
|
||||
*/
|
||||
|
||||
module msg {
|
||||
// comment in submodule
|
||||
typedef Bool Balias;
|
||||
typedef test_msgs::msg::Bar Bar;
|
||||
typedef double d4[4];
|
||||
|
||||
module Foo_Constants {
|
||||
const int32 FOO = 32;
|
||||
const int64 BAR = 64;
|
||||
};
|
||||
|
||||
@comment(type="text", text="ignore")
|
||||
struct Foo {
|
||||
// comment in struct
|
||||
std_msgs::msg::Header header;
|
||||
Balias bool;
|
||||
Bar sibling;
|
||||
double/* comment in member declaration */x;
|
||||
sequence<double> seq1;
|
||||
sequence<double, 4> seq2;
|
||||
d4 array;
|
||||
};
|
||||
};
|
||||
|
||||
struct Bar {
|
||||
int i;
|
||||
};
|
||||
};
|
||||
"""
|
||||
|
||||
IDL_STRINGARRAY = """
|
||||
module test_msgs {
|
||||
module msg {
|
||||
typedef string string__3[3];
|
||||
struct Strings {
|
||||
string__3 values;
|
||||
};
|
||||
};
|
||||
};
|
||||
"""
|
||||
|
||||
|
||||
def test_parse_empty_msg() -> None:
|
||||
"""Test msg parser with empty message."""
|
||||
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
|
||||
assert ret == {'std_msgs/msg/Empty': ([], [])}
|
||||
|
||||
|
||||
def test_parse_bounds_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
|
||||
assert ret == {
|
||||
'test_msgs/msg/Foo': (
|
||||
[],
|
||||
[
|
||||
('unbounded_integer_array', (4, ((1, 'int32'), None))),
|
||||
('five_integers_array', (3, ((1, 'int32'), 5))),
|
||||
('up_to_five_integers_array', (4, ((1, 'int32'), None))),
|
||||
('string_of_unbounded_size', (1, 'string')),
|
||||
('up_to_ten_characters_string', (1, 'string')),
|
||||
('up_to_five_unbounded_strings', (4, ((1, 'string'), None))),
|
||||
('unbounded_array_of_string_up_to_ten_characters_each', (4, ((1, 'string'), None))),
|
||||
('up_to_five_strings_up_to_ten_characters_each', (4, ((1, 'string'), None))),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_parse_defaults_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
|
||||
assert ret == {
|
||||
'test_msgs/msg/Foo': (
|
||||
[],
|
||||
[
|
||||
('b', (1, 'bool')),
|
||||
('i', (1, 'uint8')),
|
||||
('o', (1, 'uint8')),
|
||||
('h', (1, 'uint8')),
|
||||
('y', (1, 'float32')),
|
||||
('name1', (1, 'string')),
|
||||
('name2', (1, 'string')),
|
||||
('samples', (4, ((1, 'int32'), None))),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_parse_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
with pytest.raises(TypesysError, match='Could not parse'):
|
||||
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
|
||||
ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo')
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
consts, fields = ret['test_msgs/msg/Foo']
|
||||
assert consts == [
|
||||
('b', 'bool', True),
|
||||
('global', 'int32', 42),
|
||||
('f', 'float32', 1.33),
|
||||
('str', 'string', 'foo bar'),
|
||||
]
|
||||
assert fields[0][0] == 'header'
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][0] == 'bool'
|
||||
assert fields[1][1][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[2][0] == 'sibling'
|
||||
assert fields[2][1][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[3][1][0] == Nodetype.BASE
|
||||
assert fields[4][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
|
||||
def test_parse_multi_msg() -> None:
|
||||
"""Test multi msg parser."""
|
||||
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
|
||||
assert len(ret) == 3
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
assert 'std_msgs/msg/Header' in ret
|
||||
assert 'test_msgs/msg/Other' in ret
|
||||
fields = ret['test_msgs/msg/Foo'][1]
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][1][1] == 'uint8'
|
||||
assert fields[2][1][1] == 'uint8'
|
||||
consts = ret['test_msgs/msg/Other'][0]
|
||||
assert consts == [('static', 'uint32', 42)]
|
||||
|
||||
|
||||
def test_parse_cstring_confusion() -> None:
|
||||
"""Test if msg separator is confused with const string."""
|
||||
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
|
||||
assert len(ret) == 2
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
assert 'std_msgs/msg/Header' in ret
|
||||
consts, fields = ret['test_msgs/msg/Foo']
|
||||
assert consts == []
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][1][1] == 'string'
|
||||
|
||||
|
||||
def test_parse_relative_siblings_msg() -> None:
|
||||
"""Test relative siblings with msg parser."""
|
||||
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
|
||||
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||
assert ret['test_msgs/msg/Foo'][1][1][1][1] == 'test_msgs/msg/Other'
|
||||
|
||||
ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo')
|
||||
assert ret['rel_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
|
||||
|
||||
|
||||
def test_parse_idl() -> None:
|
||||
"""Test idl parser."""
|
||||
ret = get_types_from_idl(IDL_LANG)
|
||||
assert ret == {}
|
||||
|
||||
ret = get_types_from_idl(IDL)
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
consts, fields = ret['test_msgs/msg/Foo']
|
||||
assert consts == [('FOO', 'int32', 32), ('BAR', 'int64', 64)]
|
||||
assert fields[0][0] == 'header'
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][0] == 'bool'
|
||||
assert fields[1][1][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[2][0] == 'sibling'
|
||||
assert fields[2][1][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[3][1][0] == Nodetype.BASE
|
||||
assert fields[4][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
assert 'test_msgs/Bar' in ret
|
||||
consts, fields = ret['test_msgs/Bar']
|
||||
assert consts == []
|
||||
assert len(fields) == 1
|
||||
assert fields[0][0] == 'i'
|
||||
assert fields[0][1][1] == 'int'
|
||||
|
||||
ret = get_types_from_idl(IDL_STRINGARRAY)
|
||||
consts, fields = ret['test_msgs/msg/Strings']
|
||||
assert consts == []
|
||||
assert len(fields) == 1
|
||||
assert fields[0][0] == 'values'
|
||||
assert fields[0][1] == (Nodetype.ARRAY, ((Nodetype.BASE, 'string'), 3))
|
||||
|
||||
|
||||
def test_register_types() -> None:
|
||||
"""Test type registeration."""
|
||||
assert 'foo' not in FIELDDEFS
|
||||
register_types({})
|
||||
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
|
||||
assert 'foo' in FIELDDEFS
|
||||
|
||||
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
|
||||
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
|
||||
|
||||
with pytest.raises(TypesysError, match='different definition'):
|
||||
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
|
||||
|
||||
|
||||
def test_generate_msgdef() -> None:
|
||||
"""Test message definition generator."""
|
||||
res = generate_msgdef('std_msgs/msg/Header')
|
||||
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
|
||||
|
||||
res = generate_msgdef('geometry_msgs/msg/PointStamped')
|
||||
assert res[0].split(f'{"=" * 80}\n') == [
|
||||
'std_msgs/Header header\ngeometry_msgs/Point point\n',
|
||||
'MSG: std_msgs/Header\nuint32 seq\ntime stamp\nstring frame_id\n',
|
||||
'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n',
|
||||
]
|
||||
|
||||
res = generate_msgdef('geometry_msgs/msg/Twist')
|
||||
assert res[0].split(f'{"=" * 80}\n') == [
|
||||
'geometry_msgs/Vector3 linear\ngeometry_msgs/Vector3 angular\n',
|
||||
'MSG: geometry_msgs/Vector3\nfloat64 x\nfloat64 y\nfloat64 z\n',
|
||||
]
|
||||
|
||||
res = generate_msgdef('shape_msgs/msg/Mesh')
|
||||
assert res[0].split(f'{"=" * 80}\n') == [
|
||||
'shape_msgs/MeshTriangle[] triangles\ngeometry_msgs/Point[] vertices\n',
|
||||
'MSG: shape_msgs/MeshTriangle\nuint32[3] vertex_indices\n',
|
||||
'MSG: geometry_msgs/Point\nfloat64 x\nfloat64 y\nfloat64 z\n',
|
||||
]
|
||||
|
||||
res = generate_msgdef('shape_msgs/msg/Plane')
|
||||
assert res[0] == 'float64[4] coef\n'
|
||||
|
||||
res = generate_msgdef('sensor_msgs/msg/MultiEchoLaserScan')
|
||||
assert len(res[0].split('=' * 80)) == 3
|
||||
|
||||
register_types(get_types_from_msg('time[3] times\nuint8 foo=42', 'foo_msgs/Timelist'))
|
||||
res = generate_msgdef('foo_msgs/msg/Timelist')
|
||||
assert res[0] == 'uint8 foo=42\ntime[3] times\n'
|
||||
|
||||
with pytest.raises(TypesysError, match='is unknown'):
|
||||
generate_msgdef('foo_msgs/msg/Badname')
|
||||
@@ -0,0 +1,751 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import zstandard
|
||||
|
||||
from rosbags.rosbag2 import Reader, ReaderError, Writer
|
||||
|
||||
from .test_serde import MSG_JOINT, MSG_MAGN, MSG_MAGN_BIG, MSG_POLY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import BinaryIO, Iterable
|
||||
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
METADATA = """
|
||||
rosbag2_bagfile_information:
|
||||
version: 4
|
||||
storage_identifier: sqlite3
|
||||
relative_file_paths:
|
||||
- db.db3{extension}
|
||||
duration:
|
||||
nanoseconds: 42
|
||||
starting_time:
|
||||
nanoseconds_since_epoch: 666
|
||||
message_count: 4
|
||||
topics_with_message_count:
|
||||
- topic_metadata:
|
||||
name: /poly
|
||||
type: geometry_msgs/msg/Polygon
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 1
|
||||
- topic_metadata:
|
||||
name: /magn
|
||||
type: sensor_msgs/msg/MagneticField
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 2
|
||||
- topic_metadata:
|
||||
name: /joint
|
||||
type: trajectory_msgs/msg/JointTrajectory
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 1
|
||||
compression_format: {compression_format}
|
||||
compression_mode: {compression_mode}
|
||||
"""
|
||||
|
||||
METADATA_EMPTY = """
|
||||
rosbag2_bagfile_information:
|
||||
version: 6
|
||||
storage_identifier: sqlite3
|
||||
relative_file_paths:
|
||||
- db.db3
|
||||
duration:
|
||||
nanoseconds: 0
|
||||
starting_time:
|
||||
nanoseconds_since_epoch: 0
|
||||
message_count: 0
|
||||
topics_with_message_count: []
|
||||
compression_format: ""
|
||||
compression_mode: ""
|
||||
files:
|
||||
- duration:
|
||||
nanoseconds: 0
|
||||
message_count: 0
|
||||
path: db.db3
|
||||
starting_time:
|
||||
nanoseconds_since_epoch: 0
|
||||
custom_data:
|
||||
key1: value1
|
||||
key2: value2
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(params=['none', 'file', 'message'])
|
||||
def bag(request: SubRequest, tmp_path: Path) -> Path:
|
||||
"""Manually contruct bag."""
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='' if request.param != 'file' else '.zstd',
|
||||
compression_format='""' if request.param == 'none' else 'zstd',
|
||||
compression_mode='""' if request.param == 'none' else request.param.upper(),
|
||||
),
|
||||
)
|
||||
|
||||
comp = zstandard.ZstdCompressor()
|
||||
|
||||
dbpath = tmp_path / 'db.db3'
|
||||
dbh = sqlite3.connect(dbpath)
|
||||
dbh.executescript(Writer.SQLITE_SCHEMA)
|
||||
|
||||
cur = dbh.cursor()
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(1, '/poly', 'geometry_msgs/msg/Polygon', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(2, '/magn', 'sensor_msgs/msg/MagneticField', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(3, '/joint', 'trajectory_msgs/msg/JointTrajectory', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(1, 1, 666, MSG_POLY[0] if request.param != 'message' else comp.compress(MSG_POLY[0])),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(2, 2, 708, MSG_MAGN[0] if request.param != 'message' else comp.compress(MSG_MAGN[0])),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(
|
||||
3,
|
||||
2,
|
||||
708,
|
||||
MSG_MAGN_BIG[0] if request.param != 'message' else comp.compress(MSG_MAGN_BIG[0]),
|
||||
),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(4, 3, 708, MSG_JOINT[0] if request.param != 'message' else comp.compress(MSG_JOINT[0])),
|
||||
)
|
||||
dbh.commit()
|
||||
|
||||
if request.param == 'file':
|
||||
with dbpath.open('rb') as ifh, (tmp_path / 'db.db3.zstd').open('wb') as ofh:
|
||||
comp.copy_stream(ifh, ofh)
|
||||
dbpath.unlink()
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_empty_bag(tmp_path: Path) -> None:
|
||||
"""Test bags with broken fs layout."""
|
||||
(tmp_path / 'metadata.yaml').write_text(METADATA_EMPTY)
|
||||
dbpath = tmp_path / 'db.db3'
|
||||
dbh = sqlite3.connect(dbpath)
|
||||
dbh.executescript(Writer.SQLITE_SCHEMA)
|
||||
|
||||
with Reader(tmp_path) as reader:
|
||||
assert reader.message_count == 0
|
||||
assert reader.start_time == 2**63 - 1
|
||||
assert reader.end_time == 0
|
||||
assert reader.duration == 0
|
||||
assert not list(reader.messages())
|
||||
assert reader.custom_data['key1'] == 'value1'
|
||||
assert reader.custom_data['key2'] == 'value2'
|
||||
|
||||
|
||||
def test_reader(bag: Path) -> None:
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
with Reader(bag) as reader:
|
||||
assert reader.duration == 43
|
||||
assert reader.start_time == 666
|
||||
assert reader.end_time == 709
|
||||
assert reader.message_count == 4
|
||||
if reader.compression_mode:
|
||||
assert reader.compression_format == 'zstd'
|
||||
assert [x.id for x in reader.connections] == [1, 2, 3]
|
||||
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
|
||||
gen = reader.messages()
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
|
||||
assert timestamp == 666
|
||||
assert rawdata == MSG_POLY[0]
|
||||
|
||||
for idx in range(2):
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
assert timestamp == 708
|
||||
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_message_filters(bag: Path) -> None:
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag) as reader:
|
||||
magn_connections = [x for x in reader.connections if x.topic == '/magn']
|
||||
gen = reader.messages(connections=magn_connections)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(stop=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(connections=magn_connections, stop=667)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=666, stop=666)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_user_errors(bag: Path) -> None:
|
||||
"""Test user errors."""
|
||||
reader = Reader(bag)
|
||||
with pytest.raises(ReaderError, match='Rosbag is not open'):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path) -> None:
|
||||
"""Test bags with broken fs layout."""
|
||||
with pytest.raises(ReaderError, match='not read metadata'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata = tmp_path / 'metadata.yaml'
|
||||
|
||||
metadata.write_text('')
|
||||
with pytest.raises(ReaderError, match='not read'), \
|
||||
mock.patch.object(Path, 'read_text', side_effect=PermissionError):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(' invalid:\nthis is not yaml')
|
||||
with pytest.raises(ReaderError, match='not load YAML from'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text('foo:')
|
||||
with pytest.raises(ReaderError, match='key is missing'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('version: 4', 'version: 999'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='version 999'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'hdf5'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Storage plugin'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='files are missing'):
|
||||
Reader(tmp_path)
|
||||
|
||||
(tmp_path / 'db.db3').write_text('')
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('cdr', 'bson'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Serialization format'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='"gz"',
|
||||
compression_mode='"file"',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Compression format'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='not open database'), \
|
||||
Reader(tmp_path) as reader:
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def write_record(bio: BinaryIO, opcode: int, records: Iterable[bytes]) -> None:
|
||||
"""Write record."""
|
||||
data = b''.join(records)
|
||||
bio.write(bytes([opcode]) + struct.pack('<Q', len(data)) + data)
|
||||
|
||||
|
||||
def make_string(text: str) -> bytes:
|
||||
"""Serialize string."""
|
||||
data = text.encode()
|
||||
return struct.pack('<I', len(data)) + data
|
||||
|
||||
|
||||
MCAP_HEADER = b'\x89MCAP0\r\n'
|
||||
|
||||
SCHEMAS = [
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
make_string('geometry_msgs/msg/Polygon'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
make_string('sensor_msgs/msg/MagneticField'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x03,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
make_string('trajectory_msgs/msg/JointTrajectory'),
|
||||
make_string('ros2msg'),
|
||||
make_string('string foo'),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CHANNELS = [
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
struct.pack('<H', 1),
|
||||
make_string('/poly'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<H', 2),
|
||||
make_string('/magn'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
(
|
||||
0x04,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<H', 3),
|
||||
make_string('/joint'),
|
||||
make_string('cdr'),
|
||||
make_string(''),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=['unindexed', 'partially_indexed', 'indexed', 'chunked_unindexed', 'chunked_indexed'],
|
||||
)
|
||||
def bag_mcap(request: SubRequest, tmp_path: Path) -> Path:
|
||||
"""Manually contruct mcap bag."""
|
||||
# pylint: disable=too-many-locals
|
||||
# pylint: disable=too-many-statements
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='.mcap',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'mcap'),
|
||||
)
|
||||
|
||||
path = tmp_path / 'db.db3.mcap'
|
||||
bio: BinaryIO
|
||||
messages: list[tuple[int, int, int]] = []
|
||||
chunks = []
|
||||
with path.open('wb') as bio:
|
||||
realbio = bio
|
||||
bio.write(MCAP_HEADER)
|
||||
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
bio = BytesIO()
|
||||
messages = []
|
||||
|
||||
write_record(bio, *SCHEMAS[0])
|
||||
write_record(bio, *CHANNELS[0])
|
||||
messages.append((1, 666, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 1),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
MSG_POLY[0],
|
||||
),
|
||||
)
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
assert isinstance(bio, BytesIO)
|
||||
chunk_start = realbio.tell()
|
||||
compression = make_string('')
|
||||
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
write_record(
|
||||
realbio,
|
||||
0x06,
|
||||
(
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
uncompressed_size,
|
||||
struct.pack('<I', 0),
|
||||
compression,
|
||||
compressed_size,
|
||||
bio.getbuffer(),
|
||||
),
|
||||
)
|
||||
message_index_offsets = []
|
||||
message_index_start = realbio.tell()
|
||||
for channel_id, group in groupby(messages, key=lambda x: x[0]):
|
||||
message_index_offsets.append((channel_id, realbio.tell()))
|
||||
tpls = [y for x in group for y in x[1:]]
|
||||
write_record(
|
||||
realbio,
|
||||
0x07,
|
||||
(
|
||||
struct.pack('<H', channel_id),
|
||||
struct.pack('<I', 8 * len(tpls)),
|
||||
struct.pack('<' + 'Q' * len(tpls), *tpls),
|
||||
),
|
||||
)
|
||||
chunk = [
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', chunk_start),
|
||||
struct.pack('<Q', message_index_start - chunk_start),
|
||||
struct.pack('<I', 10 * len(message_index_offsets)),
|
||||
*(struct.pack('<HQ', *x) for x in message_index_offsets),
|
||||
struct.pack('<Q',
|
||||
realbio.tell() - message_index_start),
|
||||
compression,
|
||||
compressed_size,
|
||||
uncompressed_size,
|
||||
]
|
||||
chunks.append(chunk)
|
||||
bio = BytesIO()
|
||||
messages = []
|
||||
|
||||
write_record(bio, *SCHEMAS[1])
|
||||
write_record(bio, *CHANNELS[1])
|
||||
messages.append((2, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_MAGN[0],
|
||||
),
|
||||
)
|
||||
messages.append((2, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 2),
|
||||
struct.pack('<I', 2),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_MAGN_BIG[0],
|
||||
),
|
||||
)
|
||||
|
||||
write_record(bio, *SCHEMAS[2])
|
||||
write_record(bio, *CHANNELS[2])
|
||||
messages.append((3, 708, bio.tell()))
|
||||
write_record(
|
||||
bio,
|
||||
0x05,
|
||||
(
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<I', 1),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
MSG_JOINT[0],
|
||||
),
|
||||
)
|
||||
|
||||
if request.param.startswith('chunked'):
|
||||
assert isinstance(bio, BytesIO)
|
||||
chunk_start = realbio.tell()
|
||||
compression = make_string('')
|
||||
uncompressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
compressed_size = struct.pack('<Q', len(bio.getbuffer()))
|
||||
write_record(
|
||||
realbio,
|
||||
0x06,
|
||||
(
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
uncompressed_size,
|
||||
struct.pack('<I', 0),
|
||||
compression,
|
||||
compressed_size,
|
||||
bio.getbuffer(),
|
||||
),
|
||||
)
|
||||
message_index_offsets = []
|
||||
message_index_start = realbio.tell()
|
||||
for channel_id, group in groupby(messages, key=lambda x: x[0]):
|
||||
message_index_offsets.append((channel_id, realbio.tell()))
|
||||
tpls = [y for x in group for y in x[1:]]
|
||||
write_record(
|
||||
realbio,
|
||||
0x07,
|
||||
(
|
||||
struct.pack('<H', channel_id),
|
||||
struct.pack('<I', 8 * len(tpls)),
|
||||
struct.pack('<' + 'Q' * len(tpls), *tpls),
|
||||
),
|
||||
)
|
||||
chunk = [
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<Q', chunk_start),
|
||||
struct.pack('<Q', message_index_start - chunk_start),
|
||||
struct.pack('<I', 10 * len(message_index_offsets)),
|
||||
*(struct.pack('<HQ', *x) for x in message_index_offsets),
|
||||
struct.pack('<Q',
|
||||
realbio.tell() - message_index_start),
|
||||
compression,
|
||||
compressed_size,
|
||||
uncompressed_size,
|
||||
]
|
||||
chunks.append(chunk)
|
||||
bio = realbio
|
||||
messages = []
|
||||
|
||||
if request.param in ['indexed', 'partially_indexed', 'chunked_indexed']:
|
||||
summary_start = bio.tell()
|
||||
for schema in SCHEMAS:
|
||||
write_record(bio, *schema)
|
||||
if request.param != 'partially_indexed':
|
||||
for channel in CHANNELS:
|
||||
write_record(bio, *channel)
|
||||
if request.param == 'chunked_indexed':
|
||||
for chunk in chunks:
|
||||
write_record(bio, 0x08, chunk)
|
||||
|
||||
summary_offset_start = 0
|
||||
write_record(bio, 0x0a, (b'ignored',))
|
||||
write_record(
|
||||
bio,
|
||||
0x0b,
|
||||
(
|
||||
struct.pack('<Q', 4),
|
||||
struct.pack('<H', 3),
|
||||
struct.pack('<I', 3),
|
||||
struct.pack('<I', 0),
|
||||
struct.pack('<I', 0),
|
||||
struct.pack('<I', 0 if request.param == 'indexed' else 1),
|
||||
struct.pack('<Q', 666),
|
||||
struct.pack('<Q', 708),
|
||||
struct.pack('<I', 0),
|
||||
),
|
||||
)
|
||||
write_record(bio, 0x0d, (b'ignored',))
|
||||
write_record(bio, 0xff, (b'ignored',))
|
||||
else:
|
||||
summary_start = 0
|
||||
summary_offset_start = 0
|
||||
|
||||
write_record(
|
||||
bio,
|
||||
0x02,
|
||||
(
|
||||
struct.pack('<Q', summary_start),
|
||||
struct.pack('<Q', summary_offset_start),
|
||||
struct.pack('<I', 0),
|
||||
),
|
||||
)
|
||||
bio.write(MCAP_HEADER)
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_reader_mcap(bag_mcap: Path) -> None:
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
with Reader(bag_mcap) as reader:
|
||||
assert reader.duration == 43
|
||||
assert reader.start_time == 666
|
||||
assert reader.end_time == 709
|
||||
assert reader.message_count == 4
|
||||
if reader.compression_mode:
|
||||
assert reader.compression_format == 'zstd'
|
||||
assert [x.id for x in reader.connections] == [1, 2, 3]
|
||||
assert [*reader.topics.keys()] == ['/poly', '/magn', '/joint']
|
||||
gen = reader.messages()
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
assert connection.msgtype == 'geometry_msgs/msg/Polygon'
|
||||
assert timestamp == 666
|
||||
assert rawdata == MSG_POLY[0]
|
||||
|
||||
for idx in range(2):
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
assert connection.msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
assert timestamp == 708
|
||||
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
|
||||
|
||||
connection, timestamp, rawdata = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
assert connection.msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_message_filters_mcap(bag_mcap: Path) -> None:
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag_mcap) as reader:
|
||||
magn_connections = [x for x in reader.connections if x.topic == '/magn']
|
||||
gen = reader.messages(connections=magn_connections)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/magn'
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/joint'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(stop=667)
|
||||
connection, _, _ = next(gen)
|
||||
assert connection.topic == '/poly'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(connections=magn_connections, stop=667)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=666, stop=666)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_bag_mcap_files(tmp_path: Path) -> None:
|
||||
"""Test bad mcap files."""
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='.mcap',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'mcap'),
|
||||
)
|
||||
|
||||
path = tmp_path / 'db.db3.mcap'
|
||||
path.touch()
|
||||
reader = Reader(tmp_path)
|
||||
path.unlink()
|
||||
with pytest.raises(ReaderError, match='Could not open'):
|
||||
reader.open()
|
||||
|
||||
path.touch()
|
||||
with pytest.raises(ReaderError, match='seems to be empty'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
path.write_bytes(b'xxxxxxxx')
|
||||
with pytest.raises(ReaderError, match='magic is invalid'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
path.write_bytes(b'\x89MCAP0\r\n\xFF')
|
||||
with pytest.raises(ReaderError, match='Unexpected record'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
with path.open('wb') as bio:
|
||||
bio.write(b'\x89MCAP0\r\n')
|
||||
write_record(bio, 0x01, (make_string('ros1'), make_string('test_mcap')))
|
||||
with pytest.raises(ReaderError, match='Profile is not'):
|
||||
Reader(tmp_path).open()
|
||||
|
||||
with path.open('wb') as bio:
|
||||
bio.write(b'\x89MCAP0\r\n')
|
||||
write_record(bio, 0x01, (make_string('ros2'), make_string('test_mcap')))
|
||||
with pytest.raises(ReaderError, match='File end magic is invalid'):
|
||||
Reader(tmp_path).open()
|
||||
@@ -0,0 +1,422 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from struct import pack
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Reader, ReaderError
|
||||
from rosbags.rosbag1.reader import IndexData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
|
||||
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
|
||||
"""Serialize record header."""
|
||||
if isinstance(data, dict):
|
||||
fields = []
|
||||
for key, value in data.items():
|
||||
field = b'='.join([key.encode(), value])
|
||||
fields.append(pack('<L', len(field)) + field)
|
||||
data = b''.join(fields)
|
||||
return pack('<L', len(data)) + data
|
||||
|
||||
|
||||
def create_default_header() -> dict[str, bytes]:
|
||||
"""Create empty rosbag header."""
|
||||
return {
|
||||
'op': b'\x03',
|
||||
'conn_count': pack('<L', 0),
|
||||
'chunk_count': pack('<L', 0),
|
||||
}
|
||||
|
||||
|
||||
def create_connection(
|
||||
cid: int = 1,
|
||||
topic: int = 0,
|
||||
typ: int = 0,
|
||||
) -> tuple[dict[str, bytes], dict[str, bytes]]:
|
||||
"""Create connection record."""
|
||||
return {
|
||||
'op': b'\x07',
|
||||
'conn': pack('<L', cid),
|
||||
'topic': f'/topic{topic}'.encode(),
|
||||
}, {
|
||||
'type': f'foo_msgs/msg/Foo{typ}'.encode(),
|
||||
'md5sum': b'AAAA',
|
||||
'message_definition': b'MSGDEF',
|
||||
}
|
||||
|
||||
|
||||
def create_message(
|
||||
cid: int = 1,
|
||||
time: int = 0,
|
||||
msg: int = 0,
|
||||
) -> tuple[dict[str, Union[bytes, int]], bytes]:
|
||||
"""Create message record."""
|
||||
return {
|
||||
'op': b'\x02',
|
||||
'conn': cid,
|
||||
'time': time,
|
||||
}, f'MSGCONTENT{msg}'.encode()
|
||||
|
||||
|
||||
def write_bag( # pylint: disable=too-many-locals
|
||||
bag: Path,
|
||||
header: dict[str, bytes],
|
||||
chunks: Sequence[Any] = (),
|
||||
) -> None:
|
||||
"""Write bag file."""
|
||||
magic = b'#ROSBAG V2.0\n'
|
||||
|
||||
pos = 13 + 4096
|
||||
conn_count = 0
|
||||
chunk_count = len(chunks or [])
|
||||
|
||||
chunks_bytes = b''
|
||||
connections = b''
|
||||
chunkinfos = b''
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
chunk_bytes = b''
|
||||
start_time = 2**32 - 1
|
||||
end_time = 0
|
||||
counts: dict[int, int] = defaultdict(int)
|
||||
index = {}
|
||||
offset = 0
|
||||
|
||||
for head, data in chunk:
|
||||
if head.get('op') == b'\x07':
|
||||
conn_count += 1
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
connections += add
|
||||
elif head.get('op') == b'\x02':
|
||||
time = head['time']
|
||||
head['time'] = pack('<LL', head['time'], 0)
|
||||
conn = head['conn']
|
||||
head['conn'] = pack('<L', head['conn'])
|
||||
|
||||
start_time = min([start_time, time])
|
||||
end_time = max([end_time, time])
|
||||
|
||||
counts[conn] += 1
|
||||
if conn not in index:
|
||||
index[conn] = {
|
||||
'count': 0,
|
||||
'msgs': b'',
|
||||
}
|
||||
index[conn]['count'] += 1 # type: ignore
|
||||
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
|
||||
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
offset = len(chunk_bytes)
|
||||
else:
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
|
||||
chunk_bytes = ser(
|
||||
{
|
||||
'op': b'\x05',
|
||||
'compression': b'none',
|
||||
'size': pack('<L', len(chunk_bytes)),
|
||||
},
|
||||
) + ser(chunk_bytes)
|
||||
for conn, data in index.items():
|
||||
chunk_bytes += ser(
|
||||
{
|
||||
'op': b'\x04',
|
||||
'ver': pack('<L', 1),
|
||||
'conn': pack('<L', conn),
|
||||
'count': pack('<L', data['count']),
|
||||
},
|
||||
) + ser(data['msgs'])
|
||||
|
||||
chunks_bytes += chunk_bytes
|
||||
chunkinfos += ser(
|
||||
{
|
||||
'op': b'\x06',
|
||||
'ver': pack('<L', 1),
|
||||
'chunk_pos': pack('<Q', pos),
|
||||
'start_time': pack('<LL', start_time, 0),
|
||||
'end_time': pack('<LL', end_time, 0),
|
||||
'count': pack('<L', len(counts.keys())),
|
||||
},
|
||||
) + ser(b''.join([pack('<LL', x, y) for x, y in counts.items()]))
|
||||
pos += len(chunk_bytes)
|
||||
|
||||
header['conn_count'] = pack('<L', conn_count)
|
||||
header['chunk_count'] = pack('<L', chunk_count)
|
||||
if 'index_pos' not in header:
|
||||
header['index_pos'] = pack('<Q', pos)
|
||||
|
||||
header_bytes = ser(header)
|
||||
header_bytes += b'\x20' * (4096 - len(header_bytes))
|
||||
|
||||
bag.write_bytes(b''.join([
|
||||
magic,
|
||||
header_bytes,
|
||||
chunks_bytes,
|
||||
connections,
|
||||
chunkinfos,
|
||||
]))
|
||||
|
||||
|
||||
def test_indexdata() -> None:
|
||||
"""Test IndexData sort sorder."""
|
||||
x42_1_0 = IndexData(42, 1, 0)
|
||||
x42_2_0 = IndexData(42, 2, 0)
|
||||
x43_3_0 = IndexData(43, 3, 0)
|
||||
|
||||
assert not x42_1_0 < x42_2_0
|
||||
assert x42_1_0 <= x42_2_0
|
||||
assert x42_1_0 == x42_2_0
|
||||
assert not x42_1_0 != x42_2_0 # noqa
|
||||
assert x42_1_0 >= x42_2_0
|
||||
assert not x42_1_0 > x42_2_0
|
||||
|
||||
assert x42_1_0 < x43_3_0
|
||||
assert x42_1_0 <= x43_3_0
|
||||
assert not x42_1_0 == x43_3_0 # noqa
|
||||
assert x42_1_0 != x43_3_0
|
||||
assert not x42_1_0 >= x43_3_0
|
||||
assert not x42_1_0 > x43_3_0
|
||||
|
||||
|
||||
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
# empty bag
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, create_default_header())
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 0
|
||||
assert reader.start_time == 2**63 - 1
|
||||
assert reader.end_time == 0
|
||||
assert reader.duration == 0
|
||||
assert not list(reader.messages())
|
||||
|
||||
# empty bag, explicit encryptor
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, {**create_default_header(), 'encryptor': b''})
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 0
|
||||
|
||||
# single message
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[[
|
||||
create_connection(),
|
||||
create_message(time=42),
|
||||
]],
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 1
|
||||
assert reader.duration == 1
|
||||
assert reader.start_time == 42 * 10**9
|
||||
assert reader.end_time == 42 * 10**9 + 1
|
||||
assert len(reader.topics.keys()) == 1
|
||||
assert reader.topics['/topic0'].msgcount == 1
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 1
|
||||
|
||||
# sorts by time on same topic
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[
|
||||
[
|
||||
create_connection(),
|
||||
create_message(time=10, msg=10),
|
||||
create_message(time=5, msg=5),
|
||||
],
|
||||
],
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 2
|
||||
assert reader.duration == 5 * 10**9 + 1
|
||||
assert reader.start_time == 5 * 10**9
|
||||
assert reader.end_time == 10 * 10**9 + 1
|
||||
assert len(reader.topics.keys()) == 1
|
||||
assert reader.topics['/topic0'].msgcount == 2
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][0].topic == '/topic0'
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
assert msgs[1][0].topic == '/topic0'
|
||||
assert msgs[1][2] == b'MSGCONTENT10'
|
||||
|
||||
# sorts by time on different topic
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[
|
||||
[
|
||||
create_connection(),
|
||||
create_message(time=10, msg=10),
|
||||
create_connection(cid=2, topic=2),
|
||||
create_message(cid=2, time=5, msg=5),
|
||||
],
|
||||
],
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert len(reader.topics.keys()) == 2
|
||||
assert reader.topics['/topic0'].msgcount == 1
|
||||
assert reader.topics['/topic2'].msgcount == 1
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
assert msgs[1][2] == b'MSGCONTENT10'
|
||||
|
||||
connections = [x for x in reader.connections if x.topic == '/topic0']
|
||||
msgs = list(reader.messages(connections))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][2] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(start=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][2] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(stop=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
|
||||
|
||||
def test_user_errors(tmp_path: Path) -> None:
|
||||
"""Test user errors."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
|
||||
reader = Reader(bag)
|
||||
with pytest.raises(ReaderError, match='is not open'):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||
"""Test failure cases."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
with pytest.raises(ReaderError, match='does not exist'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('')
|
||||
with patch('pathlib.Path.open', side_effect=IOError), \
|
||||
pytest.raises(ReaderError, match='not open'):
|
||||
Reader(bag).open()
|
||||
|
||||
with pytest.raises(ReaderError, match='empty'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('#BADMAGIC')
|
||||
with pytest.raises(ReaderError, match='magic is invalid'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('#ROSBAG V3.0\n')
|
||||
with pytest.raises(ReaderError, match='Bag version 300 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x00')
|
||||
with pytest.raises(ReaderError, match='Header could not be read from file.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00')
|
||||
with pytest.raises(ReaderError, match='Header could not be read from file.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00\x01')
|
||||
with pytest.raises(ReaderError, match='Header field size could not be read.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x04\x00\x00\x00\x01\x00\x00\x00')
|
||||
with pytest.raises(ReaderError, match='Declared field size is too large for header.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x05\x00\x00\x00\x01\x00\x00\x00x')
|
||||
with pytest.raises(ReaderError, match='Header field could not be parsed.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, {'encryptor': b'enc', **create_default_header()})
|
||||
with pytest.raises(ReaderError, match='is not supported'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, {**create_default_header(), 'index_pos': pack('<Q', 0)})
|
||||
with pytest.raises(ReaderError, match='Bag is not indexed'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'none', b'COMP'))
|
||||
with pytest.raises(ReaderError, match='Compression \'COMP\' is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02'))
|
||||
with pytest.raises(ReaderError, match='CHUNK_INFO version 2 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02', 1))
|
||||
with pytest.raises(ReaderError, match='IDXDATA version 2 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'op=\x02', b'op=\x00', 1))
|
||||
with Reader(bag) as reader, \
|
||||
pytest.raises(ReaderError, match='Expected to find message data.'):
|
||||
next(reader.messages())
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'op=\x03', b'op=\x02', 1))
|
||||
with pytest.raises(ReaderError, match='Record of type \'MSGDATA\' is unexpected.'):
|
||||
Reader(bag).open()
|
||||
|
||||
# bad uint8 field
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[[
|
||||
({}, {}),
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]],
|
||||
)
|
||||
with Reader(bag) as reader, \
|
||||
pytest.raises(ReaderError, match='field \'op\''):
|
||||
next(reader.messages())
|
||||
|
||||
# bad uint32, uint64, time field
|
||||
for name in ('conn_count', 'chunk_pos', 'time'):
|
||||
write_bag(bag, create_default_header(), chunks=[[create_connection(), create_message()]])
|
||||
bag.write_bytes(bag.read_bytes().replace(name.encode(), b'x' * len(name), 1))
|
||||
if name == 'time':
|
||||
with pytest.raises(ReaderError, match=f'field \'{name}\''), \
|
||||
Reader(bag) as reader:
|
||||
next(reader.messages())
|
||||
else:
|
||||
with pytest.raises(ReaderError, match=f'field \'{name}\''):
|
||||
Reader(bag).open()
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag2 import Reader, Writer
|
||||
from rosbags.serde import deserialize_cdr, serialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
|
||||
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
data = 1.25
|
||||
|
||||
path = tmp_path / 'rosbag2'
|
||||
wbag = Writer(path)
|
||||
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
|
||||
with wbag:
|
||||
msgtype = 'std_msgs/msg/Float64'
|
||||
wconnection = wbag.add_connection('/test', msgtype)
|
||||
wbag.write(wconnection, 42, serialize_cdr(Foo, msgtype))
|
||||
|
||||
rbag = Reader(path)
|
||||
with rbag:
|
||||
gen = rbag.messages()
|
||||
rconnection, _, raw = next(gen)
|
||||
assert rconnection.topic == wconnection.topic
|
||||
assert rconnection.msgtype == wconnection.msgtype
|
||||
assert rconnection.ext == wconnection.ext
|
||||
msg = deserialize_cdr(raw, rconnection.msgtype)
|
||||
assert getattr(msg, 'data', None) == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Reader, Writer
|
||||
from rosbags.serde import cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
data = 1.25
|
||||
|
||||
path = tmp_path / 'test.bag'
|
||||
wbag = Writer(path)
|
||||
if fmt:
|
||||
wbag.set_compression(fmt)
|
||||
with wbag:
|
||||
msgtype = 'std_msgs/msg/Float64'
|
||||
conn = wbag.add_connection('/test', msgtype)
|
||||
wbag.write(conn, 42, cdr_to_ros1(serialize_cdr(Foo, msgtype), msgtype))
|
||||
|
||||
rbag = Reader(path)
|
||||
with rbag:
|
||||
gen = rbag.messages()
|
||||
connection, _, raw = next(gen)
|
||||
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
|
||||
assert getattr(msg, 'data', None) == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
@@ -0,0 +1,513 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Serializer and deserializer tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from rosbags.serde import (
|
||||
SerdeError,
|
||||
cdr_to_ros1,
|
||||
deserialize_cdr,
|
||||
deserialize_ros1,
|
||||
ros1_to_cdr,
|
||||
serialize_cdr,
|
||||
serialize_ros1,
|
||||
)
|
||||
from rosbags.serde.messages import get_msgdef
|
||||
from rosbags.typesys import get_types_from_msg, register_types, types
|
||||
from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
|
||||
from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon
|
||||
from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField
|
||||
from rosbags.typesys.types import std_msgs__msg__Header as Header
|
||||
|
||||
from .cdr import deserialize, serialize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Generator, Union
|
||||
|
||||
MSG_POLY = (
|
||||
(
|
||||
b'\x00\x01\x00\x00' # header
|
||||
b'\x02\x00\x00\x00' # number of points = 2
|
||||
b'\x00\x00\x80\x3f' # x = 1
|
||||
b'\x00\x00\x00\x40' # y = 2
|
||||
b'\x00\x00\x40\x40' # z = 3
|
||||
b'\x00\x00\xa0\x3f' # x = 1.25
|
||||
b'\x00\x00\x10\x40' # y = 2.25
|
||||
b'\x00\x00\x50\x40' # z = 3.25
|
||||
),
|
||||
'geometry_msgs/msg/Polygon',
|
||||
True,
|
||||
)
|
||||
|
||||
MSG_MAGN = (
|
||||
(
|
||||
b'\x00\x01\x00\x00' # header
|
||||
b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns
|
||||
b'\x06\x00\x00\x00foo42\x00' # frameid 'foo42'
|
||||
b'\x00\x00\x00\x00\x00\x00' # padding
|
||||
b'\x00\x00\x00\x00\x00\x00\x60\x40' # x = 128
|
||||
b'\x00\x00\x00\x00\x00\x00\x60\x40' # y = 128
|
||||
b'\x00\x00\x00\x00\x00\x00\x60\x40' # z = 128
|
||||
b'\x00\x00\x00\x00\x00\x00\xF0\x3F' # covariance matrix = 3x3 diag
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\xF0\x3F'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\xF0\x3F'
|
||||
),
|
||||
'sensor_msgs/msg/MagneticField',
|
||||
True,
|
||||
)
|
||||
|
||||
MSG_MAGN_BIG = (
|
||||
(
|
||||
b'\x00\x00\x00\x00' # header
|
||||
b'\x00\x00\x02\xc4\x00\x00\x01\x00' # timestamp = 708s 256ns
|
||||
b'\x00\x00\x00\x06foo42\x00' # frameid 'foo42'
|
||||
b'\x00\x00\x00\x00\x00\x00' # padding
|
||||
b'\x40\x60\x00\x00\x00\x00\x00\x00' # x = 128
|
||||
b'\x40\x60\x00\x00\x00\x00\x00\x00' # y = 128
|
||||
b'\x40\x60\x00\x00\x00\x00\x00\x00' # z = 128
|
||||
b'\x3F\xF0\x00\x00\x00\x00\x00\x00' # covariance matrix = 3x3 diag
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x3F\xF0\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
b'\x3F\xF0\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00' # garbage
|
||||
),
|
||||
'sensor_msgs/msg/MagneticField',
|
||||
False,
|
||||
)
|
||||
|
||||
MSG_JOINT = (
|
||||
(
|
||||
b'\x00\x01\x00\x00' # header
|
||||
b'\xc4\x02\x00\x00\x00\x01\x00\x00' # timestamp = 708s 256ns
|
||||
b'\x04\x00\x00\x00bar\x00' # frameid 'bar'
|
||||
b'\x02\x00\x00\x00' # number of strings
|
||||
b'\x02\x00\x00\x00a\x00' # string 'a'
|
||||
b'\x00\x00' # padding
|
||||
b'\x02\x00\x00\x00b\x00' # string 'b'
|
||||
b'\x00\x00' # padding
|
||||
b'\x00\x00\x00\x00' # number of points
|
||||
b'\x00\x00\x00' # garbage
|
||||
),
|
||||
'trajectory_msgs/msg/JointTrajectory',
|
||||
True,
|
||||
)
|
||||
|
||||
MESSAGES = [MSG_POLY, MSG_MAGN, MSG_MAGN_BIG, MSG_JOINT]
|
||||
|
||||
STATIC_64_64 = """
|
||||
uint64[2] u64
|
||||
"""
|
||||
|
||||
STATIC_64_16 = """
|
||||
uint64 u64
|
||||
uint16 u16
|
||||
"""
|
||||
|
||||
STATIC_16_64 = """
|
||||
uint16 u16
|
||||
uint64 u64
|
||||
"""
|
||||
|
||||
DYNAMIC_64_64 = """
|
||||
uint64[] u64
|
||||
"""
|
||||
|
||||
DYNAMIC_64_B_64 = """
|
||||
uint64 u64
|
||||
bool b
|
||||
float64 f64
|
||||
"""
|
||||
|
||||
DYNAMIC_64_S = """
|
||||
uint64 u64
|
||||
string s
|
||||
"""
|
||||
|
||||
DYNAMIC_S_64 = """
|
||||
string s
|
||||
uint64 u64
|
||||
"""
|
||||
|
||||
CUSTOM = """
|
||||
string base_str
|
||||
float32 base_f32
|
||||
test_msgs/msg/static_64_64 msg_s66
|
||||
test_msgs/msg/static_64_16 msg_s61
|
||||
test_msgs/msg/static_16_64 msg_s16
|
||||
test_msgs/msg/dynamic_64_64 msg_d66
|
||||
test_msgs/msg/dynamic_64_b_64 msg_d6b6
|
||||
test_msgs/msg/dynamic_64_s msg_d6s
|
||||
test_msgs/msg/dynamic_s_64 msg_ds6
|
||||
|
||||
string[2] arr_base_str
|
||||
float32[2] arr_base_f32
|
||||
test_msgs/msg/static_64_64[2] arr_msg_s66
|
||||
test_msgs/msg/static_64_16[2] arr_msg_s61
|
||||
test_msgs/msg/static_16_64[2] arr_msg_s16
|
||||
test_msgs/msg/dynamic_64_64[2] arr_msg_d66
|
||||
test_msgs/msg/dynamic_64_b_64[2] arr_msg_d6b6
|
||||
test_msgs/msg/dynamic_64_s[2] arr_msg_d6s
|
||||
test_msgs/msg/dynamic_s_64[2] arr_msg_ds6
|
||||
|
||||
string[] seq_base_str
|
||||
float32[] seq_base_f32
|
||||
test_msgs/msg/static_64_64[] seq_msg_s66
|
||||
test_msgs/msg/static_64_16[] seq_msg_s61
|
||||
test_msgs/msg/static_16_64[] seq_msg_s16
|
||||
test_msgs/msg/dynamic_64_64[] seq_msg_d66
|
||||
test_msgs/msg/dynamic_64_b_64[] seq_msg_d6b6
|
||||
test_msgs/msg/dynamic_64_s[] seq_msg_d6s
|
||||
test_msgs/msg/dynamic_s_64[] seq_msg_ds6
|
||||
"""
|
||||
|
||||
SU64_B = """
|
||||
uint64[] su64
|
||||
bool b
|
||||
"""
|
||||
|
||||
SU64_U64 = """
|
||||
uint64[] su64
|
||||
uint64 u64
|
||||
"""
|
||||
|
||||
SMSG_U64 = """
|
||||
su64_u64[] seq
|
||||
uint64 u64
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _comparable() -> Generator[None, None, None]:
|
||||
"""Make messages containing numpy arrays comparable.
|
||||
|
||||
Notes:
|
||||
This solution is necessary as numpy.ndarray is not directly patchable.
|
||||
|
||||
"""
|
||||
frombuffer = numpy.frombuffer
|
||||
|
||||
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
|
||||
lhs = self._mock_wraps # pylint: disable=protected-access
|
||||
rhs = getattr(other, '_mock_wraps', other)
|
||||
return (lhs == rhs).all() # type: ignore
|
||||
|
||||
class CNDArray(MagicMock):
|
||||
"""Mock ndarray."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = kwargs['wraps'].dtype
|
||||
self.reshape = kwargs['wraps'].reshape
|
||||
self.__eq__ = arreq # type: ignore
|
||||
|
||||
def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401
|
||||
"""Wrap return value also in mock."""
|
||||
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
|
||||
|
||||
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401
|
||||
return CNDArray(wraps=frombuffer(*args, **kwargs))
|
||||
|
||||
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize('message', MESSAGES)
|
||||
def test_serde(message: tuple[bytes, str, bool]) -> None:
|
||||
"""Test serialization deserialization roundtrip."""
|
||||
rawdata, typ, is_little = message
|
||||
|
||||
serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little)
|
||||
assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little)
|
||||
assert serdeser == rawdata[:len(serdeser)]
|
||||
assert len(rawdata) - len(serdeser) < 4
|
||||
assert all(x == 0 for x in rawdata[len(serdeser):])
|
||||
|
||||
if rawdata[1] == 1:
|
||||
rawdata = cdr_to_ros1(rawdata, typ)
|
||||
serdeser = serialize_ros1(deserialize_ros1(rawdata, typ), typ)
|
||||
assert serdeser == rawdata
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_deserializer() -> None:
|
||||
"""Test deserializer."""
|
||||
msg = deserialize_cdr(*MSG_POLY[:2])
|
||||
assert msg == deserialize(*MSG_POLY[:2])
|
||||
assert isinstance(msg, Polygon)
|
||||
assert len(msg.points) == 2
|
||||
assert msg.points[0].x == 1
|
||||
assert msg.points[0].y == 2
|
||||
assert msg.points[0].z == 3
|
||||
assert msg.points[1].x == 1.25
|
||||
assert msg.points[1].y == 2.25
|
||||
assert msg.points[1].z == 3.25
|
||||
msg_ros1 = deserialize_ros1(cdr_to_ros1(*MSG_POLY[:2]), MSG_POLY[1])
|
||||
assert msg_ros1 == msg
|
||||
|
||||
msg = deserialize_cdr(*MSG_MAGN[:2])
|
||||
assert msg == deserialize(*MSG_MAGN[:2])
|
||||
assert isinstance(msg, MagneticField)
|
||||
assert 'MagneticField' in repr(msg)
|
||||
assert msg.header.stamp.sec == 708
|
||||
assert msg.header.stamp.nanosec == 256
|
||||
assert msg.header.frame_id == 'foo42'
|
||||
field = msg.magnetic_field
|
||||
assert (field.x, field.y, field.z) == (128., 128., 128.)
|
||||
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3))
|
||||
assert (diag == [1., 1., 1.]).all()
|
||||
msg_ros1 = deserialize_ros1(cdr_to_ros1(*MSG_MAGN[:2]), MSG_MAGN[1])
|
||||
assert msg_ros1 == msg
|
||||
|
||||
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
|
||||
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
|
||||
assert isinstance(msg_big, MagneticField)
|
||||
assert msg.magnetic_field == msg_big.magnetic_field
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_serializer() -> None:
|
||||
"""Test serializer."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
data = 7
|
||||
|
||||
msg = Foo()
|
||||
ret = serialize_cdr(msg, 'std_msgs/msg/Int8', True)
|
||||
assert ret == serialize(msg, 'std_msgs/msg/Int8', True)
|
||||
assert ret == b'\x00\x01\x00\x00\x07'
|
||||
|
||||
ret = serialize_cdr(msg, 'std_msgs/msg/Int8', False)
|
||||
assert ret == serialize(msg, 'std_msgs/msg/Int8', False)
|
||||
assert ret == b'\x00\x00\x00\x00\x07'
|
||||
|
||||
ret = serialize_cdr(msg, 'std_msgs/msg/Int16', True)
|
||||
assert ret == serialize(msg, 'std_msgs/msg/Int16', True)
|
||||
assert ret == b'\x00\x01\x00\x00\x07\x00'
|
||||
|
||||
ret = serialize_cdr(msg, 'std_msgs/msg/Int16', False)
|
||||
assert ret == serialize(msg, 'std_msgs/msg/Int16', False)
|
||||
assert ret == b'\x00\x00\x00\x00\x00\x07'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_serializer_errors() -> None:
|
||||
"""Test seralizer with broken messages."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4])
|
||||
|
||||
msg = Foo()
|
||||
ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
|
||||
assert ret == serialize(msg, 'shape_msgs/msg/Plane', True)
|
||||
|
||||
msg.coef = numpy.array([1, 2, 3, 4, 4])
|
||||
with pytest.raises(SerdeError, match='array length'):
|
||||
serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_custom_type() -> None:
|
||||
"""Test custom type."""
|
||||
cname = 'test_msgs/msg/custom'
|
||||
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
|
||||
register_types(dict(get_types_from_msg(STATIC_64_16, 'test_msgs/msg/static_64_16')))
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_64_64, 'test_msgs/msg/dynamic_64_64')))
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_64_B_64, 'test_msgs/msg/dynamic_64_b_64')))
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_64_S, 'test_msgs/msg/dynamic_64_s')))
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
|
||||
register_types(dict(get_types_from_msg(CUSTOM, cname)))
|
||||
|
||||
static_64_64 = get_msgdef('test_msgs/msg/static_64_64', types).cls
|
||||
static_64_16 = get_msgdef('test_msgs/msg/static_64_16', types).cls
|
||||
static_16_64 = get_msgdef('test_msgs/msg/static_16_64', types).cls
|
||||
dynamic_64_64 = get_msgdef('test_msgs/msg/dynamic_64_64', types).cls
|
||||
dynamic_64_b_64 = get_msgdef('test_msgs/msg/dynamic_64_b_64', types).cls
|
||||
dynamic_64_s = get_msgdef('test_msgs/msg/dynamic_64_s', types).cls
|
||||
dynamic_s_64 = get_msgdef('test_msgs/msg/dynamic_s_64', types).cls
|
||||
custom = get_msgdef('test_msgs/msg/custom', types).cls
|
||||
|
||||
msg = custom(
|
||||
'str',
|
||||
1.5,
|
||||
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
|
||||
static_64_16(64, 16),
|
||||
static_16_64(16, 64),
|
||||
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
|
||||
dynamic_64_b_64(64, True, 1.25),
|
||||
dynamic_64_s(64, 's'),
|
||||
dynamic_s_64('s', 64),
|
||||
# arrays
|
||||
['str_1', ''],
|
||||
numpy.array([1.5, 0.75], dtype=numpy.float32),
|
||||
[
|
||||
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
|
||||
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
|
||||
],
|
||||
[static_64_16(64, 16), static_64_16(64, 16)],
|
||||
[static_16_64(16, 64), static_16_64(16, 64)],
|
||||
[
|
||||
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
|
||||
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
|
||||
],
|
||||
[
|
||||
dynamic_64_b_64(64, True, 1.25),
|
||||
dynamic_64_b_64(64, True, 1.25),
|
||||
],
|
||||
[dynamic_64_s(64, 's'), dynamic_64_s(64, 's')],
|
||||
[dynamic_s_64('s', 64), dynamic_s_64('s', 64)],
|
||||
# sequences
|
||||
['str_1', ''],
|
||||
numpy.array([1.5, 0.75], dtype=numpy.float32),
|
||||
[
|
||||
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
|
||||
static_64_64(numpy.array([64, 64], dtype=numpy.uint64)),
|
||||
],
|
||||
[static_64_16(64, 16), static_64_16(64, 16)],
|
||||
[static_16_64(16, 64), static_16_64(16, 64)],
|
||||
[
|
||||
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
|
||||
dynamic_64_64(numpy.array([33, 33], dtype=numpy.uint64)),
|
||||
],
|
||||
[
|
||||
dynamic_64_b_64(64, True, 1.25),
|
||||
dynamic_64_b_64(64, True, 1.25),
|
||||
],
|
||||
[dynamic_64_s(64, 's'), dynamic_64_s(64, 's')],
|
||||
[dynamic_s_64('s', 64), dynamic_s_64('s', 64)],
|
||||
)
|
||||
|
||||
res = deserialize_cdr(serialize_cdr(msg, cname), cname)
|
||||
assert res == deserialize(serialize(msg, cname), cname)
|
||||
assert res == msg
|
||||
|
||||
res = deserialize_ros1(serialize_ros1(msg, cname), cname)
|
||||
assert res == msg
|
||||
|
||||
|
||||
def test_ros1_to_cdr() -> None:
|
||||
"""Test ROS1 to CDR conversion."""
|
||||
msgtype = 'test_msgs/msg/static_16_64'
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, msgtype)))
|
||||
msg_ros = (b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02'
|
||||
)
|
||||
assert ros1_to_cdr(msg_ros, msgtype) == msg_cdr
|
||||
assert serialize_cdr(deserialize_ros1(msg_ros, msgtype), msgtype) == msg_cdr
|
||||
|
||||
msgtype = 'test_msgs/msg/dynamic_s_64'
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_S_64, msgtype)))
|
||||
msg_ros = (b'\x01\x00\x00\x00X'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x02\x00\x00\x00X\x00'
|
||||
b'\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02'
|
||||
)
|
||||
assert ros1_to_cdr(msg_ros, msgtype) == msg_cdr
|
||||
assert serialize_cdr(deserialize_ros1(msg_ros, msgtype), msgtype) == msg_cdr
|
||||
|
||||
|
||||
def test_cdr_to_ros1() -> None:
|
||||
"""Test CDR to ROS1 conversion."""
|
||||
msgtype = 'test_msgs/msg/static_16_64'
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, msgtype)))
|
||||
msg_ros = (b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02'
|
||||
)
|
||||
assert cdr_to_ros1(msg_cdr, msgtype) == msg_ros
|
||||
assert serialize_ros1(deserialize_cdr(msg_cdr, msgtype), msgtype) == msg_ros
|
||||
|
||||
msgtype = 'test_msgs/msg/dynamic_s_64'
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_S_64, msgtype)))
|
||||
msg_ros = (b'\x01\x00\x00\x00X'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x02\x00\x00\x00X\x00'
|
||||
b'\x00\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02'
|
||||
)
|
||||
assert cdr_to_ros1(msg_cdr, msgtype) == msg_ros
|
||||
assert serialize_ros1(deserialize_cdr(msg_cdr, msgtype), msgtype) == msg_ros
|
||||
|
||||
header = Header(stamp=Time(42, 666), frame_id='frame')
|
||||
msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header')
|
||||
assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_padding_empty_sequence() -> None:
|
||||
"""Test empty sequences do not add item padding."""
|
||||
register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b')))
|
||||
|
||||
su64_b = get_msgdef('test_msgs/msg/su64_b', types).cls
|
||||
msg = su64_b(numpy.array([], dtype=numpy.uint64), True)
|
||||
|
||||
cdr = serialize_cdr(msg, msg.__msgtype__)
|
||||
assert cdr[4:] == b'\x00\x00\x00\x00\x01'
|
||||
|
||||
ros1 = cdr_to_ros1(cdr, msg.__msgtype__)
|
||||
assert ros1 == cdr[4:]
|
||||
|
||||
assert ros1_to_cdr(ros1, msg.__msgtype__) == cdr
|
||||
|
||||
assert deserialize_cdr(cdr, msg.__msgtype__) == msg
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_align_after_empty_sequence() -> None:
|
||||
"""Test alignment after empty sequences."""
|
||||
register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64')))
|
||||
register_types(dict(get_types_from_msg(SMSG_U64, 'test_msgs/msg/smsg_u64')))
|
||||
|
||||
su64_u64 = get_msgdef('test_msgs/msg/su64_u64', types).cls
|
||||
smsg_u64 = get_msgdef('test_msgs/msg/smsg_u64', types).cls
|
||||
msg1 = su64_u64(numpy.array([], dtype=numpy.uint64), 42)
|
||||
msg2 = smsg_u64([], 42)
|
||||
|
||||
cdr = serialize_cdr(msg1, msg1.__msgtype__)
|
||||
assert cdr[4:] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00'
|
||||
assert serialize_cdr(msg2, msg2.__msgtype__) == cdr
|
||||
|
||||
ros1 = cdr_to_ros1(cdr, msg1.__msgtype__)
|
||||
assert ros1 == b'\x00\x00\x00\x00\x2a\x00\x00\x00\x00\x00\x00\x00'
|
||||
assert cdr_to_ros1(cdr, msg2.__msgtype__) == ros1
|
||||
|
||||
assert ros1_to_cdr(ros1, msg1.__msgtype__) == cdr
|
||||
|
||||
assert deserialize_cdr(cdr, msg1.__msgtype__) == msg1
|
||||
assert deserialize_cdr(cdr, msg2.__msgtype__) == msg2
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Writer tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2
|
||||
from rosbags.rosbag2 import Writer, WriterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_writer(tmp_path: Path) -> None:
|
||||
"""Test Writer."""
|
||||
path = tmp_path / 'rosbag2'
|
||||
with Writer(path) as bag:
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'rosbag2.db3').exists()
|
||||
size = (path / 'rosbag2.db3').stat().st_size
|
||||
|
||||
path = tmp_path / 'compress_none'
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_none.db3').exists()
|
||||
assert size == (path / 'compress_none.db3').stat().st_size
|
||||
|
||||
path = tmp_path / 'compress_file'
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert not (path / 'compress_file.db3').exists()
|
||||
assert (path / 'compress_file.db3.zstd').exists()
|
||||
|
||||
path = tmp_path / 'compress_message'
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
connection = bag.add_connection('/test', 'std_msgs/msg/Int8')
|
||||
bag.write(connection, 42, b'\x00')
|
||||
bag.write(connection, 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_message.db3').exists()
|
||||
assert size > (path / 'compress_message.db3').stat().st_size
|
||||
|
||||
path = tmp_path / 'with_custom_data'
|
||||
bag = Writer(path)
|
||||
bag.open()
|
||||
bag.set_custom_data('key1', 'value1')
|
||||
with pytest.raises(WriterError, match='non-string value'):
|
||||
bag.set_custom_data('key1', 42) # type: ignore
|
||||
bag.close()
|
||||
assert b'key1: value1' in (path / 'metadata.yaml').read_bytes()
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path) -> None:
|
||||
"""Test writer failure cases."""
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
Writer(tmp_path)
|
||||
|
||||
bag = Writer(tmp_path / 'race')
|
||||
(tmp_path / 'race').mkdir()
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
bag.open()
|
||||
|
||||
bag = Writer(tmp_path / 'compress_after_open')
|
||||
bag.open()
|
||||
with pytest.raises(WriterError, match='already open'):
|
||||
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'write')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
bag.write(
|
||||
Connection(
|
||||
1,
|
||||
'/tf',
|
||||
'tf_msgs/msg/tf2',
|
||||
'',
|
||||
'',
|
||||
0,
|
||||
ConnectionExtRosbag2('cdr', ''),
|
||||
None,
|
||||
),
|
||||
0,
|
||||
b'',
|
||||
)
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
bag.open()
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
with pytest.raises(WriterError, match='only be added once'):
|
||||
bag.add_connection('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'notopic')
|
||||
bag.open()
|
||||
connection = Connection(
|
||||
1,
|
||||
'/tf',
|
||||
'tf_msgs/msg/tf2',
|
||||
'',
|
||||
'',
|
||||
0,
|
||||
ConnectionExtRosbag2('cdr', ''),
|
||||
None,
|
||||
)
|
||||
with pytest.raises(WriterError, match='unknown connection'):
|
||||
bag.write(connection, 42, b'\x00')
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Writer tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Writer, WriterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def test_no_overwrite(tmp_path: Path) -> None:
|
||||
"""Test writer does not touch existing files."""
|
||||
path = tmp_path / 'test.bag'
|
||||
path.write_text('foo')
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
Writer(path).open()
|
||||
path.unlink()
|
||||
|
||||
writer = Writer(path)
|
||||
path.write_text('foo')
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
writer.open()
|
||||
|
||||
|
||||
def test_empty(tmp_path: Path) -> None:
|
||||
"""Test empty bag."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with Writer(path):
|
||||
pass
|
||||
data = path.read_bytes()
|
||||
assert len(data) == 13 + 4096
|
||||
|
||||
|
||||
def test_add_connection(tmp_path: Path) -> None:
|
||||
"""Test adding of connections."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with pytest.raises(WriterError, match='not opened'):
|
||||
Writer(path).add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
assert res.id == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'MESSAGE_DEFINITION') == 2
|
||||
assert data.count(b'HASH') == 2
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'std_msgs/msg/Int8')
|
||||
assert res.id == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'int8 data') == 2
|
||||
assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
with pytest.raises(WriterError, match='can only be added once'):
|
||||
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
res1 = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
res2 = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
res3 = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
assert (res1.id, res2.id, res3.id) == (0, 1, 2)
|
||||
|
||||
|
||||
def test_write_errors(tmp_path: Path) -> None:
|
||||
"""Test write errors."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with pytest.raises(WriterError, match='not opened'):
|
||||
Writer(path).write(Mock(), 42, b'DEADBEEF')
|
||||
|
||||
with Writer(path) as writer, \
|
||||
pytest.raises(WriterError, match='is no connection'):
|
||||
writer.write(Mock(), 42, b'DEADBEEF')
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_write_simple(tmp_path: Path) -> None:
|
||||
"""Test writing of messages."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with Writer(path) as writer:
|
||||
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
conn_latching = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
conn_bar = writer.add_connection(
|
||||
'/bar',
|
||||
'test_msgs/msg/Bar',
|
||||
'OTHER_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
|
||||
|
||||
writer.write(conn_foo, 42, b'DEADBEEF')
|
||||
writer.write(conn_latching, 42, b'DEADBEEF')
|
||||
writer.write(conn_bar, 43, b'SECRET')
|
||||
writer.write(conn_bar, 43, b'SUBSEQUENT')
|
||||
|
||||
res = path.read_bytes()
|
||||
assert res.count(b'op=\x05') == 1
|
||||
assert res.count(b'op=\x06') == 1
|
||||
assert res.count(b'MESSAGE_DEFINITION') == 4
|
||||
assert res.count(b'latching=1') == 2
|
||||
assert res.count(b'OTHER_DEFINITION') == 2
|
||||
assert res.count(b'callerid=src') == 2
|
||||
assert res.count(b'NEVER_WRITTEN') == 2
|
||||
assert res.count(b'DEADBEEF') == 2
|
||||
assert res.count(b'SECRET') == 1
|
||||
assert res.count(b'SUBSEQUENT') == 1
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
writer.chunk_threshold = 256
|
||||
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
conn_latching = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
conn_bar = writer.add_connection(
|
||||
'/bar',
|
||||
'test_msgs/msg/Bar',
|
||||
'OTHER_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
|
||||
|
||||
writer.write(conn_foo, 42, b'DEADBEEF')
|
||||
writer.write(conn_latching, 42, b'DEADBEEF')
|
||||
writer.write(conn_bar, 43, b'SECRET')
|
||||
writer.write(conn_bar, 43, b'SUBSEQUENT')
|
||||
|
||||
res = path.read_bytes()
|
||||
assert res.count(b'op=\x05') == 2
|
||||
assert res.count(b'op=\x06') == 2
|
||||
assert res.count(b'MESSAGE_DEFINITION') == 4
|
||||
assert res.count(b'latching=1') == 2
|
||||
assert res.count(b'OTHER_DEFINITION') == 2
|
||||
assert res.count(b'callerid=src') == 2
|
||||
assert res.count(b'NEVER_WRITTEN') == 2
|
||||
assert res.count(b'DEADBEEF') == 2
|
||||
assert res.count(b'SECRET') == 1
|
||||
assert res.count(b'SUBSEQUENT') == 1
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_compression_errors(tmp_path: Path) -> None:
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
with Writer(path) as writer, \
|
||||
pytest.raises(WriterError, match='already open'):
|
||||
writer.set_compression(writer.CompressionFormat.BZ2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
writer = Writer(path)
|
||||
if fmt:
|
||||
writer.set_compression(fmt)
|
||||
with writer:
|
||||
conn = writer.add_connection('/foo', 'std_msgs/msg/Int8')
|
||||
writer.write(conn, 42, b'\x42')
|
||||
data = path.read_bytes()
|
||||
assert data.count(f'compression={fmt.name.lower() if fmt else "none"}'.encode()) == 1
|
||||
Reference in New Issue
Block a user