Add rosbag1 support
This commit is contained in:
parent
ffacb7602c
commit
4de0c99274
6
docs/api/rosbags.rosbag1.rst
Normal file
6
docs/api/rosbags.rosbag1.rst
Normal file
@ -0,0 +1,6 @@
|
||||
rosbags.rosbag1
|
||||
===============
|
||||
|
||||
.. automodule:: rosbags.rosbag1
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@ -4,6 +4,7 @@ Rosbags namespace
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
rosbags.rosbag1
|
||||
rosbags.rosbag2
|
||||
rosbags.serde
|
||||
rosbags.typesys
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
topics/typesys
|
||||
topics/serde
|
||||
topics/rosbag2
|
||||
topics/rosbag1
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
||||
27
docs/topics/rosbag1.rst
Normal file
27
docs/topics/rosbag1.rst
Normal file
@ -0,0 +1,27 @@
|
||||
Rosbag1
|
||||
=======
|
||||
|
||||
The :py:mod:`rosbags.rosbag1` package provides fast read-only access to raw messages stored in the legacy bag format. The rosbag1 support is built for a ROS2 world and some APIs and values perform normalizations to mimic ROS2 behavior and make messages originating from rosbag1 and rosbag2 behave identically. Most notably message types are internally renamed to match their ROS2 counterparts.
|
||||
|
||||
Reading rosbag1
|
||||
---------------
|
||||
Instances of the :py:class:`Reader <rosbags.rosbag2.Reader>` class are typically used as context managers and provide access to bag metadata and contents after the bag has been opened. The following example shows the typical usage pattern:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from rosbags.rosbag1 import Reader
|
||||
|
||||
# create reader instance
|
||||
with Reader('/home/ros/rosbag_2020_03_24.bag') as reader:
|
||||
# topic and msgtype information is available on .topics dictionary
|
||||
for topic, info in reader.topics.items():
|
||||
print(topic, info)
|
||||
|
||||
# iterate over messages
|
||||
for topic, msgtype, rawdata, timestamp in reader.messages():
|
||||
if topic == '/imu_raw/Imu':
|
||||
print(timestamp)
|
||||
|
||||
# messages() accepts topic filters
|
||||
for topic, msgtype, rawdata, timestamp in reader.messages(['/imu_raw/Imu']):
|
||||
print(timestamp)
|
||||
18
src/rosbags/rosbag1/__init__.py
Normal file
18
src/rosbags/rosbag1/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbags support for rosbag1 files.
|
||||
|
||||
Reader provides access to metadata and raw message content saved in the
|
||||
rosbag1 format.
|
||||
|
||||
Supported versions:
|
||||
- Rosbag1 v2.0
|
||||
|
||||
"""
|
||||
|
||||
from .reader import Reader, ReaderError
|
||||
|
||||
__all__ = [
|
||||
'Reader',
|
||||
'ReaderError',
|
||||
]
|
||||
645
src/rosbags/rosbag1/reader.py
Normal file
645
src/rosbags/rosbag1/reader.py
Normal file
@ -0,0 +1,645 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag1 v2.0 reader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import heapq
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
from bz2 import decompress as bz2_decompress
|
||||
from enum import Enum, IntEnum
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
from lz4.frame import decompress as lz4_decompress # type: ignore
|
||||
|
||||
from rosbags.typesys.msg import normalize_msgtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
BinaryIO,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
"""Reader Error."""
|
||||
|
||||
|
||||
class Compression(Enum):
|
||||
"""Compression mode."""
|
||||
|
||||
NONE = 'none'
|
||||
BZ2 = 'bz2'
|
||||
LZ4 = 'lz4'
|
||||
|
||||
|
||||
class RecordType(IntEnum):
|
||||
"""Record type."""
|
||||
|
||||
MSGDATA = 2
|
||||
BAGHEADER = 3
|
||||
IDXDATA = 4
|
||||
CHUNK = 5
|
||||
CHUNK_INFO = 6
|
||||
CONNECTION = 7
|
||||
|
||||
|
||||
class Connection(NamedTuple):
|
||||
"""Connection information."""
|
||||
|
||||
cid: int
|
||||
topic: str
|
||||
msgtype: str
|
||||
md5sum: str
|
||||
msgdef: str
|
||||
indexes: List
|
||||
|
||||
|
||||
class ChunkInfo(NamedTuple):
|
||||
"""Chunk information."""
|
||||
|
||||
pos: int
|
||||
start_time: int
|
||||
end_time: int
|
||||
connection_counts: Dict[int, int]
|
||||
|
||||
|
||||
class Chunk(NamedTuple):
|
||||
"""Chunk metadata."""
|
||||
|
||||
datasize: int
|
||||
datapos: int
|
||||
decompressor: Callable
|
||||
|
||||
|
||||
class TopicInfo(NamedTuple):
|
||||
"""Topic information."""
|
||||
|
||||
conn_count: int
|
||||
msgcount: int
|
||||
msgdef: str
|
||||
msgtype: str
|
||||
|
||||
|
||||
class IndexData(NamedTuple):
|
||||
"""Index data."""
|
||||
|
||||
time: int
|
||||
chunk_pos: int
|
||||
offset: int
|
||||
|
||||
def __lt__(self, other: Tuple[int, ...]) -> bool:
|
||||
"""Compare by time only."""
|
||||
return self.time < other[0]
|
||||
|
||||
def __le__(self, other: Tuple[int, ...]) -> bool:
|
||||
"""Compare by time only."""
|
||||
return self.time <= other[0]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare by time only."""
|
||||
if not isinstance(other, IndexData): # pragma: no cover
|
||||
return NotImplemented
|
||||
return self.time == other[0]
|
||||
|
||||
def __ge__(self, other: Tuple[int, ...]) -> bool:
|
||||
"""Compare by time only."""
|
||||
return self.time >= other[0]
|
||||
|
||||
def __gt__(self, other: Tuple[int, ...]) -> bool:
|
||||
"""Compare by time only."""
|
||||
return self.time > other[0]
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
"""Compare by time only."""
|
||||
if not isinstance(other, IndexData): # pragma: no cover
|
||||
return NotImplemented
|
||||
return self.time != other[0]
|
||||
|
||||
|
||||
deserialize_uint8 = struct.Struct('<B').unpack
|
||||
deserialize_uint32 = struct.Struct('<L').unpack
|
||||
deserialize_uint64 = struct.Struct('<Q').unpack
|
||||
|
||||
|
||||
def deserialize_time(val: bytes) -> int:
|
||||
"""Deserialize time value.
|
||||
|
||||
Args:
|
||||
val: Serialized bytes.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
"""
|
||||
sec, nsec = struct.unpack('<LL', val)
|
||||
return sec * 10**9 + nsec
|
||||
|
||||
|
||||
class Header(dict):
|
||||
"""Record header."""
|
||||
|
||||
def get_uint8(self, name: str) -> int:
|
||||
"""Get uint8 value from field.
|
||||
|
||||
Args:
|
||||
name: Name of field.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Field not present or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return deserialize_uint8(self[name])[0]
|
||||
except (KeyError, struct.error) as err:
|
||||
raise ReaderError(f'Could not read uint8 field {name!r}.') from err
|
||||
|
||||
def get_uint32(self, name: str) -> int:
|
||||
"""Get uint32 value from field.
|
||||
|
||||
Args:
|
||||
name: Name of field.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Field not present or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return deserialize_uint32(self[name])[0]
|
||||
except (KeyError, struct.error) as err:
|
||||
raise ReaderError(f'Could not read uint32 field {name!r}.') from err
|
||||
|
||||
def get_uint64(self, name: str) -> int:
|
||||
"""Get uint64 value from field.
|
||||
|
||||
Args:
|
||||
name: Name of field.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Field not present or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return deserialize_uint64(self[name])[0]
|
||||
except (KeyError, struct.error) as err:
|
||||
raise ReaderError(f'Could not read uint64 field {name!r}.') from err
|
||||
|
||||
def get_string(self, name: str) -> str:
|
||||
"""Get string value from field.
|
||||
|
||||
Args:
|
||||
name: Name of field.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Field not present or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return self[name].decode()
|
||||
except (KeyError, ValueError) as err:
|
||||
raise ReaderError(f'Could not read string field {name!r}.') from err
|
||||
|
||||
def get_time(self, name: str) -> int:
|
||||
"""Get time value from field.
|
||||
|
||||
Args:
|
||||
name: Name of field.
|
||||
|
||||
Returns:
|
||||
Deserialized value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Field not present or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return deserialize_time(self[name])
|
||||
except (KeyError, struct.error) as err:
|
||||
raise ReaderError(f'Could not read time field {name!r}.') from err
|
||||
|
||||
@classmethod
|
||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header':
|
||||
"""Read header from file handle.
|
||||
|
||||
Args:
|
||||
src: File handle.
|
||||
expect: Expected record op.
|
||||
|
||||
Returns:
|
||||
Header object.
|
||||
|
||||
Raises:
|
||||
ReaderError: Header could not parsed.
|
||||
|
||||
"""
|
||||
try:
|
||||
binary = read_bytes(src, read_uint32(src))
|
||||
except ReaderError as err:
|
||||
raise ReaderError('Header could not be read from file.') from err
|
||||
|
||||
header = cls()
|
||||
pos = 0
|
||||
length = len(binary)
|
||||
while pos < length:
|
||||
try:
|
||||
size = deserialize_uint32(binary[pos:pos + 4])[0]
|
||||
except struct.error as err:
|
||||
raise ReaderError('Header field size could not be read.') from err
|
||||
pos += 4
|
||||
|
||||
if pos + size > length:
|
||||
raise ReaderError('Declared field size is too large for header.')
|
||||
|
||||
name, sep, value = binary[pos:pos + size].partition(b'=')
|
||||
if not sep:
|
||||
raise ReaderError('Header field could not be parsed.')
|
||||
pos += size
|
||||
|
||||
header[name.decode()] = value
|
||||
|
||||
if expect:
|
||||
have = header.get_uint8('op')
|
||||
if expect != have:
|
||||
raise ReaderError(f'Record of type {RecordType(have).name!r} is unexpected.')
|
||||
|
||||
return header
|
||||
|
||||
|
||||
def read_uint32(src: BinaryIO) -> int:
|
||||
"""Read uint32 from source.
|
||||
|
||||
Args:
|
||||
src: File handle.
|
||||
|
||||
Returns:
|
||||
Uint32 value.
|
||||
|
||||
Raises:
|
||||
ReaderError: Value unreadable or not deserializable.
|
||||
|
||||
"""
|
||||
try:
|
||||
return deserialize_uint32(src.read(4))[0]
|
||||
except struct.error as err:
|
||||
raise ReaderError('Could not read uint32.') from err
|
||||
|
||||
|
||||
def read_bytes(src: BinaryIO, size: int) -> bytes:
|
||||
"""Read bytes from source.
|
||||
|
||||
Args:
|
||||
src: File handle.
|
||||
size: Number of bytes to read.
|
||||
|
||||
Returns:
|
||||
Read bytes.
|
||||
|
||||
Raises:
|
||||
ReaderError: Not enough bytes available.
|
||||
|
||||
"""
|
||||
data = src.read(size)
|
||||
if len(data) != size:
|
||||
raise ReaderError(f'Got only {len(data)} of requested {size} bytes.')
|
||||
return data
|
||||
|
||||
|
||||
def normalize(name: str) -> str:
|
||||
"""Normalize topic name.
|
||||
|
||||
Args:
|
||||
name: Topic name.
|
||||
|
||||
Returns:
|
||||
Normalized name.
|
||||
|
||||
"""
|
||||
return f'{"/" * (name[0] == "/")}{"/".join(x for x in name.split("/") if x)}'
|
||||
|
||||
|
||||
class Reader:
|
||||
"""Rosbag 1 version 2.0 reader.
|
||||
|
||||
This class is designed for a ROS2 world, it will automatically normalize
|
||||
message type names to be in line with their ROS2 counterparts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
path: Filesystem path to bag.
|
||||
|
||||
Raises:
|
||||
ReaderError: Path does not exist.
|
||||
|
||||
"""
|
||||
self.path = Path(path)
|
||||
if not self.path.exists():
|
||||
raise ReaderError(f'File {str(self.path)!r} does not exist.')
|
||||
|
||||
self.bio: Optional[BinaryIO] = None
|
||||
self.connections: Dict[int, Connection] = {}
|
||||
self.chunk_infos: List[ChunkInfo] = []
|
||||
self.chunks: Dict[int, Chunk] = {}
|
||||
self.current_chunk = (-1, BytesIO())
|
||||
self.topics: Dict[str, TopicInfo] = {}
|
||||
|
||||
def open(self): # pylint: disable=too-many-branches,too-many-locals
|
||||
"""Open rosbag and read metadata."""
|
||||
try:
|
||||
self.bio = self.path.open('rb')
|
||||
except OSError as err:
|
||||
raise ReaderError(f'Could not open file {str(self.path)!r}: {err.strerror}.') from err
|
||||
|
||||
try:
|
||||
magic = self.bio.readline().decode()
|
||||
if not magic:
|
||||
raise ReaderError(f'File {str(self.path)!r} seems to be empty.')
|
||||
|
||||
matches = re.match(r'#ROSBAG V(\d+).(\d+)\n', magic)
|
||||
if not matches:
|
||||
raise ReaderError('File magic is invalid.')
|
||||
major, minor = matches.groups()
|
||||
version = int(major) * 100 + int(minor)
|
||||
if version != 200:
|
||||
raise ReaderError(f'Bag version {version!r} is not supported.')
|
||||
|
||||
header = Header.read(self.bio, RecordType.BAGHEADER)
|
||||
index_pos = header.get_uint64('index_pos')
|
||||
conn_count = header.get_uint32('conn_count')
|
||||
chunk_count = header.get_uint32('chunk_count')
|
||||
try:
|
||||
encryptor = header.get_string('encryptor')
|
||||
if encryptor:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
|
||||
except ReaderError:
|
||||
pass
|
||||
|
||||
if index_pos == 0:
|
||||
raise ReaderError('Bag is not indexed, reindex before reading.')
|
||||
|
||||
self.bio.seek(index_pos)
|
||||
self.connections = dict(self.read_connection() for _ in range(conn_count))
|
||||
self.chunk_infos = [self.read_chunk_info() for _ in range(chunk_count)]
|
||||
self.chunks = {}
|
||||
for chunk_info in self.chunk_infos:
|
||||
self.bio.seek(chunk_info.pos)
|
||||
self.chunks[chunk_info.pos] = self.read_chunk()
|
||||
|
||||
for _ in range(len(chunk_info.connection_counts)):
|
||||
cid, index = self.read_index_data(chunk_info.pos)
|
||||
self.connections[cid].indexes.append(index)
|
||||
|
||||
for connection in self.connections.values():
|
||||
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
|
||||
assert connection.indexes
|
||||
|
||||
self.topics = {}
|
||||
for topic, connections in groupby(
|
||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||
key=lambda x: x.topic,
|
||||
):
|
||||
connections = list(connections)
|
||||
count = reduce(
|
||||
lambda x, y: x + y,
|
||||
(
|
||||
y.connection_counts.get(x.cid, 0)
|
||||
for x in connections
|
||||
for y in self.chunk_infos
|
||||
),
|
||||
)
|
||||
|
||||
self.topics[topic] = TopicInfo(
|
||||
len(connections),
|
||||
count,
|
||||
connections[0].msgdef,
|
||||
connections[0].msgtype,
|
||||
)
|
||||
except ReaderError:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Close rosbag."""
|
||||
assert self.bio
|
||||
self.bio.close()
|
||||
self.bio = None
|
||||
|
||||
@property
|
||||
def duration(self) -> int:
|
||||
"""Duration in nanoseconds between earliest and latest messages."""
|
||||
return self.end_time - self.start_time
|
||||
|
||||
@property
|
||||
def start_time(self) -> int:
|
||||
"""Timestamp in nanoseconds of the earliest message."""
|
||||
return min(x.start_time for x in self.chunk_infos)
|
||||
|
||||
@property
|
||||
def end_time(self) -> int:
|
||||
"""Timestamp in nanoseconds of the latest message."""
|
||||
return max(x.end_time for x in self.chunk_infos)
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total message count."""
|
||||
return reduce(lambda x, y: x + y, (x.msgcount for x in self.topics.values()), 0)
|
||||
|
||||
def read_connection(self) -> Tuple[int, Connection]:
|
||||
"""Read connection record from current position."""
|
||||
assert self.bio
|
||||
header = Header.read(self.bio, RecordType.CONNECTION)
|
||||
conn = header.get_uint32('conn')
|
||||
topic = normalize(header.get_string('topic'))
|
||||
|
||||
header = Header.read(self.bio)
|
||||
typ = header.get_string('type')
|
||||
md5sum = header.get_string('md5sum')
|
||||
msgdef = header.get_string('message_definition')
|
||||
|
||||
return conn, Connection(conn, topic, normalize_msgtype(typ), md5sum, msgdef, [])
|
||||
|
||||
def read_chunk_info(self) -> ChunkInfo:
|
||||
"""Read chunk info record from current position."""
|
||||
assert self.bio
|
||||
header = Header.read(self.bio, RecordType.CHUNK_INFO)
|
||||
|
||||
ver = header.get_uint32('ver')
|
||||
if ver != 1:
|
||||
raise ReaderError(f'CHUNK_INFO version {ver} is not supported.')
|
||||
|
||||
chunk_pos = header.get_uint64('chunk_pos')
|
||||
start_time = header.get_time('start_time')
|
||||
end_time = header.get_time('end_time')
|
||||
count = header.get_uint32('count')
|
||||
|
||||
self.bio.seek(4, os.SEEK_CUR)
|
||||
|
||||
return ChunkInfo(
|
||||
chunk_pos,
|
||||
start_time,
|
||||
end_time,
|
||||
{read_uint32(self.bio): read_uint32(self.bio) for _ in range(count)},
|
||||
)
|
||||
|
||||
def read_chunk(self) -> Chunk:
|
||||
"""Read chunk record header from current position."""
|
||||
assert self.bio
|
||||
header = Header.read(self.bio, RecordType.CHUNK)
|
||||
compression = header.get_string('compression')
|
||||
datasize = read_uint32(self.bio)
|
||||
datapos = self.bio.tell()
|
||||
self.bio.seek(datasize, os.SEEK_CUR)
|
||||
try:
|
||||
decompressor = {
|
||||
Compression.NONE.value: lambda x: x,
|
||||
Compression.BZ2.value: bz2_decompress,
|
||||
Compression.LZ4.value: lz4_decompress,
|
||||
}[compression]
|
||||
except KeyError:
|
||||
raise ReaderError(f'Compression {compression!r} is not supported.') from None
|
||||
|
||||
return Chunk(
|
||||
datasize,
|
||||
datapos,
|
||||
decompressor,
|
||||
)
|
||||
|
||||
def read_index_data(self, pos: int) -> Tuple[int, List[IndexData]]:
|
||||
"""Read index data from position.
|
||||
|
||||
Args:
|
||||
pos: Seek position.
|
||||
|
||||
Returns:
|
||||
Connection id and list of index data.
|
||||
|
||||
Raises:
|
||||
ReaderError: Record unreadable.
|
||||
|
||||
"""
|
||||
assert self.bio
|
||||
header = Header.read(self.bio, RecordType.IDXDATA)
|
||||
|
||||
ver = header.get_uint32('ver')
|
||||
if ver != 1:
|
||||
raise ReaderError(f'IDXDATA version {ver} is not supported.')
|
||||
conn = header.get_uint32('conn')
|
||||
count = header.get_uint32('count')
|
||||
|
||||
self.bio.seek(4, os.SEEK_CUR)
|
||||
|
||||
index: List[IndexData] = []
|
||||
for _ in range(count):
|
||||
time = deserialize_time(self.bio.read(8))
|
||||
offset = read_uint32(self.bio)
|
||||
bisect.insort(index, IndexData(time, pos, offset))
|
||||
return conn, index
|
||||
|
||||
def messages(
|
||||
self,
|
||||
topics: Optional[Iterable[str]] = None,
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[Tuple[str, str, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
topics: Iterable with topic names to filter for. An empty iterable
|
||||
yields all messages.
|
||||
start: Yield only messages at or after this timestamp (ns).
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
Tuples of topic name, type, timestamp (ns), and rawdata.
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not open or data corrupt.
|
||||
|
||||
"""
|
||||
if not self.bio:
|
||||
raise ReaderError('Rosbag is not open.')
|
||||
|
||||
indexes = [x.indexes for x in self.connections.values() if not topics or x.topic in topics]
|
||||
for entry in heapq.merge(*indexes):
|
||||
if start and entry.time < start:
|
||||
continue
|
||||
if stop and entry.time >= stop:
|
||||
return
|
||||
|
||||
if self.current_chunk[0] != entry.chunk_pos:
|
||||
self.current_chunk[1].close()
|
||||
|
||||
chunk_header = self.chunks[entry.chunk_pos]
|
||||
self.bio.seek(chunk_header.datapos)
|
||||
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
|
||||
|
||||
chunk = self.current_chunk[1]
|
||||
chunk.seek(entry.offset)
|
||||
|
||||
while True:
|
||||
header = Header.read(chunk)
|
||||
have = header.get_uint8('op')
|
||||
if have != RecordType.CONNECTION:
|
||||
break
|
||||
chunk.seek(read_uint32(chunk), os.SEEK_CUR)
|
||||
|
||||
if have != RecordType.MSGDATA:
|
||||
raise ReaderError('Expected to find message data.')
|
||||
|
||||
connection = self.connections[header.get_uint32('conn')]
|
||||
time = header.get_time('time')
|
||||
|
||||
data = read_bytes(chunk, read_uint32(chunk))
|
||||
|
||||
assert entry.time == time
|
||||
yield connection.topic, connection.msgtype, time, data
|
||||
|
||||
def __enter__(self) -> Reader:
|
||||
"""Open rosbag1 when entering contextmanager."""
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Close rosbag1 when exiting contextmanager."""
|
||||
self.close()
|
||||
return False
|
||||
394
tests/test_reader1.py
Normal file
394
tests/test_reader1.py
Normal file
@ -0,0 +1,394 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
from collections import defaultdict
|
||||
from struct import pack
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rosbags.rosbag1 import Reader, ReaderError
|
||||
from rosbags.rosbag1.reader import IndexData
|
||||
|
||||
|
||||
def ser(data):
|
||||
"""Serialize record header."""
|
||||
if isinstance(data, dict):
|
||||
fields = []
|
||||
for key, value in data.items():
|
||||
field = b'='.join([key.encode(), value])
|
||||
fields.append(pack('<L', len(field)) + field)
|
||||
data = b''.join(fields)
|
||||
return pack('<L', len(data)) + data
|
||||
|
||||
|
||||
def create_default_header():
|
||||
"""Create empty rosbag header."""
|
||||
return {
|
||||
'op': b'\x03',
|
||||
'conn_count': pack('<L', 0),
|
||||
'chunk_count': pack('<L', 0),
|
||||
}
|
||||
|
||||
|
||||
def create_connection(cid=1, topic=0, typ=0):
|
||||
"""Create connection record."""
|
||||
return {
|
||||
'op': b'\x07',
|
||||
'conn': pack('<L', cid),
|
||||
'topic': f'/topic{topic}'.encode(),
|
||||
}, {
|
||||
'type': f'foo_msgs/msg/Foo{typ}'.encode(),
|
||||
'md5sum': b'AAAA',
|
||||
'message_definition': b'MSGDEF',
|
||||
}
|
||||
|
||||
|
||||
def create_message(cid=1, time=0, msg=0):
|
||||
"""Create message record."""
|
||||
return {
|
||||
'op': b'\x02',
|
||||
'conn': cid,
|
||||
'time': time,
|
||||
}, f'MSGCONTENT{msg}'.encode()
|
||||
|
||||
|
||||
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements
|
||||
"""Write bag file."""
|
||||
magic = b'#ROSBAG V2.0\n'
|
||||
|
||||
pos = 13 + 4096
|
||||
conn_count = 0
|
||||
chunk_count = len(chunks or [])
|
||||
|
||||
chunks_bytes = b''
|
||||
connections = b''
|
||||
chunkinfos = b''
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
chunk_bytes = b''
|
||||
start_time = 2**32 - 1
|
||||
end_time = 0
|
||||
counts = defaultdict(int)
|
||||
index = {}
|
||||
offset = 0
|
||||
|
||||
for head, data in chunk:
|
||||
if head.get('op') == b'\x07':
|
||||
conn_count += 1
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
connections += add
|
||||
elif head.get('op') == b'\x02':
|
||||
time = head['time']
|
||||
head['time'] = pack('<LL', head['time'], 0)
|
||||
conn = head['conn']
|
||||
head['conn'] = pack('<L', head['conn'])
|
||||
|
||||
start_time = min([start_time, time])
|
||||
end_time = max([end_time, time])
|
||||
|
||||
counts[conn] += 1
|
||||
if conn not in index:
|
||||
index[conn] = {
|
||||
'count': 0,
|
||||
'msgs': b'',
|
||||
}
|
||||
index[conn]['count'] += 1
|
||||
index[conn]['msgs'] += pack('<LLL', time, 0, offset)
|
||||
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
offset = len(chunk_bytes)
|
||||
else:
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
|
||||
chunk_bytes = ser(
|
||||
{
|
||||
'op': b'\x05',
|
||||
'compression': b'none',
|
||||
'size': pack('<L', len(chunk_bytes))
|
||||
}
|
||||
) + ser(chunk_bytes)
|
||||
for conn, data in index.items():
|
||||
chunk_bytes += ser(
|
||||
{
|
||||
'op': b'\x04',
|
||||
'ver': pack('<L', 1),
|
||||
'conn': pack('<L', conn),
|
||||
'count': pack('<L', data['count']),
|
||||
}
|
||||
) + ser(data['msgs'])
|
||||
|
||||
chunks_bytes += chunk_bytes
|
||||
chunkinfos += ser(
|
||||
{
|
||||
'op': b'\x06',
|
||||
'ver': pack('<L', 1),
|
||||
'chunk_pos': pack('<Q', pos),
|
||||
'start_time': pack('<LL', start_time, 0),
|
||||
'end_time': pack('<LL', end_time, 0),
|
||||
'count': pack('<L', len(counts.keys())),
|
||||
}
|
||||
) + ser(b''.join([pack('<LL', x, y) for x, y in counts.items()]))
|
||||
pos += len(chunk_bytes)
|
||||
|
||||
header['conn_count'] = pack('<L', conn_count)
|
||||
header['chunk_count'] = pack('<L', chunk_count)
|
||||
if 'index_pos' not in header:
|
||||
header['index_pos'] = pack('<Q', pos)
|
||||
|
||||
header = ser(header)
|
||||
header += b'\x00' * (4096 - len(header))
|
||||
|
||||
bag.write_bytes(b''.join([
|
||||
magic,
|
||||
header,
|
||||
chunks_bytes,
|
||||
connections,
|
||||
chunkinfos,
|
||||
]))
|
||||
|
||||
|
||||
def test_indexdata():
|
||||
"""Test IndexData sort sorder."""
|
||||
x42_1_0 = IndexData(42, 1, 0)
|
||||
x42_2_0 = IndexData(42, 2, 0)
|
||||
x43_3_0 = IndexData(43, 3, 0)
|
||||
|
||||
# flake8: noqa
|
||||
# pylint: disable=unneeded-not
|
||||
assert not x42_1_0 < x42_2_0
|
||||
assert x42_1_0 <= x42_2_0
|
||||
assert x42_1_0 == x42_2_0
|
||||
assert not x42_1_0 != x42_2_0
|
||||
assert x42_1_0 >= x42_2_0
|
||||
assert not x42_1_0 > x42_2_0
|
||||
|
||||
assert x42_1_0 < x43_3_0
|
||||
assert x42_1_0 <= x43_3_0
|
||||
assert not x42_1_0 == x43_3_0
|
||||
assert x42_1_0 != x43_3_0
|
||||
assert not x42_1_0 >= x43_3_0
|
||||
assert not x42_1_0 > x43_3_0
|
||||
|
||||
|
||||
def test_reader(tmp_path): # pylint: disable=too-many-statements
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
# empty bag
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, create_default_header())
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 0
|
||||
|
||||
# empty bag, explicit encryptor
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, {**create_default_header(), 'encryptor': b''})
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 0
|
||||
|
||||
# single message
|
||||
write_bag(
|
||||
bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(time=42),
|
||||
]]
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 1
|
||||
assert reader.duration == 0
|
||||
assert reader.start_time == 42 * 10**9
|
||||
assert reader.end_time == 42 * 10**9
|
||||
assert len(reader.topics.keys()) == 1
|
||||
assert reader.topics['/topic0'].msgcount == 1
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 1
|
||||
|
||||
# sorts by time on same topic
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[
|
||||
[
|
||||
create_connection(),
|
||||
create_message(time=10, msg=10),
|
||||
create_message(time=5, msg=5),
|
||||
]
|
||||
]
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert reader.message_count == 2
|
||||
assert reader.duration == 5 * 10**9
|
||||
assert reader.start_time == 5 * 10**9
|
||||
assert reader.end_time == 10 * 10**9
|
||||
assert len(reader.topics.keys()) == 1
|
||||
assert reader.topics['/topic0'].msgcount == 2
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
assert msgs[1][3] == b'MSGCONTENT10'
|
||||
|
||||
# sorts by time on different topic
|
||||
write_bag(
|
||||
bag,
|
||||
create_default_header(),
|
||||
chunks=[
|
||||
[
|
||||
create_connection(),
|
||||
create_message(time=10, msg=10),
|
||||
create_connection(cid=2, topic=2),
|
||||
create_message(cid=2, time=5, msg=5),
|
||||
]
|
||||
]
|
||||
)
|
||||
with Reader(bag) as reader:
|
||||
assert len(reader.topics.keys()) == 2
|
||||
assert reader.topics['/topic0'].msgcount == 1
|
||||
assert reader.topics['/topic2'].msgcount == 1
|
||||
msgs = list(reader.messages())
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
assert msgs[1][3] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(['/topic0']))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(start=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT10'
|
||||
|
||||
msgs = list(reader.messages(stop=7 * 10**9))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0][3] == b'MSGCONTENT5'
|
||||
|
||||
|
||||
def test_user_errors(tmp_path):
|
||||
"""Test user errors."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
|
||||
reader = Reader(bag)
|
||||
with pytest.raises(ReaderError, match='is not open'):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements
|
||||
"""Test failure cases."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
with pytest.raises(ReaderError, match='does not exist'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('')
|
||||
with patch('pathlib.Path.open', side_effect=IOError), \
|
||||
pytest.raises(ReaderError, match='not open'):
|
||||
Reader(bag).open()
|
||||
|
||||
with pytest.raises(ReaderError, match='empty'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('#BADMAGIC')
|
||||
with pytest.raises(ReaderError, match='magic is invalid'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_text('#ROSBAG V3.0\n')
|
||||
with pytest.raises(ReaderError, match='Bag version 300 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x00')
|
||||
with pytest.raises(ReaderError, match='Header could not be read from file.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00')
|
||||
with pytest.raises(ReaderError, match='Header could not be read from file.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x01\x00\x00\x00\x01')
|
||||
with pytest.raises(ReaderError, match='Header field size could not be read.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x04\x00\x00\x00\x01\x00\x00\x00')
|
||||
with pytest.raises(ReaderError, match='Declared field size is too large for header.'):
|
||||
Reader(bag).open()
|
||||
|
||||
bag.write_bytes(b'#ROSBAG V2.0\x0a\x05\x00\x00\x00\x01\x00\x00\x00x')
|
||||
with pytest.raises(ReaderError, match='Header field could not be parsed.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, {'encryptor': b'enc', **create_default_header()})
|
||||
with pytest.raises(ReaderError, match='is not supported'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, {**create_default_header(), 'index_pos': pack('<Q', 0)})
|
||||
with pytest.raises(ReaderError, match='Bag is not indexed'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'none', b'COMP'))
|
||||
with pytest.raises(ReaderError, match='Compression \'COMP\' is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02'))
|
||||
with pytest.raises(ReaderError, match='CHUNK_INFO version 2 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'ver=\x01', b'ver=\x02', 1))
|
||||
with pytest.raises(ReaderError, match='IDXDATA version 2 is not supported.'):
|
||||
Reader(bag).open()
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'op=\x02', b'op=\x00', 1))
|
||||
with Reader(bag) as reader, \
|
||||
pytest.raises(ReaderError, match='Expected to find message data.'):
|
||||
next(reader.messages())
|
||||
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]])
|
||||
bag.write_bytes(bag.read_bytes().replace(b'op=\x03', b'op=\x02', 1))
|
||||
with pytest.raises(ReaderError, match='Record of type \'MSGDATA\' is unexpected.'):
|
||||
Reader(bag).open()
|
||||
|
||||
# bad uint8 field
|
||||
write_bag(
|
||||
bag, create_default_header(), chunks=[[
|
||||
({}, {}),
|
||||
create_connection(),
|
||||
create_message(),
|
||||
]]
|
||||
)
|
||||
with Reader(bag) as reader, \
|
||||
pytest.raises(ReaderError, match='field \'op\''):
|
||||
next(reader.messages())
|
||||
|
||||
# bad uint32, uint64, time field
|
||||
for name in ('conn_count', 'chunk_pos', 'time'):
|
||||
write_bag(bag, create_default_header(), chunks=[[create_connection(), create_message()]])
|
||||
bag.write_bytes(bag.read_bytes().replace(name.encode(), b'x' * len(name), 1))
|
||||
if name == 'time':
|
||||
with pytest.raises(ReaderError, match=f'field \'{name}\''), \
|
||||
Reader(bag) as reader:
|
||||
next(reader.messages())
|
||||
else:
|
||||
with pytest.raises(ReaderError, match=f'field \'{name}\''):
|
||||
Reader(bag).open()
|
||||
Loading…
x
Reference in New Issue
Block a user