Add rosbag1 writer
This commit is contained in:
parent
5bd1bcbd83
commit
cc96973be3
@ -14,7 +14,7 @@ Rosbags
|
||||
Rosbags is the **pure python** library for everything rosbag. It contains:
|
||||
|
||||
- **rosbag2** reader and writer,
|
||||
- **rosbag1** reader for raw messages,
|
||||
- **rosbag1** reader and writer,
|
||||
- **extensible** type system with serializers and deserializers,
|
||||
- **efficient converter** between rosbag1 and rosbag2,
|
||||
- and more.
|
||||
|
||||
@ -3,6 +3,27 @@ Rosbag1
|
||||
|
||||
The :py:mod:`rosbags.rosbag1` package provides fast read-only access to raw messages stored in the legacy bag format. The rosbag1 support is built for a ROS2 world and some APIs and values perform normalizations to mimic ROS2 behavior and make messages originating from rosbag1 and rosbag2 behave identically. Most notably message types are internally renamed to match their ROS2 counterparts.
|
||||
|
||||
Writing rosbag1
|
||||
---------------
|
||||
Instances of the :py:class:`Writer <rosbags.rosbag1.Writer>` class can create and write to new rosbag1 files. It is usually used as a context manager. Before the first message of a topic can be written, its topic must first be added to the bag. The following example shows the typical usage pattern:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from rosbags.rosbag1 import Writer
|
||||
from rosbags.serde import cdr_to_ros1, serialize_cdr
|
||||
from rosbags.typesys.types import std_msgs__msg__String as String
|
||||
|
||||
# create writer instance and open for writing
|
||||
with Writer('/home/ros/rosbag_2020_03_24.bag') as writer:
|
||||
# add new connection
|
||||
topic = '/chatter'
|
||||
msgtype = String.__msgtype__
|
||||
connection = writer.add_connection(topic, msgtype, latching=True)
|
||||
|
||||
# serialize and write message
|
||||
message = String('hello world')
|
||||
writer.write(connection, timestamp, cdr_to_ros1(serialize_cdr(message, msgtype), msgtype))
|
||||
|
||||
Reading rosbag1
|
||||
---------------
|
||||
Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically used as context managers and provide access to bag metadata and contents after the bag has been opened. The following example shows the typical usage pattern:
|
||||
@ -10,6 +31,8 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
|
||||
.. code-block:: python
|
||||
|
||||
from rosbags.rosbag1 import Reader
|
||||
from rosbags.serde import deserialize_cdr, ros1_to_cdr
|
||||
|
||||
|
||||
# create reader instance
|
||||
with Reader('/home/ros/rosbag_2020_03_24.bag') as reader:
|
||||
@ -20,9 +43,11 @@ Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically
|
||||
# iterate over messages
|
||||
for connection, timestamp, rawdata in reader.messages():
|
||||
if connection.topic == '/imu_raw/Imu':
|
||||
print(timestamp)
|
||||
msg = deserialize_cdr(ros1_to_cdr(rawdata, connection.msgtype), connection.msgtype)
|
||||
print(msg.header.frame_id)
|
||||
|
||||
# messages() accepts connection filters
|
||||
connections = [x for x in reader.connections.values() if x.topic == '/imu_raw/Imu']
|
||||
for connection, timestamp, rawdata in reader.messages(connections=connections):
|
||||
print(timestamp)
|
||||
msg = deserialize_cdr(ros1_to_cdr(rawdata, connection.msgtype), connection.msgtype)
|
||||
print(msg.header.frame_id)
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbags support for rosbag1 files.
|
||||
|
||||
Reader provides access to metadata and raw message content saved in the
|
||||
rosbag1 format.
|
||||
Readers and writers provide access to metadata and raw message content saved
|
||||
in the rosbag1 format.
|
||||
|
||||
Supported versions:
|
||||
- Rosbag1 v2.0
|
||||
@ -11,8 +11,11 @@ Supported versions:
|
||||
"""
|
||||
|
||||
from .reader import Reader, ReaderError
|
||||
from .writer import Writer, WriterError
|
||||
|
||||
__all__ = [
|
||||
'Reader',
|
||||
'ReaderError',
|
||||
'Writer',
|
||||
'WriterError',
|
||||
]
|
||||
|
||||
403
src/rosbags/rosbag1/writer.py
Normal file
403
src/rosbags/rosbag1/writer.py
Normal file
@ -0,0 +1,403 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag1 writer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from bz2 import compress as bz2_compress
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lz4.frame import compress as lz4_compress # type: ignore
|
||||
|
||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||
|
||||
from .reader import Connection, RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
|
||||
|
||||
|
||||
class WriterError(Exception):
|
||||
"""Writer Error."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteChunk:
|
||||
"""In progress chunk."""
|
||||
data: BytesIO
|
||||
pos: int
|
||||
start: int
|
||||
end: int
|
||||
connections: dict[int, list[tuple[int, int]]]
|
||||
|
||||
|
||||
serialize_uint8 = struct.Struct('<B').pack
|
||||
serialize_uint32 = struct.Struct('<L').pack
|
||||
serialize_uint64 = struct.Struct('<Q').pack
|
||||
|
||||
|
||||
def serialize_time(val: int) -> bytes:
|
||||
"""Serialize time value.
|
||||
|
||||
Args:
|
||||
val: Time value.
|
||||
|
||||
Returns:
|
||||
Serialized bytes.
|
||||
|
||||
"""
|
||||
sec, nsec = val // 10**9, val % 10**9
|
||||
return struct.pack('<LL', sec, nsec)
|
||||
|
||||
|
||||
class Header(dict):
|
||||
"""Record header."""
|
||||
|
||||
def set_uint32(self, name: str, value: int):
|
||||
"""Set field to uint32 value.
|
||||
|
||||
Args:
|
||||
name: Field name.
|
||||
value: Field value.
|
||||
|
||||
"""
|
||||
self[name] = serialize_uint32(value)
|
||||
|
||||
def set_uint64(self, name: str, value: int):
|
||||
"""Set field to uint64 value.
|
||||
|
||||
Args:
|
||||
name: Field name.
|
||||
value: Field value.
|
||||
|
||||
"""
|
||||
self[name] = serialize_uint64(value)
|
||||
|
||||
def set_string(self, name: str, value: str):
|
||||
"""Set field to string value.
|
||||
|
||||
Args:
|
||||
name: Field name.
|
||||
value: Field value.
|
||||
|
||||
"""
|
||||
self[name] = value.encode()
|
||||
|
||||
def set_time(self, name: str, value: int):
|
||||
"""Set field to time value.
|
||||
|
||||
Args:
|
||||
name: Field name.
|
||||
value: Field value.
|
||||
|
||||
"""
|
||||
self[name] = serialize_time(value)
|
||||
|
||||
def write(self, dst: BinaryIO, opcode: Optional[RecordType] = None) -> int:
|
||||
"""Write to file handle.
|
||||
|
||||
Args:
|
||||
dst: File handle.
|
||||
opcode: Record type code.
|
||||
|
||||
Returns:
|
||||
Bytes written.
|
||||
|
||||
"""
|
||||
data = b''
|
||||
|
||||
if opcode:
|
||||
keqv = 'op='.encode() + serialize_uint8(opcode)
|
||||
data += serialize_uint32(len(keqv)) + keqv
|
||||
|
||||
for key, value in self.items():
|
||||
keqv = f'{key}='.encode() + value
|
||||
data += serialize_uint32(len(keqv)) + keqv
|
||||
|
||||
size = len(data)
|
||||
dst.write(serialize_uint32(size) + data)
|
||||
return size + 4
|
||||
|
||||
|
||||
class Writer: # pylint: disable=too-many-instance-attributes
|
||||
"""Rosbag1 writer.
|
||||
|
||||
This class implements writing of rosbag1 files in version 2.0. It should be
|
||||
used as a contextmanager.
|
||||
|
||||
"""
|
||||
|
||||
class CompressionFormat(IntEnum):
|
||||
"""Compession formats."""
|
||||
|
||||
BZ2 = auto()
|
||||
LZ4 = auto()
|
||||
|
||||
def __init__(self, path: Union[Path, str]):
|
||||
"""Initialize writer.
|
||||
|
||||
Args:
|
||||
path: Filesystem path to bag.
|
||||
|
||||
Raises:
|
||||
WriterError: Target path exisits already, Writer can only create new rosbags.
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
self.path = path
|
||||
if path.exists():
|
||||
raise WriterError(f'{path} exists already, not overwriting.')
|
||||
self.bio: Optional[BinaryIO] = None
|
||||
self.compressor: Callable[[bytes], bytes] = lambda x: x
|
||||
self.compression_format = 'none'
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.chunks: list[WriteChunk] = [
|
||||
WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)),
|
||||
]
|
||||
self.chunk_threshold = 1 * (1 << 20)
|
||||
|
||||
def set_compression(self, fmt: CompressionFormat):
|
||||
"""Enable compression on rosbag1.
|
||||
|
||||
This function has to be called before opening.
|
||||
|
||||
Args:
|
||||
fmt: Compressor to use, bz2 or lz4
|
||||
|
||||
Raises:
|
||||
WriterError: Bag already open.
|
||||
|
||||
"""
|
||||
if self.bio:
|
||||
raise WriterError(f'Cannot set compression, bag {self.path} already open.')
|
||||
|
||||
self.compression_format = fmt.name.lower()
|
||||
|
||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
|
||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
|
||||
self.compressor = {
|
||||
'bz2': bz2,
|
||||
'lz4': lz4,
|
||||
}[self.compression_format]
|
||||
|
||||
def open(self):
|
||||
"""Open rosbag1 for writing."""
|
||||
try:
|
||||
self.bio = self.path.open('xb')
|
||||
except FileExistsError:
|
||||
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
||||
|
||||
self.bio.write(b'#ROSBAG V2.0\n')
|
||||
header = Header()
|
||||
header.set_uint64('index_pos', 0)
|
||||
header.set_uint32('conn_count', 0)
|
||||
header.set_uint32('chunk_count', 0)
|
||||
size = header.write(self.bio, RecordType.BAGHEADER)
|
||||
padsize = 4096 - 4 - size
|
||||
self.bio.write(serialize_uint32(padsize) + b' ' * padsize)
|
||||
|
||||
def add_connection( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
topic: str,
|
||||
msgtype: str,
|
||||
msgdef: Optional[str] = None,
|
||||
md5sum: Optional[str] = None,
|
||||
callerid: Optional[str] = None,
|
||||
latching: Optional[int] = None,
|
||||
**_kw: Any,
|
||||
) -> Connection:
|
||||
"""Add a connection.
|
||||
|
||||
This function can only be called after opening a bag.
|
||||
|
||||
Args:
|
||||
topic: Topic name.
|
||||
msgtype: Message type.
|
||||
msgdef: Message definiton.
|
||||
md5sum: Message hash.
|
||||
callerid: Caller id.
|
||||
latching: Latching information.
|
||||
_kw: Ignored to allow consuming dicts from connection objects.
|
||||
|
||||
Returns:
|
||||
Connection id.
|
||||
|
||||
Raises:
|
||||
WriterError: Bag not open or identical topic previously registered.
|
||||
|
||||
"""
|
||||
if not self.bio:
|
||||
raise WriterError('Bag was not opened.')
|
||||
|
||||
if msgdef is None or md5sum is None:
|
||||
msgdef, md5sum = generate_msgdef(msgtype)
|
||||
assert msgdef
|
||||
assert md5sum
|
||||
|
||||
connection = Connection(
|
||||
len(self.connections),
|
||||
topic,
|
||||
denormalize_msgtype(msgtype),
|
||||
md5sum,
|
||||
msgdef,
|
||||
callerid,
|
||||
latching,
|
||||
[],
|
||||
)
|
||||
|
||||
if any(x[1:] == connection[1:] for x in self.connections.values()):
|
||||
raise WriterError(
|
||||
f'Connections can only be added once with same arguments: {connection!r}.',
|
||||
)
|
||||
|
||||
bio = self.chunks[-1].data
|
||||
self.write_connection(connection, bio)
|
||||
|
||||
self.connections[connection.cid] = connection
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
"""Write message to rosbag1.
|
||||
|
||||
Args:
|
||||
connection: Connection to write message to.
|
||||
timestamp: Message timestamp (ns).
|
||||
data: Serialized message data.
|
||||
|
||||
Raises:
|
||||
WriterError: Bag not open or connection not registered.
|
||||
|
||||
"""
|
||||
if not self.bio:
|
||||
raise WriterError('Bag was not opened.')
|
||||
|
||||
if connection not in self.connections.values():
|
||||
raise WriterError(f'There is no connection {connection!r}.') from None
|
||||
|
||||
chunk = self.chunks[-1]
|
||||
chunk.connections[connection.cid].append((timestamp, chunk.data.tell()))
|
||||
|
||||
if timestamp < chunk.start:
|
||||
chunk.start = timestamp
|
||||
|
||||
if timestamp > chunk.end:
|
||||
chunk.end = timestamp
|
||||
|
||||
header = Header()
|
||||
header.set_uint32('conn', connection.cid)
|
||||
header.set_time('time', timestamp)
|
||||
|
||||
header.write(chunk.data, RecordType.MSGDATA)
|
||||
chunk.data.write(serialize_uint32(len(data)))
|
||||
chunk.data.write(data)
|
||||
if chunk.data.tell() > self.chunk_threshold:
|
||||
self.write_chunk(chunk)
|
||||
|
||||
@staticmethod
|
||||
def write_connection(connection: Connection, bio: BytesIO):
|
||||
"""Write connection record."""
|
||||
header = Header()
|
||||
header.set_uint32('conn', connection.cid)
|
||||
header.set_string('topic', connection.topic)
|
||||
header.write(bio, RecordType.CONNECTION)
|
||||
|
||||
header = Header()
|
||||
header.set_string('topic', connection.topic)
|
||||
header.set_string('type', connection.msgtype)
|
||||
header.set_string('md5sum', connection.md5sum)
|
||||
header.set_string('message_definition', connection.msgdef)
|
||||
if connection.callerid is not None:
|
||||
header.set_string('callerid', connection.callerid)
|
||||
if connection.latching is not None:
|
||||
header.set_string('latching', str(connection.latching))
|
||||
header.write(bio)
|
||||
|
||||
def write_chunk(self, chunk: WriteChunk):
|
||||
"""Write open chunk to file."""
|
||||
assert self.bio
|
||||
|
||||
if size := chunk.data.tell() > 0:
|
||||
chunk.pos = self.bio.tell()
|
||||
|
||||
header = Header()
|
||||
header.set_string('compression', self.compression_format)
|
||||
header.set_uint32('size', size)
|
||||
header.write(self.bio, RecordType.CHUNK)
|
||||
data = self.compressor(chunk.data.getvalue())
|
||||
self.bio.write(serialize_uint32(len(data)))
|
||||
self.bio.write(data)
|
||||
|
||||
for cid, items in chunk.connections.items():
|
||||
header = Header()
|
||||
header.set_uint32('ver', 1)
|
||||
header.set_uint32('conn', cid)
|
||||
header.set_uint32('count', len(items))
|
||||
header.write(self.bio, RecordType.IDXDATA)
|
||||
self.bio.write(serialize_uint32(len(items) * 12))
|
||||
for time, offset in items:
|
||||
self.bio.write(serialize_time(time) + serialize_uint32(offset))
|
||||
|
||||
chunk.data.close()
|
||||
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
||||
|
||||
def close(self):
|
||||
"""Close rosbag1 after writing.
|
||||
|
||||
Closes open chunks and writes index.
|
||||
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
if chunk.pos == -1:
|
||||
self.write_chunk(chunk)
|
||||
|
||||
index_pos = self.bio.tell()
|
||||
|
||||
for connection in self.connections.values():
|
||||
self.write_connection(connection, self.bio)
|
||||
|
||||
for chunk in self.chunks:
|
||||
if chunk.pos == -1:
|
||||
continue
|
||||
header = Header()
|
||||
header.set_uint32('ver', 1)
|
||||
header.set_uint64('chunk_pos', chunk.pos)
|
||||
header.set_time('start_time', 0 if chunk.start == 2**64 else chunk.start)
|
||||
header.set_time('end_time', chunk.end)
|
||||
header.set_uint32('count', len(chunk.connections))
|
||||
header.write(self.bio, RecordType.CHUNK_INFO)
|
||||
self.bio.write(serialize_uint32(len(chunk.connections) * 8))
|
||||
for cid, items in chunk.connections.items():
|
||||
self.bio.write(serialize_uint32(cid) + serialize_uint32(len(items)))
|
||||
|
||||
self.bio.seek(13)
|
||||
header = Header()
|
||||
header.set_uint64('index_pos', index_pos)
|
||||
header.set_uint32('conn_count', len(self.connections))
|
||||
header.set_uint32('chunk_count', len([x for x in self.chunks if x.pos != -1]))
|
||||
size = header.write(self.bio, RecordType.BAGHEADER)
|
||||
padsize = 4096 - 4 - size
|
||||
self.bio.write(serialize_uint32(padsize) + b' ' * padsize)
|
||||
|
||||
self.bio.close()
|
||||
|
||||
def __enter__(self) -> Writer:
|
||||
"""Open rosbag1 when entering contextmanager."""
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Close rosbag1 when exiting contextmanager."""
|
||||
self.close()
|
||||
return False
|
||||
44
tests/test_roundtrip1.py
Normal file
44
tests/test_roundtrip1.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Reader, Writer
|
||||
from rosbags.serde import cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
data = 1.25
|
||||
|
||||
path = tmp_path / 'test.bag'
|
||||
wbag = Writer(path)
|
||||
if fmt:
|
||||
wbag.set_compression(fmt)
|
||||
with wbag:
|
||||
msgtype = 'std_msgs/msg/Float64'
|
||||
conn = wbag.add_connection('/test', msgtype)
|
||||
wbag.write(conn, 42, cdr_to_ros1(serialize_cdr(Foo, msgtype), msgtype))
|
||||
|
||||
rbag = Reader(path)
|
||||
with rbag:
|
||||
gen = rbag.messages()
|
||||
connection, _, raw = next(gen)
|
||||
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
|
||||
assert msg.data == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
201
tests/test_writer1.py
Normal file
201
tests/test_writer1.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Writer tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Writer, WriterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def test_no_overwrite(tmp_path: Path):
|
||||
"""Test writer does not touch existing files."""
|
||||
path = tmp_path / 'test.bag'
|
||||
path.write_text('foo')
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
Writer(path).open()
|
||||
path.unlink()
|
||||
|
||||
writer = Writer(path)
|
||||
path.write_text('foo')
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
writer.open()
|
||||
|
||||
|
||||
def test_empty(tmp_path: Path):
|
||||
"""Test empty bag."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with Writer(path):
|
||||
pass
|
||||
data = path.read_bytes()
|
||||
assert len(data) == 13 + 4096
|
||||
|
||||
|
||||
def test_add_connection(tmp_path: Path):
|
||||
"""Test adding of connections."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with pytest.raises(WriterError, match='not opened'):
|
||||
Writer(path).add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
assert res.cid == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'MESSAGE_DEFINITION') == 2
|
||||
assert data.count(b'HASH') == 2
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'std_msgs/msg/Int8')
|
||||
assert res.cid == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'int8 data') == 2
|
||||
assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
with pytest.raises(WriterError, match='can only be added once'):
|
||||
writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
res1 = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
res2 = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
res3 = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
||||
|
||||
|
||||
def test_write_errors(tmp_path: Path):
|
||||
"""Test write errors."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with pytest.raises(WriterError, match='not opened'):
|
||||
Writer(path).write(Mock(), 42, b'DEADBEEF')
|
||||
|
||||
with Writer(path) as writer, \
|
||||
pytest.raises(WriterError, match='is no connection'):
|
||||
writer.write(Mock(), 42, b'DEADBEEF')
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_write_simple(tmp_path: Path):
|
||||
"""Test writing of messages."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
with Writer(path) as writer:
|
||||
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
conn_latching = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
conn_bar = writer.add_connection(
|
||||
'/bar',
|
||||
'test_msgs/msg/Bar',
|
||||
'OTHER_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
|
||||
|
||||
writer.write(conn_foo, 42, b'DEADBEEF')
|
||||
writer.write(conn_latching, 42, b'DEADBEEF')
|
||||
writer.write(conn_bar, 43, b'SECRET')
|
||||
writer.write(conn_bar, 43, b'SUBSEQUENT')
|
||||
|
||||
res = path.read_bytes()
|
||||
assert res.count(b'op=\x05') == 1
|
||||
assert res.count(b'op=\x06') == 1
|
||||
assert res.count(b'MESSAGE_DEFINITION') == 4
|
||||
assert res.count(b'latching=1') == 2
|
||||
assert res.count(b'OTHER_DEFINITION') == 2
|
||||
assert res.count(b'callerid=src') == 2
|
||||
assert res.count(b'NEVER_WRITTEN') == 2
|
||||
assert res.count(b'DEADBEEF') == 2
|
||||
assert res.count(b'SECRET') == 1
|
||||
assert res.count(b'SUBSEQUENT') == 1
|
||||
path.unlink()
|
||||
|
||||
with Writer(path) as writer:
|
||||
writer.chunk_threshold = 256
|
||||
conn_foo = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
conn_latching = writer.add_connection(
|
||||
'/foo',
|
||||
'test_msgs/msg/Test',
|
||||
'MESSAGE_DEFINITION',
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
conn_bar = writer.add_connection(
|
||||
'/bar',
|
||||
'test_msgs/msg/Bar',
|
||||
'OTHER_DEFINITION',
|
||||
'HASH',
|
||||
callerid='src',
|
||||
)
|
||||
writer.add_connection('/baz', 'test_msgs/msg/Baz', 'NEVER_WRITTEN', 'HASH')
|
||||
|
||||
writer.write(conn_foo, 42, b'DEADBEEF')
|
||||
writer.write(conn_latching, 42, b'DEADBEEF')
|
||||
writer.write(conn_bar, 43, b'SECRET')
|
||||
writer.write(conn_bar, 43, b'SUBSEQUENT')
|
||||
|
||||
res = path.read_bytes()
|
||||
assert res.count(b'op=\x05') == 2
|
||||
assert res.count(b'op=\x06') == 2
|
||||
assert res.count(b'MESSAGE_DEFINITION') == 4
|
||||
assert res.count(b'latching=1') == 2
|
||||
assert res.count(b'OTHER_DEFINITION') == 2
|
||||
assert res.count(b'callerid=src') == 2
|
||||
assert res.count(b'NEVER_WRITTEN') == 2
|
||||
assert res.count(b'DEADBEEF') == 2
|
||||
assert res.count(b'SECRET') == 1
|
||||
assert res.count(b'SUBSEQUENT') == 1
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_compression_errors(tmp_path: Path):
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
with Writer(path) as writer, \
|
||||
pytest.raises(WriterError, match='already open'):
|
||||
writer.set_compression(writer.CompressionFormat.BZ2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
writer = Writer(path)
|
||||
if fmt:
|
||||
writer.set_compression(fmt)
|
||||
with writer:
|
||||
conn = writer.add_connection('/foo', 'std_msgs/msg/Int8')
|
||||
writer.write(conn, 42, b'\x42')
|
||||
data = path.read_bytes()
|
||||
assert data.count(f'compression={fmt.name.lower() if fmt else "none"}'.encode()) == 1
|
||||
Loading…
x
Reference in New Issue
Block a user