2022-04-10 23:32:25 +02:00
|
|
|
# Copyright 2020-2022 Ternaris.
|
2021-08-06 09:39:55 +02:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
"""Rosbag1 writer."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import struct
|
|
|
|
|
from bz2 import compress as bz2_compress
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from enum import IntEnum, auto
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from pathlib import Path
|
2021-11-25 14:26:17 +01:00
|
|
|
from typing import TYPE_CHECKING, Any, Dict
|
2021-08-06 09:39:55 +02:00
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
from lz4.frame import compress as lz4_compress
|
2021-08-06 09:39:55 +02:00
|
|
|
|
2022-04-13 09:40:22 +02:00
|
|
|
from rosbags.interfaces import Connection, ConnectionExtRosbag1
|
2021-08-06 09:39:55 +02:00
|
|
|
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
|
|
|
|
|
2022-04-13 09:40:22 +02:00
|
|
|
from .reader import RecordType
|
2021-08-06 09:39:55 +02:00
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from types import TracebackType
|
2021-11-25 14:26:17 +01:00
|
|
|
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
|
2021-08-06 09:39:55 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class WriterError(Exception):
|
|
|
|
|
"""Writer Error."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class WriteChunk:
|
|
|
|
|
"""In progress chunk."""
|
2022-04-11 00:07:53 +02:00
|
|
|
|
2021-08-06 09:39:55 +02:00
|
|
|
data: BytesIO
|
|
|
|
|
pos: int
|
|
|
|
|
start: int
|
|
|
|
|
end: int
|
|
|
|
|
connections: dict[int, list[tuple[int, int]]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serialize_uint8 = struct.Struct('<B').pack
|
|
|
|
|
serialize_uint32 = struct.Struct('<L').pack
|
|
|
|
|
serialize_uint64 = struct.Struct('<Q').pack
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def serialize_time(val: int) -> bytes:
|
|
|
|
|
"""Serialize time value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
val: Time value.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Serialized bytes.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
sec, nsec = val // 10**9, val % 10**9
|
|
|
|
|
return struct.pack('<LL', sec, nsec)
|
|
|
|
|
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
class Header(Dict[str, Any]):
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Record header."""
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def set_uint32(self, name: str, value: int) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Set field to uint32 value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Field name.
|
|
|
|
|
value: Field value.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self[name] = serialize_uint32(value)
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def set_uint64(self, name: str, value: int) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Set field to uint64 value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Field name.
|
|
|
|
|
value: Field value.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self[name] = serialize_uint64(value)
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def set_string(self, name: str, value: str) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Set field to string value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Field name.
|
|
|
|
|
value: Field value.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self[name] = value.encode()
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def set_time(self, name: str, value: int) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Set field to time value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Field name.
|
|
|
|
|
value: Field value.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self[name] = serialize_time(value)
|
|
|
|
|
|
|
|
|
|
def write(self, dst: BinaryIO, opcode: Optional[RecordType] = None) -> int:
|
|
|
|
|
"""Write to file handle.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dst: File handle.
|
|
|
|
|
opcode: Record type code.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Bytes written.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
data = b''
|
|
|
|
|
|
|
|
|
|
if opcode:
|
|
|
|
|
keqv = 'op='.encode() + serialize_uint8(opcode)
|
|
|
|
|
data += serialize_uint32(len(keqv)) + keqv
|
|
|
|
|
|
|
|
|
|
for key, value in self.items():
|
|
|
|
|
keqv = f'{key}='.encode() + value
|
|
|
|
|
data += serialize_uint32(len(keqv)) + keqv
|
|
|
|
|
|
|
|
|
|
size = len(data)
|
|
|
|
|
dst.write(serialize_uint32(size) + data)
|
|
|
|
|
return size + 4
|
|
|
|
|
|
|
|
|
|
|
2022-04-11 00:07:53 +02:00
|
|
|
class Writer:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Rosbag1 writer.
|
|
|
|
|
|
|
|
|
|
This class implements writing of rosbag1 files in version 2.0. It should be
|
|
|
|
|
used as a contextmanager.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class CompressionFormat(IntEnum):
|
|
|
|
|
"""Compession formats."""
|
|
|
|
|
|
|
|
|
|
BZ2 = auto()
|
|
|
|
|
LZ4 = auto()
|
|
|
|
|
|
|
|
|
|
def __init__(self, path: Union[Path, str]):
|
|
|
|
|
"""Initialize writer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
path: Filesystem path to bag.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
WriterError: Target path exisits already, Writer can only create new rosbags.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
path = Path(path)
|
|
|
|
|
self.path = path
|
|
|
|
|
if path.exists():
|
|
|
|
|
raise WriterError(f'{path} exists already, not overwriting.')
|
|
|
|
|
self.bio: Optional[BinaryIO] = None
|
|
|
|
|
self.compressor: Callable[[bytes], bytes] = lambda x: x
|
|
|
|
|
self.compression_format = 'none'
|
|
|
|
|
self.connections: dict[int, Connection] = {}
|
|
|
|
|
self.chunks: list[WriteChunk] = [
|
|
|
|
|
WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)),
|
|
|
|
|
]
|
|
|
|
|
self.chunk_threshold = 1 * (1 << 20)
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def set_compression(self, fmt: CompressionFormat) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Enable compression on rosbag1.
|
|
|
|
|
|
|
|
|
|
This function has to be called before opening.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fmt: Compressor to use, bz2 or lz4
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
WriterError: Bag already open.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if self.bio:
|
|
|
|
|
raise WriterError(f'Cannot set compression, bag {self.path} already open.')
|
|
|
|
|
|
|
|
|
|
self.compression_format = fmt.name.lower()
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
|
|
|
|
|
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
|
2021-08-06 09:39:55 +02:00
|
|
|
self.compressor = {
|
|
|
|
|
'bz2': bz2,
|
|
|
|
|
'lz4': lz4,
|
|
|
|
|
}[self.compression_format]
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def open(self) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Open rosbag1 for writing."""
|
|
|
|
|
try:
|
|
|
|
|
self.bio = self.path.open('xb')
|
|
|
|
|
except FileExistsError:
|
|
|
|
|
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
assert self.bio
|
2021-08-06 09:39:55 +02:00
|
|
|
self.bio.write(b'#ROSBAG V2.0\n')
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_uint64('index_pos', 0)
|
|
|
|
|
header.set_uint32('conn_count', 0)
|
|
|
|
|
header.set_uint32('chunk_count', 0)
|
|
|
|
|
size = header.write(self.bio, RecordType.BAGHEADER)
|
|
|
|
|
padsize = 4096 - 4 - size
|
|
|
|
|
self.bio.write(serialize_uint32(padsize) + b' ' * padsize)
|
|
|
|
|
|
|
|
|
|
def add_connection( # pylint: disable=too-many-arguments
|
|
|
|
|
self,
|
|
|
|
|
topic: str,
|
|
|
|
|
msgtype: str,
|
|
|
|
|
msgdef: Optional[str] = None,
|
|
|
|
|
md5sum: Optional[str] = None,
|
|
|
|
|
callerid: Optional[str] = None,
|
|
|
|
|
latching: Optional[int] = None,
|
2022-04-11 00:07:53 +02:00
|
|
|
**_kw: Any, # noqa: ANN401
|
2021-08-06 09:39:55 +02:00
|
|
|
) -> Connection:
|
|
|
|
|
"""Add a connection.
|
|
|
|
|
|
|
|
|
|
This function can only be called after opening a bag.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
topic: Topic name.
|
|
|
|
|
msgtype: Message type.
|
|
|
|
|
msgdef: Message definiton.
|
|
|
|
|
md5sum: Message hash.
|
|
|
|
|
callerid: Caller id.
|
|
|
|
|
latching: Latching information.
|
|
|
|
|
_kw: Ignored to allow consuming dicts from connection objects.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Connection id.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
WriterError: Bag not open or identical topic previously registered.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not self.bio:
|
|
|
|
|
raise WriterError('Bag was not opened.')
|
|
|
|
|
|
|
|
|
|
if msgdef is None or md5sum is None:
|
|
|
|
|
msgdef, md5sum = generate_msgdef(msgtype)
|
|
|
|
|
assert msgdef
|
|
|
|
|
assert md5sum
|
|
|
|
|
|
|
|
|
|
connection = Connection(
|
|
|
|
|
len(self.connections),
|
|
|
|
|
topic,
|
|
|
|
|
denormalize_msgtype(msgtype),
|
|
|
|
|
msgdef,
|
2021-09-13 10:49:57 +02:00
|
|
|
md5sum,
|
2022-04-13 09:40:22 +02:00
|
|
|
-1,
|
|
|
|
|
ConnectionExtRosbag1(
|
|
|
|
|
callerid,
|
|
|
|
|
latching,
|
|
|
|
|
),
|
2022-04-13 11:09:30 +02:00
|
|
|
self,
|
2021-08-06 09:39:55 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if any(x[1:] == connection[1:] for x in self.connections.values()):
|
|
|
|
|
raise WriterError(
|
|
|
|
|
f'Connections can only be added once with same arguments: {connection!r}.',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
bio = self.chunks[-1].data
|
|
|
|
|
self.write_connection(connection, bio)
|
|
|
|
|
|
2022-04-12 17:48:11 +02:00
|
|
|
self.connections[connection.id] = connection
|
2021-08-06 09:39:55 +02:00
|
|
|
return connection
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Write message to rosbag1.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
connection: Connection to write message to.
|
|
|
|
|
timestamp: Message timestamp (ns).
|
|
|
|
|
data: Serialized message data.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
WriterError: Bag not open or connection not registered.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not self.bio:
|
|
|
|
|
raise WriterError('Bag was not opened.')
|
|
|
|
|
|
|
|
|
|
if connection not in self.connections.values():
|
|
|
|
|
raise WriterError(f'There is no connection {connection!r}.') from None
|
|
|
|
|
|
|
|
|
|
chunk = self.chunks[-1]
|
2022-04-12 17:48:11 +02:00
|
|
|
chunk.connections[connection.id].append((timestamp, chunk.data.tell()))
|
2021-08-06 09:39:55 +02:00
|
|
|
|
|
|
|
|
if timestamp < chunk.start:
|
|
|
|
|
chunk.start = timestamp
|
|
|
|
|
|
|
|
|
|
if timestamp > chunk.end:
|
|
|
|
|
chunk.end = timestamp
|
|
|
|
|
|
|
|
|
|
header = Header()
|
2022-04-12 17:48:11 +02:00
|
|
|
header.set_uint32('conn', connection.id)
|
2021-08-06 09:39:55 +02:00
|
|
|
header.set_time('time', timestamp)
|
|
|
|
|
|
|
|
|
|
header.write(chunk.data, RecordType.MSGDATA)
|
|
|
|
|
chunk.data.write(serialize_uint32(len(data)))
|
|
|
|
|
chunk.data.write(data)
|
|
|
|
|
if chunk.data.tell() > self.chunk_threshold:
|
|
|
|
|
self.write_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2021-11-25 14:26:17 +01:00
|
|
|
def write_connection(connection: Connection, bio: BinaryIO) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Write connection record."""
|
|
|
|
|
header = Header()
|
2022-04-12 17:48:11 +02:00
|
|
|
header.set_uint32('conn', connection.id)
|
2021-08-06 09:39:55 +02:00
|
|
|
header.set_string('topic', connection.topic)
|
|
|
|
|
header.write(bio, RecordType.CONNECTION)
|
|
|
|
|
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_string('topic', connection.topic)
|
|
|
|
|
header.set_string('type', connection.msgtype)
|
|
|
|
|
header.set_string('md5sum', connection.md5sum)
|
|
|
|
|
header.set_string('message_definition', connection.msgdef)
|
2022-04-13 09:40:22 +02:00
|
|
|
assert isinstance(connection.ext, ConnectionExtRosbag1)
|
|
|
|
|
if connection.ext.callerid is not None:
|
|
|
|
|
header.set_string('callerid', connection.ext.callerid)
|
|
|
|
|
if connection.ext.latching is not None:
|
|
|
|
|
header.set_string('latching', str(connection.ext.latching))
|
2021-08-06 09:39:55 +02:00
|
|
|
header.write(bio)
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def write_chunk(self, chunk: WriteChunk) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Write open chunk to file."""
|
|
|
|
|
assert self.bio
|
|
|
|
|
|
|
|
|
|
if size := chunk.data.tell() > 0:
|
|
|
|
|
chunk.pos = self.bio.tell()
|
|
|
|
|
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_string('compression', self.compression_format)
|
|
|
|
|
header.set_uint32('size', size)
|
|
|
|
|
header.write(self.bio, RecordType.CHUNK)
|
|
|
|
|
data = self.compressor(chunk.data.getvalue())
|
|
|
|
|
self.bio.write(serialize_uint32(len(data)))
|
|
|
|
|
self.bio.write(data)
|
|
|
|
|
|
|
|
|
|
for cid, items in chunk.connections.items():
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_uint32('ver', 1)
|
|
|
|
|
header.set_uint32('conn', cid)
|
|
|
|
|
header.set_uint32('count', len(items))
|
|
|
|
|
header.write(self.bio, RecordType.IDXDATA)
|
|
|
|
|
self.bio.write(serialize_uint32(len(items) * 12))
|
|
|
|
|
for time, offset in items:
|
|
|
|
|
self.bio.write(serialize_time(time) + serialize_uint32(offset))
|
|
|
|
|
|
|
|
|
|
chunk.data.close()
|
|
|
|
|
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
|
|
|
|
|
2021-11-25 14:26:17 +01:00
|
|
|
def close(self) -> None:
|
2021-08-06 09:39:55 +02:00
|
|
|
"""Close rosbag1 after writing.
|
|
|
|
|
|
|
|
|
|
Closes open chunks and writes index.
|
|
|
|
|
|
|
|
|
|
"""
|
2021-11-25 14:26:17 +01:00
|
|
|
assert self.bio
|
2021-08-06 09:39:55 +02:00
|
|
|
for chunk in self.chunks:
|
|
|
|
|
if chunk.pos == -1:
|
|
|
|
|
self.write_chunk(chunk)
|
|
|
|
|
|
|
|
|
|
index_pos = self.bio.tell()
|
|
|
|
|
|
|
|
|
|
for connection in self.connections.values():
|
|
|
|
|
self.write_connection(connection, self.bio)
|
|
|
|
|
|
|
|
|
|
for chunk in self.chunks:
|
|
|
|
|
if chunk.pos == -1:
|
|
|
|
|
continue
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_uint32('ver', 1)
|
|
|
|
|
header.set_uint64('chunk_pos', chunk.pos)
|
|
|
|
|
header.set_time('start_time', 0 if chunk.start == 2**64 else chunk.start)
|
|
|
|
|
header.set_time('end_time', chunk.end)
|
|
|
|
|
header.set_uint32('count', len(chunk.connections))
|
|
|
|
|
header.write(self.bio, RecordType.CHUNK_INFO)
|
|
|
|
|
self.bio.write(serialize_uint32(len(chunk.connections) * 8))
|
|
|
|
|
for cid, items in chunk.connections.items():
|
|
|
|
|
self.bio.write(serialize_uint32(cid) + serialize_uint32(len(items)))
|
|
|
|
|
|
|
|
|
|
self.bio.seek(13)
|
|
|
|
|
header = Header()
|
|
|
|
|
header.set_uint64('index_pos', index_pos)
|
|
|
|
|
header.set_uint32('conn_count', len(self.connections))
|
|
|
|
|
header.set_uint32('chunk_count', len([x for x in self.chunks if x.pos != -1]))
|
|
|
|
|
size = header.write(self.bio, RecordType.BAGHEADER)
|
|
|
|
|
padsize = 4096 - 4 - size
|
|
|
|
|
self.bio.write(serialize_uint32(padsize) + b' ' * padsize)
|
|
|
|
|
|
|
|
|
|
self.bio.close()
|
|
|
|
|
|
|
|
|
|
def __enter__(self) -> Writer:
|
|
|
|
|
"""Open rosbag1 when entering contextmanager."""
|
|
|
|
|
self.open()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(
|
|
|
|
|
self,
|
|
|
|
|
exc_type: Optional[Type[BaseException]],
|
|
|
|
|
exc_val: Optional[BaseException],
|
|
|
|
|
exc_tb: Optional[TracebackType],
|
|
|
|
|
) -> Literal[False]:
|
|
|
|
|
"""Close rosbag1 when exiting contextmanager."""
|
|
|
|
|
self.close()
|
|
|
|
|
return False
|