Add rosbag2 support
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import zstandard
|
||||
|
||||
from rosbags.rosbag2 import Reader, ReaderError, Writer
|
||||
|
||||
from .test_serde import MSG_JOINT, MSG_MAGN, MSG_MAGN_BIG, MSG_POLY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
METADATA = """
|
||||
rosbag2_bagfile_information:
|
||||
version: 4
|
||||
storage_identifier: sqlite3
|
||||
relative_file_paths:
|
||||
- db.db3{extension}
|
||||
duration:
|
||||
nanoseconds: 42
|
||||
starting_time:
|
||||
nanoseconds_since_epoch: 666
|
||||
message_count: 4
|
||||
topics_with_message_count:
|
||||
- topic_metadata:
|
||||
name: /poly
|
||||
type: geometry_msgs/msg/Polygon
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 1
|
||||
- topic_metadata:
|
||||
name: /magn
|
||||
type: sensor_msgs/msg/MagneticField
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 2
|
||||
- topic_metadata:
|
||||
name: /joint
|
||||
type: trajectory_msgs/msg/JointTrajectory
|
||||
serialization_format: cdr
|
||||
offered_qos_profiles: ""
|
||||
message_count: 1
|
||||
compression_format: {compression_format}
|
||||
compression_mode: {compression_mode}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(params=['none', 'file', 'message'])
|
||||
def bag(request: SubRequest, tmp_path: Path) -> Path:
|
||||
"""Manually contruct bag."""
|
||||
(tmp_path / 'metadata.yaml').write_text(
|
||||
METADATA.format(
|
||||
extension='' if request.param != 'file' else '.zstd',
|
||||
compression_format='""' if request.param == 'none' else 'zstd',
|
||||
compression_mode='""' if request.param == 'none' else request.param.upper(),
|
||||
),
|
||||
)
|
||||
|
||||
comp = zstandard.ZstdCompressor()
|
||||
|
||||
dbpath = tmp_path / 'db.db3'
|
||||
dbh = sqlite3.connect(dbpath)
|
||||
dbh.executescript(Writer.SQLITE_SCHEMA)
|
||||
|
||||
cur = dbh.cursor()
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(1, '/poly', 'geometry_msgs/msg/Polygon', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(2, '/magn', 'sensor_msgs/msg/MagneticField', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO topics VALUES(?, ?, ?, ?, ?)',
|
||||
(3, '/joint', 'trajectory_msgs/msg/JointTrajectory', 'cdr', ''),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(1, 1, 666, MSG_POLY[0] if request.param != 'message' else comp.compress(MSG_POLY[0])),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(2, 2, 708, MSG_MAGN[0] if request.param != 'message' else comp.compress(MSG_MAGN[0])),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(
|
||||
3,
|
||||
2,
|
||||
708,
|
||||
MSG_MAGN_BIG[0] if request.param != 'message' else comp.compress(MSG_MAGN_BIG[0]),
|
||||
),
|
||||
)
|
||||
cur.execute(
|
||||
'INSERT INTO messages VALUES(?, ?, ?, ?)',
|
||||
(4, 3, 708, MSG_JOINT[0] if request.param != 'message' else comp.compress(MSG_JOINT[0])),
|
||||
)
|
||||
dbh.commit()
|
||||
|
||||
if request.param == 'file':
|
||||
with dbpath.open('rb') as ifh, (tmp_path / 'db.db3.zstd').open('wb') as ofh:
|
||||
comp.copy_stream(ifh, ofh)
|
||||
dbpath.unlink()
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_reader(bag: Path):
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
with Reader(bag) as reader:
|
||||
assert reader.duration == 42
|
||||
assert reader.start_time == 666
|
||||
assert reader.end_time == 708
|
||||
assert reader.message_count == 4
|
||||
if reader.compression_mode:
|
||||
assert reader.compression_format == 'zstd'
|
||||
|
||||
gen = reader.messages()
|
||||
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/poly'
|
||||
assert msgtype == 'geometry_msgs/msg/Polygon'
|
||||
assert timestamp == 666
|
||||
assert rawdata == MSG_POLY[0]
|
||||
|
||||
for idx in range(2):
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/magn'
|
||||
assert msgtype == 'sensor_msgs/msg/MagneticField'
|
||||
assert timestamp == 708
|
||||
assert rawdata == [MSG_MAGN, MSG_MAGN_BIG][idx][0]
|
||||
|
||||
topic, msgtype, timestamp, rawdata = next(gen)
|
||||
assert topic == '/joint'
|
||||
assert msgtype == 'trajectory_msgs/msg/JointTrajectory'
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_message_filters(bag: Path):
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag) as reader:
|
||||
|
||||
gen = reader.messages(['/magn'])
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=667)
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/magn'
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/joint'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(stop=667)
|
||||
topic, _, _, _ = next(gen)
|
||||
assert topic == '/poly'
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(['/magn'], stop=667)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
gen = reader.messages(start=666, stop=666)
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_user_errors(bag: Path):
|
||||
"""Test user errors."""
|
||||
reader = Reader(bag)
|
||||
with pytest.raises(ReaderError, match='Rosbag is not open'):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path):
|
||||
"""Test bags with broken fs layout."""
|
||||
with pytest.raises(ReaderError, match='not read metadata'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata = tmp_path / 'metadata.yaml'
|
||||
|
||||
metadata.write_text('')
|
||||
with pytest.raises(ReaderError, match='not read'), \
|
||||
mock.patch.object(Path, 'read_text', side_effect=PermissionError):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(' invalid:\nthis is not yaml')
|
||||
with pytest.raises(ReaderError, match='not load YAML from'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text('foo:')
|
||||
with pytest.raises(ReaderError, match='key is missing'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('version: 4', 'version: 999'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='version 999'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('sqlite3', 'hdf5'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Storage plugin'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='files are missing'):
|
||||
Reader(tmp_path)
|
||||
|
||||
(tmp_path / 'db.db3').write_text('')
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
).replace('cdr', 'bson'),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Serialization format'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='"gz"',
|
||||
compression_mode='"file"',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='Compression format'):
|
||||
Reader(tmp_path)
|
||||
|
||||
metadata.write_text(
|
||||
METADATA.format(
|
||||
extension='',
|
||||
compression_format='""',
|
||||
compression_mode='""',
|
||||
),
|
||||
)
|
||||
with pytest.raises(ReaderError, match='not open database'), \
|
||||
Reader(tmp_path) as reader:
|
||||
next(reader.messages())
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag2 import Reader, Writer
|
||||
from rosbags.serde import deserialize_cdr, serialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
|
||||
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
data = 1.25
|
||||
|
||||
path = tmp_path / 'rosbag2'
|
||||
wbag = Writer(path)
|
||||
wbag.set_compression(mode, wbag.CompressionFormat.ZSTD)
|
||||
with wbag:
|
||||
msgtype = 'std_msgs/msg/Float64'
|
||||
wbag.add_topic('/test', msgtype)
|
||||
wbag.write('/test', 42, serialize_cdr(Foo, msgtype))
|
||||
|
||||
rbag = Reader(path)
|
||||
with rbag:
|
||||
gen = rbag.messages()
|
||||
_, msgtype, _, raw = next(gen)
|
||||
msg = deserialize_cdr(raw, msgtype)
|
||||
assert msg.data == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Writer tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag2 import Writer, WriterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_writer(tmp_path: Path):
|
||||
"""Test Writer."""
|
||||
path = (tmp_path / 'rosbag2')
|
||||
with Writer(path) as bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'rosbag2.db3').exists()
|
||||
size = (path / 'rosbag2.db3').stat().st_size
|
||||
|
||||
path = (tmp_path / 'compress_none')
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.NONE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_none.db3').exists()
|
||||
assert size == (path / 'compress_none.db3').stat().st_size
|
||||
|
||||
path = (tmp_path / 'compress_file')
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert not (path / 'compress_file.db3').exists()
|
||||
assert (path / 'compress_file.db3.zstd').exists()
|
||||
|
||||
path = (tmp_path / 'compress_message')
|
||||
bag = Writer(path)
|
||||
bag.set_compression(bag.CompressionMode.MESSAGE, bag.CompressionFormat.ZSTD)
|
||||
with bag:
|
||||
bag.add_topic('/test', 'std_msgs/msg/Int8')
|
||||
bag.write('/test', 42, b'\x00')
|
||||
bag.write('/test', 666, b'\x01' * 4096)
|
||||
assert (path / 'metadata.yaml').exists()
|
||||
assert (path / 'compress_message.db3').exists()
|
||||
assert size > (path / 'compress_message.db3').stat().st_size
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path):
|
||||
"""Test writer failure cases."""
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
Writer(tmp_path)
|
||||
|
||||
bag = Writer(tmp_path / 'race')
|
||||
(tmp_path / 'race').mkdir()
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
bag.open()
|
||||
|
||||
bag = Writer(tmp_path / 'compress_after_open')
|
||||
bag.open()
|
||||
with pytest.raises(WriterError, match='already open'):
|
||||
bag.set_compression(bag.CompressionMode.FILE, bag.CompressionFormat.ZSTD)
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'write')
|
||||
with pytest.raises(WriterError, match='was not opened'):
|
||||
bag.write('/tf', 0, b'')
|
||||
|
||||
bag = Writer(tmp_path / 'topic')
|
||||
bag.open()
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
with pytest.raises(WriterError, match='only be added once'):
|
||||
bag.add_topic('/tf', 'tf_msgs/msg/tf2')
|
||||
|
||||
bag = Writer(tmp_path / 'notopic')
|
||||
bag.open()
|
||||
with pytest.raises(WriterError, match='unknown topic'):
|
||||
bag.write('/test', 42, b'\x00')
|
||||
Reference in New Issue
Block a user