Type generics and missing return types
This commit is contained in:
parent
ac704bd890
commit
52480e2bad
65
setup.cfg
65
setup.cfg
@ -30,6 +30,7 @@ classifiers =
|
|||||||
Programming Language :: Python :: 3 :: Only
|
Programming Language :: Python :: 3 :: Only
|
||||||
Programming Language :: Python :: 3.8
|
Programming Language :: Python :: 3.8
|
||||||
Programming Language :: Python :: 3.9
|
Programming Language :: Python :: 3.9
|
||||||
|
Programming Language :: Python :: 3.10
|
||||||
Topic :: Scientific/Engineering
|
Topic :: Scientific/Engineering
|
||||||
Typing :: Typed
|
Typing :: Typed
|
||||||
project_urls =
|
project_urls =
|
||||||
@ -109,7 +110,7 @@ avoid-escape = False
|
|||||||
docstring_convention = google
|
docstring_convention = google
|
||||||
docstring_style = google
|
docstring_style = google
|
||||||
extend-exclude = venv*,.venv*
|
extend-exclude = venv*,.venv*
|
||||||
extend-select =
|
extend-select =
|
||||||
# docstrings
|
# docstrings
|
||||||
D204,
|
D204,
|
||||||
D400,
|
D400,
|
||||||
@ -119,8 +120,10 @@ extend-select =
|
|||||||
ignore =
|
ignore =
|
||||||
# do not require annotation of `self`
|
# do not require annotation of `self`
|
||||||
ANN101,
|
ANN101,
|
||||||
# allow line break before binary operator
|
# handled by B001
|
||||||
W503,
|
E722,
|
||||||
|
# allow line break after binary operator
|
||||||
|
W504,
|
||||||
max-line-length = 100
|
max-line-length = 100
|
||||||
strictness = long
|
strictness = long
|
||||||
suppress-none-returning = True
|
suppress-none-returning = True
|
||||||
@ -134,10 +137,14 @@ multi_line_output = 3
|
|||||||
explicit_package_bases = True
|
explicit_package_bases = True
|
||||||
mypy_path = $MYPY_CONFIG_FILE_DIR/src
|
mypy_path = $MYPY_CONFIG_FILE_DIR/src
|
||||||
namespace_packages = True
|
namespace_packages = True
|
||||||
|
strict = True
|
||||||
|
|
||||||
[mypy-ruamel]
|
[mypy-lz4.frame]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-ruamel.yaml]
|
||||||
|
implicit_reexport = True
|
||||||
|
|
||||||
[pydocstyle]
|
[pydocstyle]
|
||||||
convention = google
|
convention = google
|
||||||
add-select = D204,D400,D401,D404,D413
|
add-select = D204,D400,D401,D404,D413
|
||||||
@ -146,9 +153,37 @@ add-select = D204,D400,D401,D404,D413
|
|||||||
max-line-length = 100
|
max-line-length = 100
|
||||||
|
|
||||||
[pylint.'MESSAGES CONTROL']
|
[pylint.'MESSAGES CONTROL']
|
||||||
|
enable = all
|
||||||
disable =
|
disable =
|
||||||
duplicate-code,
|
duplicate-code,
|
||||||
ungrouped-imports,
|
ungrouped-imports,
|
||||||
|
# isort (pylint FAQ)
|
||||||
|
wrong-import-order,
|
||||||
|
# mccabe (pylint FAQ)
|
||||||
|
too-many-branches,
|
||||||
|
# fixme
|
||||||
|
fixme,
|
||||||
|
# pep8-naming (pylint FAQ, keep: invalid-name)
|
||||||
|
bad-classmethod-argument,
|
||||||
|
bad-mcs-classmethod-argument,
|
||||||
|
no-self-argument
|
||||||
|
# pycodestyle (pylint FAQ)
|
||||||
|
bad-indentation,
|
||||||
|
bare-except,
|
||||||
|
line-too-long,
|
||||||
|
missing-final-newline,
|
||||||
|
multiple-statements,
|
||||||
|
trailing-whitespace,
|
||||||
|
unnecessary-semicolon,
|
||||||
|
unneeded-not,
|
||||||
|
# pydocstyle (pylint FAQ)
|
||||||
|
missing-class-docstring,
|
||||||
|
missing-function-docstring,
|
||||||
|
missing-module-docstring,
|
||||||
|
# pyflakes (pylint FAQ)
|
||||||
|
undefined-variable,
|
||||||
|
unused-import,
|
||||||
|
unused-variable,
|
||||||
|
|
||||||
[yapf]
|
[yapf]
|
||||||
based_on_style = google
|
based_on_style = google
|
||||||
@ -159,15 +194,15 @@ indent_dictionary_value = false
|
|||||||
|
|
||||||
[tool:pytest]
|
[tool:pytest]
|
||||||
addopts =
|
addopts =
|
||||||
-v
|
-v
|
||||||
--flake8
|
--flake8
|
||||||
--mypy
|
--mypy
|
||||||
--pylint
|
--pylint
|
||||||
--yapf
|
--yapf
|
||||||
--cov=src
|
--cov=src
|
||||||
--cov-branch
|
--cov-branch
|
||||||
--cov-report=html
|
--cov-report=html
|
||||||
--cov-report=term
|
--cov-report=term
|
||||||
--no-cov-on-fail
|
--no-cov-on-fail
|
||||||
--junitxml=report.xml
|
--junitxml=report.xml
|
||||||
junit_family=xunit2
|
junit_family=xunit2
|
||||||
|
|||||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
def pathtype(exists: bool = True) -> Callable:
|
def pathtype(exists: bool = True) -> Callable[[str], Path]:
|
||||||
"""Path argument for argparse.
|
"""Path argument for argparse.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -10,14 +10,15 @@ import os
|
|||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
from bz2 import decompress as bz2_decompress
|
from bz2 import decompress as bz2_decompress
|
||||||
|
from collections import defaultdict
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, NamedTuple
|
from typing import TYPE_CHECKING, Any, Dict, NamedTuple
|
||||||
|
|
||||||
from lz4.frame import decompress as lz4_decompress # type: ignore
|
from lz4.frame import decompress as lz4_decompress
|
||||||
|
|
||||||
from rosbags.typesys.msg import normalize_msgtype
|
from rosbags.typesys.msg import normalize_msgtype
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ class Connection(NamedTuple):
|
|||||||
md5sum: str
|
md5sum: str
|
||||||
callerid: Optional[str]
|
callerid: Optional[str]
|
||||||
latching: Optional[int]
|
latching: Optional[int]
|
||||||
indexes: list
|
indexes: list[IndexData]
|
||||||
|
|
||||||
|
|
||||||
class ChunkInfo(NamedTuple):
|
class ChunkInfo(NamedTuple):
|
||||||
@ -76,7 +77,7 @@ class Chunk(NamedTuple):
|
|||||||
|
|
||||||
datasize: int
|
datasize: int
|
||||||
datapos: int
|
datapos: int
|
||||||
decompressor: Callable
|
decompressor: Callable[[bytes], bytes]
|
||||||
|
|
||||||
|
|
||||||
class TopicInfo(NamedTuple):
|
class TopicInfo(NamedTuple):
|
||||||
@ -124,9 +125,9 @@ class IndexData(NamedTuple):
|
|||||||
return self.time != other[0]
|
return self.time != other[0]
|
||||||
|
|
||||||
|
|
||||||
deserialize_uint8 = struct.Struct('<B').unpack
|
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
|
||||||
deserialize_uint32 = struct.Struct('<L').unpack
|
deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
|
||||||
deserialize_uint64 = struct.Struct('<Q').unpack
|
deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def deserialize_time(val: bytes) -> int:
|
def deserialize_time(val: bytes) -> int:
|
||||||
@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
|
|||||||
Deserialized value.
|
Deserialized value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sec, nsec = struct.unpack('<LL', val)
|
unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
|
||||||
|
sec, nsec = unpacked
|
||||||
return sec * 10**9 + nsec
|
return sec * 10**9 + nsec
|
||||||
|
|
||||||
|
|
||||||
class Header(dict):
|
class Header(Dict[str, Any]):
|
||||||
"""Record header."""
|
"""Record header."""
|
||||||
|
|
||||||
def get_uint8(self, name: str) -> int:
|
def get_uint8(self, name: str) -> int:
|
||||||
@ -214,7 +216,9 @@ class Header(dict):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self[name].decode()
|
value = self[name]
|
||||||
|
assert isinstance(value, bytes)
|
||||||
|
return value.decode()
|
||||||
except (KeyError, ValueError) as err:
|
except (KeyError, ValueError) as err:
|
||||||
raise ReaderError(f'Could not read string field {name!r}.') from err
|
raise ReaderError(f'Could not read string field {name!r}.') from err
|
||||||
|
|
||||||
@ -237,7 +241,7 @@ class Header(dict):
|
|||||||
raise ReaderError(f'Could not read time field {name!r}.') from err
|
raise ReaderError(f'Could not read time field {name!r}.') from err
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||||
"""Read header from file handle.
|
"""Read header from file handle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -362,10 +366,10 @@ class Reader:
|
|||||||
self.connections: dict[int, Connection] = {}
|
self.connections: dict[int, Connection] = {}
|
||||||
self.chunk_infos: list[ChunkInfo] = []
|
self.chunk_infos: list[ChunkInfo] = []
|
||||||
self.chunks: dict[int, Chunk] = {}
|
self.chunks: dict[int, Chunk] = {}
|
||||||
self.current_chunk = (-1, BytesIO())
|
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
|
||||||
self.topics: dict[str, TopicInfo] = {}
|
self.topics: dict[str, TopicInfo] = {}
|
||||||
|
|
||||||
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||||
"""Open rosbag and read metadata."""
|
"""Open rosbag and read metadata."""
|
||||||
try:
|
try:
|
||||||
self.bio = self.path.open('rb')
|
self.bio = self.path.open('rb')
|
||||||
@ -409,24 +413,25 @@ class Reader:
|
|||||||
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
|
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
|
||||||
|
|
||||||
self.chunks = {}
|
self.chunks = {}
|
||||||
|
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
|
||||||
for chunk_info in self.chunk_infos:
|
for chunk_info in self.chunk_infos:
|
||||||
self.bio.seek(chunk_info.pos)
|
self.bio.seek(chunk_info.pos)
|
||||||
self.chunks[chunk_info.pos] = self.read_chunk()
|
self.chunks[chunk_info.pos] = self.read_chunk()
|
||||||
|
|
||||||
for _ in range(len(chunk_info.connection_counts)):
|
for _ in range(len(chunk_info.connection_counts)):
|
||||||
cid, index = self.read_index_data(chunk_info.pos)
|
cid, index = self.read_index_data(chunk_info.pos)
|
||||||
self.connections[cid].indexes.append(index)
|
indexes[cid].append(index)
|
||||||
|
|
||||||
for connection in self.connections.values():
|
for cid, connection in self.connections.items():
|
||||||
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
|
connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
|
||||||
assert connection.indexes
|
assert connection.indexes
|
||||||
|
|
||||||
self.topics = {}
|
self.topics = {}
|
||||||
for topic, connections in groupby(
|
for topic, group in groupby(
|
||||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||||
key=lambda x: x.topic,
|
key=lambda x: x.topic,
|
||||||
):
|
):
|
||||||
connections = list(connections)
|
connections = list(group)
|
||||||
count = reduce(
|
count = reduce(
|
||||||
lambda x, y: x + y,
|
lambda x, y: x + y,
|
||||||
(
|
(
|
||||||
@ -446,7 +451,7 @@ class Reader:
|
|||||||
self.close()
|
self.close()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Close rosbag."""
|
"""Close rosbag."""
|
||||||
assert self.bio
|
assert self.bio
|
||||||
self.bio.close()
|
self.bio.close()
|
||||||
@ -614,8 +619,8 @@ class Reader:
|
|||||||
|
|
||||||
chunk_header = self.chunks[entry.chunk_pos]
|
chunk_header = self.chunks[entry.chunk_pos]
|
||||||
self.bio.seek(chunk_header.datapos)
|
self.bio.seek(chunk_header.datapos)
|
||||||
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||||
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
|
self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
|
||||||
|
|
||||||
chunk = self.current_chunk[1]
|
chunk = self.current_chunk[1]
|
||||||
chunk.seek(entry.offset)
|
chunk.seek(entry.offset)
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from dataclasses import dataclass
|
|||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
from lz4.frame import compress as lz4_compress # type: ignore
|
from lz4.frame import compress as lz4_compress
|
||||||
|
|
||||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ from .reader import Connection, RecordType
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
|
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
|
||||||
class WriterError(Exception):
|
class WriterError(Exception):
|
||||||
@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
|
|||||||
return struct.pack('<LL', sec, nsec)
|
return struct.pack('<LL', sec, nsec)
|
||||||
|
|
||||||
|
|
||||||
class Header(dict):
|
class Header(Dict[str, Any]):
|
||||||
"""Record header."""
|
"""Record header."""
|
||||||
|
|
||||||
def set_uint32(self, name: str, value: int):
|
def set_uint32(self, name: str, value: int) -> None:
|
||||||
"""Set field to uint32 value.
|
"""Set field to uint32 value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -70,7 +70,7 @@ class Header(dict):
|
|||||||
"""
|
"""
|
||||||
self[name] = serialize_uint32(value)
|
self[name] = serialize_uint32(value)
|
||||||
|
|
||||||
def set_uint64(self, name: str, value: int):
|
def set_uint64(self, name: str, value: int) -> None:
|
||||||
"""Set field to uint64 value.
|
"""Set field to uint64 value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -80,7 +80,7 @@ class Header(dict):
|
|||||||
"""
|
"""
|
||||||
self[name] = serialize_uint64(value)
|
self[name] = serialize_uint64(value)
|
||||||
|
|
||||||
def set_string(self, name: str, value: str):
|
def set_string(self, name: str, value: str) -> None:
|
||||||
"""Set field to string value.
|
"""Set field to string value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,7 +90,7 @@ class Header(dict):
|
|||||||
"""
|
"""
|
||||||
self[name] = value.encode()
|
self[name] = value.encode()
|
||||||
|
|
||||||
def set_time(self, name: str, value: int):
|
def set_time(self, name: str, value: int) -> None:
|
||||||
"""Set field to time value.
|
"""Set field to time value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
]
|
]
|
||||||
self.chunk_threshold = 1 * (1 << 20)
|
self.chunk_threshold = 1 * (1 << 20)
|
||||||
|
|
||||||
def set_compression(self, fmt: CompressionFormat):
|
def set_compression(self, fmt: CompressionFormat) -> None:
|
||||||
"""Enable compression on rosbag1.
|
"""Enable compression on rosbag1.
|
||||||
|
|
||||||
This function has to be called before opening.
|
This function has to be called before opening.
|
||||||
@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
self.compression_format = fmt.name.lower()
|
self.compression_format = fmt.name.lower()
|
||||||
|
|
||||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
|
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
|
||||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
|
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
|
||||||
self.compressor = {
|
self.compressor = {
|
||||||
'bz2': bz2,
|
'bz2': bz2,
|
||||||
'lz4': lz4,
|
'lz4': lz4,
|
||||||
}[self.compression_format]
|
}[self.compression_format]
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
"""Open rosbag1 for writing."""
|
"""Open rosbag1 for writing."""
|
||||||
try:
|
try:
|
||||||
self.bio = self.path.open('xb')
|
self.bio = self.path.open('xb')
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
||||||
|
|
||||||
|
assert self.bio
|
||||||
self.bio.write(b'#ROSBAG V2.0\n')
|
self.bio.write(b'#ROSBAG V2.0\n')
|
||||||
header = Header()
|
header = Header()
|
||||||
header.set_uint64('index_pos', 0)
|
header.set_uint64('index_pos', 0)
|
||||||
@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.connections[connection.cid] = connection
|
self.connections[connection.cid] = connection
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||||
"""Write message to rosbag1.
|
"""Write message to rosbag1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.write_chunk(chunk)
|
self.write_chunk(chunk)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_connection(connection: Connection, bio: BytesIO):
|
def write_connection(connection: Connection, bio: BinaryIO) -> None:
|
||||||
"""Write connection record."""
|
"""Write connection record."""
|
||||||
header = Header()
|
header = Header()
|
||||||
header.set_uint32('conn', connection.cid)
|
header.set_uint32('conn', connection.cid)
|
||||||
@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
header.set_string('latching', str(connection.latching))
|
header.set_string('latching', str(connection.latching))
|
||||||
header.write(bio)
|
header.write(bio)
|
||||||
|
|
||||||
def write_chunk(self, chunk: WriteChunk):
|
def write_chunk(self, chunk: WriteChunk) -> None:
|
||||||
"""Write open chunk to file."""
|
"""Write open chunk to file."""
|
||||||
assert self.bio
|
assert self.bio
|
||||||
|
|
||||||
@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
chunk.data.close()
|
chunk.data.close()
|
||||||
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Close rosbag1 after writing.
|
"""Close rosbag1 after writing.
|
||||||
|
|
||||||
Closes open chunks and writes index.
|
Closes open chunks and writes index.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert self.bio
|
||||||
for chunk in self.chunks:
|
for chunk in self.chunks:
|
||||||
if chunk.pos == -1:
|
if chunk.pos == -1:
|
||||||
self.write_chunk(chunk)
|
self.write_chunk(chunk)
|
||||||
|
|||||||
@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import zstandard
|
import zstandard
|
||||||
from ruamel.yaml import YAML, YAMLError
|
from ruamel.yaml import YAML
|
||||||
|
from ruamel.yaml.error import YAMLError
|
||||||
|
|
||||||
from .connection import Connection
|
from .connection import Connection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union
|
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
|
||||||
class ReaderError(Exception):
|
class ReaderError(Exception):
|
||||||
@ -25,7 +26,7 @@ class ReaderError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def decompress(path: Path, do_decompress: bool):
|
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
|
||||||
"""Transparent rosbag2 database decompression context.
|
"""Transparent rosbag2 database decompression context.
|
||||||
|
|
||||||
This context manager will yield a path to the decompressed file contents.
|
This context manager will yield a path to the decompressed file contents.
|
||||||
@ -119,12 +120,12 @@ class Reader:
|
|||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
"""Open rosbag2."""
|
"""Open rosbag2."""
|
||||||
# Future storage formats will require file handles.
|
# Future storage formats will require file handles.
|
||||||
self.bio = True
|
self.bio = True
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Close rosbag2."""
|
"""Close rosbag2."""
|
||||||
# Future storage formats will require file handles.
|
# Future storage formats will require file handles.
|
||||||
assert self.bio
|
assert self.bio
|
||||||
@ -133,12 +134,14 @@ class Reader:
|
|||||||
@property
|
@property
|
||||||
def duration(self) -> int:
|
def duration(self) -> int:
|
||||||
"""Duration in nanoseconds between earliest and latest messages."""
|
"""Duration in nanoseconds between earliest and latest messages."""
|
||||||
return self.metadata['duration']['nanoseconds'] + 1
|
nsecs: int = self.metadata['duration']['nanoseconds']
|
||||||
|
return nsecs + 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_time(self) -> int:
|
def start_time(self) -> int:
|
||||||
"""Timestamp in nanoseconds of the earliest message."""
|
"""Timestamp in nanoseconds of the earliest message."""
|
||||||
return self.metadata['starting_time']['nanoseconds_since_epoch']
|
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||||
|
return nsecs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_time(self) -> int:
|
def end_time(self) -> int:
|
||||||
@ -148,7 +151,8 @@ class Reader:
|
|||||||
@property
|
@property
|
||||||
def message_count(self) -> int:
|
def message_count(self) -> int:
|
||||||
"""Total message count."""
|
"""Total message count."""
|
||||||
return self.metadata['message_count']
|
count: int = self.metadata['message_count']
|
||||||
|
return count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def compression_format(self) -> Optional[str]:
|
def compression_format(self) -> Optional[str]:
|
||||||
@ -233,7 +237,7 @@ class Reader:
|
|||||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||||
|
|
||||||
cur.execute('SELECT name,id FROM topics')
|
cur.execute('SELECT name,id FROM topics')
|
||||||
connmap: Dict[int, Connection] = {
|
connmap: dict[int, Connection] = {
|
||||||
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
|
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
|
||||||
None) # type: ignore
|
None) # type: ignore
|
||||||
for row in cur
|
for row in cur
|
||||||
|
|||||||
@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.compression_format = ''
|
self.compression_format = ''
|
||||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||||
self.connections: dict[int, Connection] = {}
|
self.connections: dict[int, Connection] = {}
|
||||||
self.conn = None
|
self.conn: Optional[sqlite3.Connection] = None
|
||||||
self.cursor: Optional[sqlite3.Cursor] = None
|
self.cursor: Optional[sqlite3.Cursor] = None
|
||||||
|
|
||||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
|
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
|
||||||
"""Enable compression on bag.
|
"""Enable compression on bag.
|
||||||
|
|
||||||
This function has to be called before opening.
|
This function has to be called before opening.
|
||||||
@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.compression_format = fmt.name.lower()
|
self.compression_format = fmt.name.lower()
|
||||||
self.compressor = zstandard.ZstdCompressor()
|
self.compressor = zstandard.ZstdCompressor()
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> None:
|
||||||
"""Open rosbag2 for writing.
|
"""Open rosbag2 for writing.
|
||||||
|
|
||||||
Create base directory and open database connection.
|
Create base directory and open database connection.
|
||||||
@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||||
"""Write message to rosbag2.
|
"""Write message to rosbag2.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
)
|
)
|
||||||
connection.count += 1
|
connection.count += 1
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
"""Close rosbag2 after writing.
|
"""Close rosbag2 after writing.
|
||||||
|
|
||||||
Closes open database transactions and writes metadata.yaml.
|
Closes open database transactions and writes metadata.yaml.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert self.cursor
|
||||||
|
assert self.conn
|
||||||
self.cursor.close()
|
self.cursor.close()
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
|
|
||||||
@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
|||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
if self.compression_mode == 'file':
|
if self.compression_mode == 'file':
|
||||||
|
assert self.compressor
|
||||||
src = self.dbpath
|
src = self.dbpath
|
||||||
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
|
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
|
||||||
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:
|
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:
|
||||||
|
|||||||
@ -19,10 +19,10 @@ from .typing import Field
|
|||||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable
|
from .typing import CDRDeser, CDRSer, CDRSerSize
|
||||||
|
|
||||||
|
|
||||||
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
|
||||||
"""Generate cdr size calculation function.
|
"""Generate cdr size calculation function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
|||||||
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
|
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||||
"""Generate cdr serialization function.
|
"""Generate cdr serialization function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
|||||||
return compile_lines(lines).serialize_cdr # type: ignore
|
return compile_lines(lines).serialize_cdr # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||||
"""Generate cdr deserialization function.
|
"""Generate cdr deserialization function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
|
|||||||
generate_serialize_cdr(fields, 'be'),
|
generate_serialize_cdr(fields, 'be'),
|
||||||
generate_deserialize_cdr(fields, 'le'),
|
generate_deserialize_cdr(fields, 'le'),
|
||||||
generate_deserialize_cdr(fields, 'be'),
|
generate_deserialize_cdr(fields, 'be'),
|
||||||
generate_ros1_to_cdr(fields, typename, False),
|
generate_ros1_to_cdr(fields, typename, False), # type: ignore
|
||||||
generate_ros1_to_cdr(fields, typename, True),
|
generate_ros1_to_cdr(fields, typename, True), # type: ignore
|
||||||
generate_cdr_to_ros1(fields, typename, False),
|
generate_cdr_to_ros1(fields, typename, False), # type: ignore
|
||||||
generate_cdr_to_ros1(fields, typename, True),
|
generate_cdr_to_ros1(fields, typename, True), # type: ignore
|
||||||
)
|
)
|
||||||
return MSGDEFCACHE[typename]
|
return MSGDEFCACHE[typename]
|
||||||
|
|||||||
@ -18,10 +18,16 @@ from .typing import Field
|
|||||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable # pylint: disable=ungrouped-imports
|
from typing import Union # pylint: disable=ungrouped-imports
|
||||||
|
|
||||||
|
from .typing import Bitcvt, BitcvtSize
|
||||||
|
|
||||||
|
|
||||||
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
|
def generate_ros1_to_cdr(
|
||||||
|
fields: list[Field],
|
||||||
|
typename: str,
|
||||||
|
copy: bool,
|
||||||
|
) -> Union[Bitcvt, BitcvtSize]:
|
||||||
"""Generate ROS1 to CDR conversion function.
|
"""Generate ROS1 to CDR conversion function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
|
|||||||
aligned = anext
|
aligned = anext
|
||||||
|
|
||||||
lines.append(' return ipos, opos')
|
lines.append(' return ipos, opos')
|
||||||
return getattr(compile_lines(lines), funcname)
|
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
|
def generate_cdr_to_ros1(
|
||||||
|
fields: list[Field],
|
||||||
|
typename: str,
|
||||||
|
copy: bool,
|
||||||
|
) -> Union[Bitcvt, BitcvtSize]:
|
||||||
"""Generate CDR to ROS1 conversion function.
|
"""Generate CDR to ROS1 conversion function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
|
|||||||
aligned = anext
|
aligned = anext
|
||||||
|
|
||||||
lines.append(' return ipos, opos')
|
lines.append(' return ipos, opos')
|
||||||
return getattr(compile_lines(lines), funcname)
|
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||||
|
|||||||
@ -7,7 +7,14 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, NamedTuple
|
from typing import TYPE_CHECKING, NamedTuple
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, Tuple
|
||||||
|
|
||||||
|
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
|
||||||
|
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
|
||||||
|
|
||||||
|
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
|
||||||
|
CDRSer = Callable[[bytes, int, type], int]
|
||||||
|
CDRSerSize = Callable[[int, type], int]
|
||||||
|
|
||||||
|
|
||||||
class Descriptor(NamedTuple):
|
class Descriptor(NamedTuple):
|
||||||
@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
|
|||||||
"""Metadata of a message."""
|
"""Metadata of a message."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
fields: List[Field]
|
fields: list[Field]
|
||||||
cls: Any
|
cls: Any
|
||||||
size_cdr: int
|
size_cdr: int
|
||||||
getsize_cdr: Callable
|
getsize_cdr: CDRSerSize
|
||||||
serialize_cdr_le: Callable
|
serialize_cdr_le: CDRSer
|
||||||
serialize_cdr_be: Callable
|
serialize_cdr_be: CDRSer
|
||||||
deserialize_cdr_le: Callable
|
deserialize_cdr_le: CDRDeser
|
||||||
deserialize_cdr_be: Callable
|
deserialize_cdr_be: CDRDeser
|
||||||
getsize_ros1_to_cdr: Callable
|
getsize_ros1_to_cdr: BitcvtSize
|
||||||
ros1_to_cdr: Callable
|
ros1_to_cdr: Bitcvt
|
||||||
getsize_cdr_to_ros1: Callable
|
getsize_cdr_to_ros1: BitcvtSize
|
||||||
cdr_to_ros1: Callable
|
cdr_to_ros1: Bitcvt
|
||||||
|
|||||||
@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
|
|||||||
pos = rule.skip_ws(text, 0)
|
pos = rule.skip_ws(text, 0)
|
||||||
npos, trees = rule.parse(text, pos)
|
npos, trees = rule.parse(text, pos)
|
||||||
assert npos == len(text), f'Could not parse: {text!r}'
|
assert npos == len(text), f'Could not parse: {text!r}'
|
||||||
return visitor.visit(trees)
|
return visitor.visit(trees) # type: ignore
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
raise TypesysError(f'Could not parse: {text!r}') from err
|
raise TypesysError(f'Could not parse: {text!r}') from err
|
||||||
|
|||||||
@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
RULES = parse_grammar(GRAMMAR_IDL)
|
RULES = parse_grammar(GRAMMAR_IDL)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.typedefs = {}
|
self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
|
||||||
|
|
||||||
def visit_specification(self, children: Any) -> Typesdict:
|
def visit_specification(self, children: Any) -> Typesdict:
|
||||||
"""Process start symbol, return only children of modules."""
|
"""Process start symbol, return only children of modules."""
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class Rule:
|
|||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def parse(self, text: str, pos: int):
|
def parse(self, text: str, pos: int) -> tuple[int, Any]:
|
||||||
"""Apply rule at position."""
|
"""Apply rule at position."""
|
||||||
raise NotImplementedError # pragma: no cover
|
raise NotImplementedError # pragma: no cover
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
RULES: dict[str, Rule] = {}
|
RULES: dict[str, Rule] = {}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
|
|
||||||
def visit(self, tree: Any) -> Any:
|
def visit(self, tree: Any) -> Any:
|
||||||
|
|||||||
@ -13,12 +13,14 @@ from . import types
|
|||||||
from .base import Nodetype, TypesysError
|
from .base import Nodetype, TypesysError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from .base import Typesdict
|
from .base import Typesdict
|
||||||
|
|
||||||
INTLIKE = re.compile('^u?(bool|int|float)')
|
INTLIKE = re.compile('^u?(bool|int|float)')
|
||||||
|
|
||||||
|
|
||||||
def get_typehint(desc: tuple) -> str:
|
def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
|
||||||
"""Get python type hint for field.
|
"""Get python type hint for field.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if desc[0] == Nodetype.BASE:
|
if desc[0] == Nodetype.BASE:
|
||||||
if match := INTLIKE.match(desc[1]):
|
if match := INTLIKE.match(desc[1]): # type: ignore
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return 'str'
|
return 'str'
|
||||||
|
|
||||||
if desc[0] == Nodetype.NAME:
|
if desc[0] == Nodetype.NAME:
|
||||||
|
assert isinstance(desc[1], str)
|
||||||
return desc[1].replace('/', '__')
|
return desc[1].replace('/', '__')
|
||||||
|
|
||||||
sub = desc[1][0]
|
sub = desc[1][0]
|
||||||
if INTLIKE.match(sub[1]):
|
if INTLIKE.match(sub[1]):
|
||||||
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
||||||
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
||||||
return f'list[{get_typehint(sub)}]'
|
return f'list[{get_typehint(sub)}]' # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def generate_python_code(typs: Typesdict) -> str:
|
def generate_python_code(typs: Typesdict) -> str:
|
||||||
@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
|
|||||||
'',
|
'',
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_ftype(ftype: tuple) -> tuple:
|
def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
|
||||||
if ftype[0] <= 2:
|
if ftype[0] <= 2:
|
||||||
return int(ftype[0]), ftype[1]
|
return int(ftype[0]), ftype[1]
|
||||||
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
|
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Union, cast
|
from typing import TYPE_CHECKING, Dict, List, Union, cast
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from rosbags.serde.messages import SerdeError, get_msgdef
|
from rosbags.serde.messages import SerdeError, get_msgdef
|
||||||
from rosbags.serde.typing import Msgdef
|
from rosbags.serde.typing import Msgdef
|
||||||
@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
|
|||||||
|
|
||||||
size = SIZEMAP[desc.args]
|
size = SIZEMAP[desc.args]
|
||||||
pos = (pos + size - 1) & -size
|
pos = (pos + size - 1) & -size
|
||||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
|
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
|
||||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||||
ndarr = ndarr.byteswap() # no inplace on readonly array
|
ndarr = ndarr.byteswap() # no inplace on readonly array
|
||||||
return ndarr, pos + num * SIZEMAP[desc.args]
|
return ndarr, pos + num * SIZEMAP[desc.args]
|
||||||
@ -278,7 +279,7 @@ def serialize_array(
|
|||||||
size = SIZEMAP[desc.args]
|
size = SIZEMAP[desc.args]
|
||||||
pos = (pos + size - 1) & -size
|
pos = (pos + size - 1) & -size
|
||||||
size *= len(val)
|
size *= len(val)
|
||||||
val = cast(numpy.ndarray, val)
|
val = cast(NDArray[numpy.int_], val)
|
||||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||||
val = val.byteswap() # no inplace on readonly array
|
val = val.byteswap() # no inplace on readonly array
|
||||||
rawdata[pos:pos + size] = memoryview(val.tobytes())
|
rawdata[pos:pos + size] = memoryview(val.tobytes())
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError
|
|||||||
from rosbags.rosbag2 import WriterError
|
from rosbags.rosbag2 import WriterError
|
||||||
|
|
||||||
|
|
||||||
def test_cliwrapper(tmp_path: Path):
|
def test_cliwrapper(tmp_path: Path) -> None:
|
||||||
"""Test cli wrapper."""
|
"""Test cli wrapper."""
|
||||||
(tmp_path / 'subdir').mkdir()
|
(tmp_path / 'subdir').mkdir()
|
||||||
(tmp_path / 'ros1.bag').write_text('')
|
(tmp_path / 'ros1.bag').write_text('')
|
||||||
@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path):
|
|||||||
mock_print.assert_called_with('ERROR: exc')
|
mock_print.assert_called_with('ERROR: exc')
|
||||||
|
|
||||||
|
|
||||||
def test_convert(tmp_path: Path):
|
def test_convert(tmp_path: Path) -> None:
|
||||||
"""Test conversion function."""
|
"""Test conversion function."""
|
||||||
(tmp_path / 'subdir').mkdir()
|
(tmp_path / 'subdir').mkdir()
|
||||||
(tmp_path / 'foo.bag').write_text('')
|
(tmp_path / 'foo.bag').write_text('')
|
||||||
|
|||||||
@ -142,13 +142,13 @@ module test_msgs {
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_parse_empty_msg():
|
def test_parse_empty_msg() -> None:
|
||||||
"""Test msg parser with empty message."""
|
"""Test msg parser with empty message."""
|
||||||
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
|
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
|
||||||
assert ret == {'std_msgs/msg/Empty': ([], [])}
|
assert ret == {'std_msgs/msg/Empty': ([], [])}
|
||||||
|
|
||||||
|
|
||||||
def test_parse_bounds_msg():
|
def test_parse_bounds_msg() -> None:
|
||||||
"""Test msg parser."""
|
"""Test msg parser."""
|
||||||
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
|
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
|
||||||
assert ret == {
|
assert ret == {
|
||||||
@ -168,7 +168,7 @@ def test_parse_bounds_msg():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_parse_defaults_msg():
|
def test_parse_defaults_msg() -> None:
|
||||||
"""Test msg parser."""
|
"""Test msg parser."""
|
||||||
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
|
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
|
||||||
assert ret == {
|
assert ret == {
|
||||||
@ -188,7 +188,7 @@ def test_parse_defaults_msg():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_parse_msg():
|
def test_parse_msg() -> None:
|
||||||
"""Test msg parser."""
|
"""Test msg parser."""
|
||||||
with pytest.raises(TypesysError, match='Could not parse'):
|
with pytest.raises(TypesysError, match='Could not parse'):
|
||||||
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
|
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
|
||||||
@ -208,7 +208,7 @@ def test_parse_msg():
|
|||||||
assert fields[6][1][0] == Nodetype.ARRAY
|
assert fields[6][1][0] == Nodetype.ARRAY
|
||||||
|
|
||||||
|
|
||||||
def test_parse_multi_msg():
|
def test_parse_multi_msg() -> None:
|
||||||
"""Test multi msg parser."""
|
"""Test multi msg parser."""
|
||||||
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
|
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
|
||||||
assert len(ret) == 3
|
assert len(ret) == 3
|
||||||
@ -223,7 +223,7 @@ def test_parse_multi_msg():
|
|||||||
assert consts == [('static', 'uint32', 42)]
|
assert consts == [('static', 'uint32', 42)]
|
||||||
|
|
||||||
|
|
||||||
def test_parse_cstring_confusion():
|
def test_parse_cstring_confusion() -> None:
|
||||||
"""Test if msg separator is confused with const string."""
|
"""Test if msg separator is confused with const string."""
|
||||||
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
|
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
|
||||||
assert len(ret) == 2
|
assert len(ret) == 2
|
||||||
@ -235,7 +235,7 @@ def test_parse_cstring_confusion():
|
|||||||
assert fields[1][1][1] == 'string'
|
assert fields[1][1][1] == 'string'
|
||||||
|
|
||||||
|
|
||||||
def test_parse_relative_siblings_msg():
|
def test_parse_relative_siblings_msg() -> None:
|
||||||
"""Test relative siblings with msg parser."""
|
"""Test relative siblings with msg parser."""
|
||||||
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
|
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
|
||||||
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||||
@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg():
|
|||||||
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
|
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
|
||||||
|
|
||||||
|
|
||||||
def test_parse_idl():
|
def test_parse_idl() -> None:
|
||||||
"""Test idl parser."""
|
"""Test idl parser."""
|
||||||
ret = get_types_from_idl(IDL_LANG)
|
ret = get_types_from_idl(IDL_LANG)
|
||||||
assert ret == {}
|
assert ret == {}
|
||||||
@ -267,21 +267,21 @@ def test_parse_idl():
|
|||||||
assert fields[6][1][0] == Nodetype.ARRAY
|
assert fields[6][1][0] == Nodetype.ARRAY
|
||||||
|
|
||||||
|
|
||||||
def test_register_types():
|
def test_register_types() -> None:
|
||||||
"""Test type registeration."""
|
"""Test type registeration."""
|
||||||
assert 'foo' not in FIELDDEFS
|
assert 'foo' not in FIELDDEFS
|
||||||
register_types({})
|
register_types({})
|
||||||
register_types({'foo': [[], [('b', (1, 'bool'))]]})
|
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
|
||||||
assert 'foo' in FIELDDEFS
|
assert 'foo' in FIELDDEFS
|
||||||
|
|
||||||
register_types({'std_msgs/msg/Header': [[], []]})
|
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
|
||||||
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
|
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
|
||||||
|
|
||||||
with pytest.raises(TypesysError, match='different definition'):
|
with pytest.raises(TypesysError, match='different definition'):
|
||||||
register_types({'foo': [[], [('x', (1, 'bool'))]]})
|
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_generate_msgdef():
|
def test_generate_msgdef() -> None:
|
||||||
"""Test message definition generator."""
|
"""Test message definition generator."""
|
||||||
res = generate_msgdef('std_msgs/msg/Header')
|
res = generate_msgdef('std_msgs/msg/Header')
|
||||||
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
|
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
|
||||||
|
|||||||
@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path:
|
|||||||
return tmp_path
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
def test_reader(bag: Path):
|
def test_reader(bag: Path) -> None:
|
||||||
"""Test reader and deserializer on simple bag."""
|
"""Test reader and deserializer on simple bag."""
|
||||||
with Reader(bag) as reader:
|
with Reader(bag) as reader:
|
||||||
assert reader.duration == 43
|
assert reader.duration == 43
|
||||||
@ -151,7 +151,7 @@ def test_reader(bag: Path):
|
|||||||
next(gen)
|
next(gen)
|
||||||
|
|
||||||
|
|
||||||
def test_message_filters(bag: Path):
|
def test_message_filters(bag: Path) -> None:
|
||||||
"""Test reader filters messages."""
|
"""Test reader filters messages."""
|
||||||
with Reader(bag) as reader:
|
with Reader(bag) as reader:
|
||||||
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
|
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
|
||||||
@ -188,14 +188,14 @@ def test_message_filters(bag: Path):
|
|||||||
next(gen)
|
next(gen)
|
||||||
|
|
||||||
|
|
||||||
def test_user_errors(bag: Path):
|
def test_user_errors(bag: Path) -> None:
|
||||||
"""Test user errors."""
|
"""Test user errors."""
|
||||||
reader = Reader(bag)
|
reader = Reader(bag)
|
||||||
with pytest.raises(ReaderError, match='Rosbag is not open'):
|
with pytest.raises(ReaderError, match='Rosbag is not open'):
|
||||||
next(reader.messages())
|
next(reader.messages())
|
||||||
|
|
||||||
|
|
||||||
def test_failure_cases(tmp_path: Path):
|
def test_failure_cases(tmp_path: Path) -> None:
|
||||||
"""Test bags with broken fs layout."""
|
"""Test bags with broken fs layout."""
|
||||||
with pytest.raises(ReaderError, match='not read metadata'):
|
with pytest.raises(ReaderError, match='not read metadata'):
|
||||||
Reader(tmp_path)
|
Reader(tmp_path)
|
||||||
|
|||||||
@ -2,8 +2,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Reader tests."""
|
"""Reader tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from struct import pack
|
from struct import pack
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -11,8 +14,12 @@ import pytest
|
|||||||
from rosbags.rosbag1 import Reader, ReaderError
|
from rosbags.rosbag1 import Reader, ReaderError
|
||||||
from rosbags.rosbag1.reader import IndexData
|
from rosbags.rosbag1.reader import IndexData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Sequence, Union
|
||||||
|
|
||||||
def ser(data):
|
|
||||||
|
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
|
||||||
"""Serialize record header."""
|
"""Serialize record header."""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
fields = []
|
fields = []
|
||||||
@ -23,7 +30,7 @@ def ser(data):
|
|||||||
return pack('<L', len(data)) + data
|
return pack('<L', len(data)) + data
|
||||||
|
|
||||||
|
|
||||||
def create_default_header():
|
def create_default_header() -> dict[str, bytes]:
|
||||||
"""Create empty rosbag header."""
|
"""Create empty rosbag header."""
|
||||||
return {
|
return {
|
||||||
'op': b'\x03',
|
'op': b'\x03',
|
||||||
@ -32,7 +39,11 @@ def create_default_header():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_connection(cid=1, topic=0, typ=0):
|
def create_connection(
|
||||||
|
cid: int = 1,
|
||||||
|
topic: int = 0,
|
||||||
|
typ: int = 0,
|
||||||
|
) -> tuple[dict[str, bytes], dict[str, bytes]]:
|
||||||
"""Create connection record."""
|
"""Create connection record."""
|
||||||
return {
|
return {
|
||||||
'op': b'\x07',
|
'op': b'\x07',
|
||||||
@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_message(cid=1, time=0, msg=0):
|
def create_message(
|
||||||
|
cid: int = 1,
|
||||||
|
time: int = 0,
|
||||||
|
msg: int = 0,
|
||||||
|
) -> tuple[dict[str, Union[bytes, int]], bytes]:
|
||||||
"""Create message record."""
|
"""Create message record."""
|
||||||
return {
|
return {
|
||||||
'op': b'\x02',
|
'op': b'\x02',
|
||||||
@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0):
|
|||||||
}, f'MSGCONTENT{msg}'.encode()
|
}, f'MSGCONTENT{msg}'.encode()
|
||||||
|
|
||||||
|
|
||||||
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements
|
def write_bag( # pylint: disable=too-many-locals,too-many-statements
|
||||||
|
|
||||||
|
bag: Path,
|
||||||
|
header: dict[str, bytes],
|
||||||
|
chunks: Sequence[Any] = (),
|
||||||
|
) -> None:
|
||||||
"""Write bag file."""
|
"""Write bag file."""
|
||||||
magic = b'#ROSBAG V2.0\n'
|
magic = b'#ROSBAG V2.0\n'
|
||||||
|
|
||||||
@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
|||||||
chunk_bytes = b''
|
chunk_bytes = b''
|
||||||
start_time = 2**32 - 1
|
start_time = 2**32 - 1
|
||||||
end_time = 0
|
end_time = 0
|
||||||
counts = defaultdict(int)
|
counts: dict[int, int] = defaultdict(int)
|
||||||
index = {}
|
index = {}
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
|||||||
'count': 0,
|
'count': 0,
|
||||||
'msgs': b'',
|
'msgs': b'',
|
||||||
}
|
}
|
||||||
index[conn]['count'] += 1
|
index[conn]['count'] += 1 # type: ignore
|
||||||
index[conn]['msgs'] += pack('<LLL', time, 0, offset)
|
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
|
||||||
|
|
||||||
add = ser(head) + ser(data)
|
add = ser(head) + ser(data)
|
||||||
chunk_bytes += add
|
chunk_bytes += add
|
||||||
@ -140,19 +160,19 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
|||||||
if 'index_pos' not in header:
|
if 'index_pos' not in header:
|
||||||
header['index_pos'] = pack('<Q', pos)
|
header['index_pos'] = pack('<Q', pos)
|
||||||
|
|
||||||
header = ser(header)
|
header_bytes = ser(header)
|
||||||
header += b'\x20' * (4096 - len(header))
|
header_bytes += b'\x20' * (4096 - len(header_bytes))
|
||||||
|
|
||||||
bag.write_bytes(b''.join([
|
bag.write_bytes(b''.join([
|
||||||
magic,
|
magic,
|
||||||
header,
|
header_bytes,
|
||||||
chunks_bytes,
|
chunks_bytes,
|
||||||
connections,
|
connections,
|
||||||
chunkinfos,
|
chunkinfos,
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|
||||||
def test_indexdata():
|
def test_indexdata() -> None:
|
||||||
"""Test IndexData sort sorder."""
|
"""Test IndexData sort sorder."""
|
||||||
x42_1_0 = IndexData(42, 1, 0)
|
x42_1_0 = IndexData(42, 1, 0)
|
||||||
x42_2_0 = IndexData(42, 2, 0)
|
x42_2_0 = IndexData(42, 2, 0)
|
||||||
@ -175,7 +195,7 @@ def test_indexdata():
|
|||||||
assert not x42_1_0 > x43_3_0
|
assert not x42_1_0 > x43_3_0
|
||||||
|
|
||||||
|
|
||||||
def test_reader(tmp_path): # pylint: disable=too-many-statements
|
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||||
"""Test reader and deserializer on simple bag."""
|
"""Test reader and deserializer on simple bag."""
|
||||||
# empty bag
|
# empty bag
|
||||||
bag = tmp_path / 'test.bag'
|
bag = tmp_path / 'test.bag'
|
||||||
@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
|
|||||||
assert msgs[0][2] == b'MSGCONTENT5'
|
assert msgs[0][2] == b'MSGCONTENT5'
|
||||||
|
|
||||||
|
|
||||||
def test_user_errors(tmp_path):
|
def test_user_errors(tmp_path: Path) -> None:
|
||||||
"""Test user errors."""
|
"""Test user errors."""
|
||||||
bag = tmp_path / 'test.bag'
|
bag = tmp_path / 'test.bag'
|
||||||
write_bag(bag, create_default_header(), chunks=[[
|
write_bag(bag, create_default_header(), chunks=[[
|
||||||
@ -281,7 +301,7 @@ def test_user_errors(tmp_path):
|
|||||||
next(reader.messages())
|
next(reader.messages())
|
||||||
|
|
||||||
|
|
||||||
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements
|
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||||
"""Test failure cases."""
|
"""Test failure cases."""
|
||||||
bag = tmp_path / 'test.bag'
|
bag = tmp_path / 'test.bag'
|
||||||
with pytest.raises(ReaderError, match='does not exist'):
|
with pytest.raises(ReaderError, match='does not exist'):
|
||||||
|
|||||||
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
|
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
|
||||||
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
|
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
|
||||||
"""Test full data roundtrip."""
|
"""Test full data roundtrip."""
|
||||||
|
|
||||||
class Foo: # pylint: disable=too-few-public-methods
|
class Foo: # pylint: disable=too-few-public-methods
|
||||||
|
|||||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||||
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||||
"""Test full data roundtrip."""
|
"""Test full data roundtrip."""
|
||||||
|
|
||||||
class Foo: # pylint: disable=too-few-public-methods
|
class Foo: # pylint: disable=too-few-public-methods
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__
|
|||||||
from .cdr import deserialize, serialize
|
from .cdr import deserialize, serialize
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Tuple, Union
|
from typing import Any, Generator, Union
|
||||||
|
|
||||||
MSG_POLY = (
|
MSG_POLY = (
|
||||||
(
|
(
|
||||||
@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def _comparable():
|
def _comparable() -> Generator[None, None, None]:
|
||||||
"""Make messages containing numpy arrays comparable.
|
"""Make messages containing numpy arrays comparable.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
@ -180,7 +180,7 @@ def _comparable():
|
|||||||
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
|
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
|
||||||
lhs = self._mock_wraps # pylint: disable=protected-access
|
lhs = self._mock_wraps # pylint: disable=protected-access
|
||||||
rhs = getattr(other, '_mock_wraps', other)
|
rhs = getattr(other, '_mock_wraps', other)
|
||||||
return (lhs == rhs).all()
|
return (lhs == rhs).all() # type: ignore
|
||||||
|
|
||||||
class CNDArray(MagicMock):
|
class CNDArray(MagicMock):
|
||||||
"""Mock ndarray."""
|
"""Mock ndarray."""
|
||||||
@ -194,14 +194,14 @@ def _comparable():
|
|||||||
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
|
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
|
||||||
|
|
||||||
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
|
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
|
||||||
return CNDArray(wraps=frombuffer(*args, **kwargs))
|
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
|
||||||
|
|
||||||
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
|
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('message', MESSAGES)
|
@pytest.mark.parametrize('message', MESSAGES)
|
||||||
def test_serde(message: Tuple[bytes, str, bool]):
|
def test_serde(message: tuple[bytes, str, bool]) -> None:
|
||||||
"""Test serialization deserialization roundtrip."""
|
"""Test serialization deserialization roundtrip."""
|
||||||
rawdata, typ, is_little = message
|
rawdata, typ, is_little = message
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('_comparable')
|
@pytest.mark.usefixtures('_comparable')
|
||||||
def test_deserializer():
|
def test_deserializer() -> None:
|
||||||
"""Test deserializer."""
|
"""Test deserializer."""
|
||||||
msg = deserialize_cdr(*MSG_POLY[:2])
|
msg = deserialize_cdr(*MSG_POLY[:2])
|
||||||
assert msg == deserialize(*MSG_POLY[:2])
|
assert msg == deserialize(*MSG_POLY[:2])
|
||||||
@ -233,7 +233,8 @@ def test_deserializer():
|
|||||||
assert msg.header.frame_id == 'foo42'
|
assert msg.header.frame_id == 'foo42'
|
||||||
field = msg.magnetic_field
|
field = msg.magnetic_field
|
||||||
assert (field.x, field.y, field.z) == (128., 128., 128.)
|
assert (field.x, field.y, field.z) == (128., 128., 128.)
|
||||||
assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all()
|
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore
|
||||||
|
assert (diag == [1., 1., 1.]).all()
|
||||||
|
|
||||||
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
|
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
|
||||||
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
|
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
|
||||||
@ -241,7 +242,7 @@ def test_deserializer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('_comparable')
|
@pytest.mark.usefixtures('_comparable')
|
||||||
def test_serializer():
|
def test_serializer() -> None:
|
||||||
"""Test serializer."""
|
"""Test serializer."""
|
||||||
|
|
||||||
class Foo: # pylint: disable=too-few-public-methods
|
class Foo: # pylint: disable=too-few-public-methods
|
||||||
@ -268,7 +269,7 @@ def test_serializer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('_comparable')
|
@pytest.mark.usefixtures('_comparable')
|
||||||
def test_serializer_errors():
|
def test_serializer_errors() -> None:
|
||||||
"""Test seralizer with broken messages."""
|
"""Test seralizer with broken messages."""
|
||||||
|
|
||||||
class Foo: # pylint: disable=too-few-public-methods
|
class Foo: # pylint: disable=too-few-public-methods
|
||||||
@ -286,7 +287,7 @@ def test_serializer_errors():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('_comparable')
|
@pytest.mark.usefixtures('_comparable')
|
||||||
def test_custom_type():
|
def test_custom_type() -> None:
|
||||||
"""Test custom type."""
|
"""Test custom type."""
|
||||||
cname = 'test_msgs/msg/custom'
|
cname = 'test_msgs/msg/custom'
|
||||||
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
|
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
|
||||||
@ -362,7 +363,7 @@ def test_custom_type():
|
|||||||
assert res == msg
|
assert res == msg
|
||||||
|
|
||||||
|
|
||||||
def test_ros1_to_cdr():
|
def test_ros1_to_cdr() -> None:
|
||||||
"""Test ROS1 to CDR conversion."""
|
"""Test ROS1 to CDR conversion."""
|
||||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||||
@ -385,7 +386,7 @@ def test_ros1_to_cdr():
|
|||||||
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
|
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
|
||||||
|
|
||||||
|
|
||||||
def test_cdr_to_ros1():
|
def test_cdr_to_ros1() -> None:
|
||||||
"""Test CDR to ROS1 conversion."""
|
"""Test CDR to ROS1 conversion."""
|
||||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||||
|
|||||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def test_writer(tmp_path: Path):
|
def test_writer(tmp_path: Path) -> None:
|
||||||
"""Test Writer."""
|
"""Test Writer."""
|
||||||
path = (tmp_path / 'rosbag2')
|
path = (tmp_path / 'rosbag2')
|
||||||
with Writer(path) as bag:
|
with Writer(path) as bag:
|
||||||
@ -60,7 +60,7 @@ def test_writer(tmp_path: Path):
|
|||||||
assert size > (path / 'compress_message.db3').stat().st_size
|
assert size > (path / 'compress_message.db3').stat().st_size
|
||||||
|
|
||||||
|
|
||||||
def test_failure_cases(tmp_path: Path):
|
def test_failure_cases(tmp_path: Path) -> None:
|
||||||
"""Test writer failure cases."""
|
"""Test writer failure cases."""
|
||||||
with pytest.raises(WriterError, match='exists'):
|
with pytest.raises(WriterError, match='exists'):
|
||||||
Writer(tmp_path)
|
Writer(tmp_path)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def test_no_overwrite(tmp_path: Path):
|
def test_no_overwrite(tmp_path: Path) -> None:
|
||||||
"""Test writer does not touch existing files."""
|
"""Test writer does not touch existing files."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
path.write_text('foo')
|
path.write_text('foo')
|
||||||
@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path):
|
|||||||
writer.open()
|
writer.open()
|
||||||
|
|
||||||
|
|
||||||
def test_empty(tmp_path: Path):
|
def test_empty(tmp_path: Path) -> None:
|
||||||
"""Test empty bag."""
|
"""Test empty bag."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ def test_empty(tmp_path: Path):
|
|||||||
assert len(data) == 13 + 4096
|
assert len(data) == 13 + 4096
|
||||||
|
|
||||||
|
|
||||||
def test_add_connection(tmp_path: Path):
|
def test_add_connection(tmp_path: Path) -> None:
|
||||||
"""Test adding of connections."""
|
"""Test adding of connections."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path):
|
|||||||
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_write_errors(tmp_path: Path):
|
def test_write_errors(tmp_path: Path) -> None:
|
||||||
"""Test write errors."""
|
"""Test write errors."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path):
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def test_write_simple(tmp_path: Path):
|
def test_write_simple(tmp_path: Path) -> None:
|
||||||
"""Test writing of messages."""
|
"""Test writing of messages."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path):
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def test_compression_errors(tmp_path: Path):
|
def test_compression_errors(tmp_path: Path) -> None:
|
||||||
"""Test compression modes."""
|
"""Test compression modes."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
with Writer(path) as writer, \
|
with Writer(path) as writer, \
|
||||||
@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||||
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||||
"""Test compression modes."""
|
"""Test compression modes."""
|
||||||
path = tmp_path / 'test.bag'
|
path = tmp_path / 'test.bag'
|
||||||
writer = Writer(path)
|
writer = Writer(path)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from rosbags.rosbag2 import Reader
|
|||||||
from rosbags.serde import deserialize_cdr
|
from rosbags.serde import deserialize_cdr
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
|
||||||
class ReaderPy: # pylint: disable=too-few-public-methods
|
class ReaderPy: # pylint: disable=too-few-public-methods
|
||||||
@ -35,7 +35,7 @@ class ReaderPy: # pylint: disable=too-few-public-methods
|
|||||||
self.reader.open(soptions, coptions)
|
self.reader.open(soptions, coptions)
|
||||||
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
||||||
|
|
||||||
def messages(self):
|
def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]:
|
||||||
"""Expose rosbag2 like generator behavior."""
|
"""Expose rosbag2 like generator behavior."""
|
||||||
while self.reader.has_next():
|
while self.reader.has_next():
|
||||||
topic, data, timestamp = self.reader.read_next()
|
topic, data, timestamp = self.reader.read_next()
|
||||||
@ -48,7 +48,7 @@ def deserialize_py(data: bytes, msgtype: str) -> Any:
|
|||||||
return deserialize_message(data, pytype)
|
return deserialize_message(data, pytype)
|
||||||
|
|
||||||
|
|
||||||
def compare_msg(lite: Any, native: Any):
|
def compare_msg(lite: Any, native: Any) -> None:
|
||||||
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
|
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -79,7 +79,7 @@ def compare_msg(lite: Any, native: Any):
|
|||||||
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
|
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
|
||||||
|
|
||||||
|
|
||||||
def compare(path: Path):
|
def compare(path: Path) -> None:
|
||||||
"""Compare raw and deserialized messages."""
|
"""Compare raw and deserialized messages."""
|
||||||
with Reader(path) as reader:
|
with Reader(path) as reader:
|
||||||
gens = (reader.messages(), ReaderPy(path).messages())
|
gens = (reader.messages(), ReaderPy(path).messages())
|
||||||
@ -100,7 +100,7 @@ def compare(path: Path):
|
|||||||
assert len(list(gens[1])) == 0
|
assert len(list(gens[1])) == 0
|
||||||
|
|
||||||
|
|
||||||
def read_deser_rosbag2_py(path: Path):
|
def read_deser_rosbag2_py(path: Path) -> None:
|
||||||
"""Read testbag with rosbag2_py."""
|
"""Read testbag with rosbag2_py."""
|
||||||
soptions = StorageOptions(str(path), 'sqlite3')
|
soptions = StorageOptions(str(path), 'sqlite3')
|
||||||
coptions = ConverterOptions('', '')
|
coptions = ConverterOptions('', '')
|
||||||
@ -115,14 +115,14 @@ def read_deser_rosbag2_py(path: Path):
|
|||||||
deserialize_message(rawdata, pytype)
|
deserialize_message(rawdata, pytype)
|
||||||
|
|
||||||
|
|
||||||
def read_deser_rosbag2(path: Path):
|
def read_deser_rosbag2(path: Path) -> None:
|
||||||
"""Read testbag with rosbag2lite."""
|
"""Read testbag with rosbag2lite."""
|
||||||
with Reader(path) as reader:
|
with Reader(path) as reader:
|
||||||
for connection, _, data in reader.messages():
|
for connection, _, data in reader.messages():
|
||||||
deserialize_cdr(data, connection.msgtype)
|
deserialize_cdr(data, connection.msgtype)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
"""Benchmark rosbag2 against rosbag2_py."""
|
"""Benchmark rosbag2 against rosbag2_py."""
|
||||||
path = Path(sys.argv[1])
|
path = Path(sys.argv[1])
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -25,7 +25,7 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
|
|||||||
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
|
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, List, Union
|
from typing import Any, Generator, List, Union
|
||||||
|
|
||||||
from rosbag.bag import _Connection_Info
|
from rosbag.bag import _Connection_Info
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class Reader: # pylint: disable=too-few-public-methods
|
|||||||
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
|
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
|
||||||
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
||||||
|
|
||||||
def messages(self):
|
def messages(self) -> Generator[tuple[str, int, bytes], None, None]:
|
||||||
"""Expose rosbag2 like generator behavior."""
|
"""Expose rosbag2 like generator behavior."""
|
||||||
while self.reader.has_next():
|
while self.reader.has_next():
|
||||||
topic, data, timestamp = self.reader.read_next()
|
topic, data, timestamp = self.reader.read_next()
|
||||||
@ -47,7 +47,7 @@ class Reader: # pylint: disable=too-few-public-methods
|
|||||||
yield topic, timestamp, deserialize_message(data, pytype)
|
yield topic, timestamp, deserialize_message(data, pytype)
|
||||||
|
|
||||||
|
|
||||||
def fixup_ros1(conns: List[_Connection_Info]):
|
def fixup_ros1(conns: List[_Connection_Info]) -> None:
|
||||||
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
|
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -69,16 +69,13 @@ def fixup_ros1(conns: List[_Connection_Info]):
|
|||||||
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
|
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
|
||||||
|
|
||||||
|
|
||||||
def compare(ref: Any, msg: Any):
|
def compare(ref: Any, msg: Any) -> None:
|
||||||
"""Compare message to its reference.
|
"""Compare message to its reference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ref: Reference ROS1 message.
|
ref: Reference ROS1 message.
|
||||||
msg: Converted ROS2 message.
|
msg: Converted ROS2 message.
|
||||||
|
|
||||||
Return:
|
|
||||||
True if messages are identical.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if hasattr(msg, 'get_fields_and_field_types'):
|
if hasattr(msg, 'get_fields_and_field_types'):
|
||||||
for name in msg.get_fields_and_field_types():
|
for name in msg.get_fields_and_field_types():
|
||||||
@ -107,7 +104,7 @@ def compare(ref: Any, msg: Any):
|
|||||||
assert ref == msg
|
assert ref == msg
|
||||||
|
|
||||||
|
|
||||||
def main_bag1_bag1(path1: Path, path2: Path):
|
def main_bag1_bag1(path1: Path, path2: Path) -> None:
|
||||||
"""Compare rosbag1 to rosbag1 message by message.
|
"""Compare rosbag1 to rosbag1 message by message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -132,7 +129,7 @@ def main_bag1_bag1(path1: Path, path2: Path):
|
|||||||
print('Bags are identical.') # noqa: T001
|
print('Bags are identical.') # noqa: T001
|
||||||
|
|
||||||
|
|
||||||
def main_bag1_bag2(path1: Path, path2: Path):
|
def main_bag1_bag2(path1: Path, path2: Path) -> None:
|
||||||
"""Compare rosbag1 to rosbag2 message by message.
|
"""Compare rosbag1 to rosbag2 message by message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user