Type generics and missing return types
This commit is contained in:
parent
ac704bd890
commit
52480e2bad
65
setup.cfg
65
setup.cfg
@ -30,6 +30,7 @@ classifiers =
|
||||
Programming Language :: Python :: 3 :: Only
|
||||
Programming Language :: Python :: 3.8
|
||||
Programming Language :: Python :: 3.9
|
||||
Programming Language :: Python :: 3.10
|
||||
Topic :: Scientific/Engineering
|
||||
Typing :: Typed
|
||||
project_urls =
|
||||
@ -109,7 +110,7 @@ avoid-escape = False
|
||||
docstring_convention = google
|
||||
docstring_style = google
|
||||
extend-exclude = venv*,.venv*
|
||||
extend-select =
|
||||
extend-select =
|
||||
# docstrings
|
||||
D204,
|
||||
D400,
|
||||
@ -119,8 +120,10 @@ extend-select =
|
||||
ignore =
|
||||
# do not require annotation of `self`
|
||||
ANN101,
|
||||
# allow line break before binary operator
|
||||
W503,
|
||||
# handled by B001
|
||||
E722,
|
||||
# allow line break after binary operator
|
||||
W504,
|
||||
max-line-length = 100
|
||||
strictness = long
|
||||
suppress-none-returning = True
|
||||
@ -134,10 +137,14 @@ multi_line_output = 3
|
||||
explicit_package_bases = True
|
||||
mypy_path = $MYPY_CONFIG_FILE_DIR/src
|
||||
namespace_packages = True
|
||||
strict = True
|
||||
|
||||
[mypy-ruamel]
|
||||
[mypy-lz4.frame]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-ruamel.yaml]
|
||||
implicit_reexport = True
|
||||
|
||||
[pydocstyle]
|
||||
convention = google
|
||||
add-select = D204,D400,D401,D404,D413
|
||||
@ -146,9 +153,37 @@ add-select = D204,D400,D401,D404,D413
|
||||
max-line-length = 100
|
||||
|
||||
[pylint.'MESSAGES CONTROL']
|
||||
enable = all
|
||||
disable =
|
||||
duplicate-code,
|
||||
ungrouped-imports,
|
||||
# isort (pylint FAQ)
|
||||
wrong-import-order,
|
||||
# mccabe (pylint FAQ)
|
||||
too-many-branches,
|
||||
# fixme
|
||||
fixme,
|
||||
# pep8-naming (pylint FAQ, keep: invalid-name)
|
||||
bad-classmethod-argument,
|
||||
bad-mcs-classmethod-argument,
|
||||
no-self-argument
|
||||
# pycodestyle (pylint FAQ)
|
||||
bad-indentation,
|
||||
bare-except,
|
||||
line-too-long,
|
||||
missing-final-newline,
|
||||
multiple-statements,
|
||||
trailing-whitespace,
|
||||
unnecessary-semicolon,
|
||||
unneeded-not,
|
||||
# pydocstyle (pylint FAQ)
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
# pyflakes (pylint FAQ)
|
||||
undefined-variable,
|
||||
unused-import,
|
||||
unused-variable,
|
||||
|
||||
[yapf]
|
||||
based_on_style = google
|
||||
@ -159,15 +194,15 @@ indent_dictionary_value = false
|
||||
|
||||
[tool:pytest]
|
||||
addopts =
|
||||
-v
|
||||
--flake8
|
||||
--mypy
|
||||
--pylint
|
||||
--yapf
|
||||
--cov=src
|
||||
--cov-branch
|
||||
--cov-report=html
|
||||
--cov-report=term
|
||||
--no-cov-on-fail
|
||||
--junitxml=report.xml
|
||||
-v
|
||||
--flake8
|
||||
--mypy
|
||||
--pylint
|
||||
--yapf
|
||||
--cov=src
|
||||
--cov-branch
|
||||
--cov-report=html
|
||||
--cov-report=term
|
||||
--no-cov-on-fail
|
||||
--junitxml=report.xml
|
||||
junit_family=xunit2
|
||||
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def pathtype(exists: bool = True) -> Callable:
|
||||
def pathtype(exists: bool = True) -> Callable[[str], Path]:
|
||||
"""Path argument for argparse.
|
||||
|
||||
Args:
|
||||
|
||||
@ -10,14 +10,15 @@ import os
|
||||
import re
|
||||
import struct
|
||||
from bz2 import decompress as bz2_decompress
|
||||
from collections import defaultdict
|
||||
from enum import Enum, IntEnum
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, NamedTuple
|
||||
|
||||
from lz4.frame import decompress as lz4_decompress # type: ignore
|
||||
from lz4.frame import decompress as lz4_decompress
|
||||
|
||||
from rosbags.typesys.msg import normalize_msgtype
|
||||
|
||||
@ -59,7 +60,7 @@ class Connection(NamedTuple):
|
||||
md5sum: str
|
||||
callerid: Optional[str]
|
||||
latching: Optional[int]
|
||||
indexes: list
|
||||
indexes: list[IndexData]
|
||||
|
||||
|
||||
class ChunkInfo(NamedTuple):
|
||||
@ -76,7 +77,7 @@ class Chunk(NamedTuple):
|
||||
|
||||
datasize: int
|
||||
datapos: int
|
||||
decompressor: Callable
|
||||
decompressor: Callable[[bytes], bytes]
|
||||
|
||||
|
||||
class TopicInfo(NamedTuple):
|
||||
@ -124,9 +125,9 @@ class IndexData(NamedTuple):
|
||||
return self.time != other[0]
|
||||
|
||||
|
||||
deserialize_uint8 = struct.Struct('<B').unpack
|
||||
deserialize_uint32 = struct.Struct('<L').unpack
|
||||
deserialize_uint64 = struct.Struct('<Q').unpack
|
||||
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
|
||||
deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
|
||||
deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
|
||||
|
||||
|
||||
def deserialize_time(val: bytes) -> int:
|
||||
@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
|
||||
Deserialized value.
|
||||
|
||||
"""
|
||||
sec, nsec = struct.unpack('<LL', val)
|
||||
unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
|
||||
sec, nsec = unpacked
|
||||
return sec * 10**9 + nsec
|
||||
|
||||
|
||||
class Header(dict):
|
||||
class Header(Dict[str, Any]):
|
||||
"""Record header."""
|
||||
|
||||
def get_uint8(self, name: str) -> int:
|
||||
@ -214,7 +216,9 @@ class Header(dict):
|
||||
|
||||
"""
|
||||
try:
|
||||
return self[name].decode()
|
||||
value = self[name]
|
||||
assert isinstance(value, bytes)
|
||||
return value.decode()
|
||||
except (KeyError, ValueError) as err:
|
||||
raise ReaderError(f'Could not read string field {name!r}.') from err
|
||||
|
||||
@ -237,7 +241,7 @@ class Header(dict):
|
||||
raise ReaderError(f'Could not read time field {name!r}.') from err
|
||||
|
||||
@classmethod
|
||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||
def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||
"""Read header from file handle.
|
||||
|
||||
Args:
|
||||
@ -362,10 +366,10 @@ class Reader:
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.chunk_infos: list[ChunkInfo] = []
|
||||
self.chunks: dict[int, Chunk] = {}
|
||||
self.current_chunk = (-1, BytesIO())
|
||||
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
|
||||
self.topics: dict[str, TopicInfo] = {}
|
||||
|
||||
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
"""Open rosbag and read metadata."""
|
||||
try:
|
||||
self.bio = self.path.open('rb')
|
||||
@ -409,24 +413,25 @@ class Reader:
|
||||
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
|
||||
|
||||
self.chunks = {}
|
||||
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
|
||||
for chunk_info in self.chunk_infos:
|
||||
self.bio.seek(chunk_info.pos)
|
||||
self.chunks[chunk_info.pos] = self.read_chunk()
|
||||
|
||||
for _ in range(len(chunk_info.connection_counts)):
|
||||
cid, index = self.read_index_data(chunk_info.pos)
|
||||
self.connections[cid].indexes.append(index)
|
||||
indexes[cid].append(index)
|
||||
|
||||
for connection in self.connections.values():
|
||||
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
|
||||
for cid, connection in self.connections.items():
|
||||
connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
|
||||
assert connection.indexes
|
||||
|
||||
self.topics = {}
|
||||
for topic, connections in groupby(
|
||||
for topic, group in groupby(
|
||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||
key=lambda x: x.topic,
|
||||
):
|
||||
connections = list(connections)
|
||||
connections = list(group)
|
||||
count = reduce(
|
||||
lambda x, y: x + y,
|
||||
(
|
||||
@ -446,7 +451,7 @@ class Reader:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag."""
|
||||
assert self.bio
|
||||
self.bio.close()
|
||||
@ -614,8 +619,8 @@ class Reader:
|
||||
|
||||
chunk_header = self.chunks[entry.chunk_pos]
|
||||
self.bio.seek(chunk_header.datapos)
|
||||
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
|
||||
rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||
self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
|
||||
|
||||
chunk = self.current_chunk[1]
|
||||
chunk.seek(entry.offset)
|
||||
|
||||
@ -11,9 +11,9 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from lz4.frame import compress as lz4_compress # type: ignore
|
||||
from lz4.frame import compress as lz4_compress
|
||||
|
||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||
|
||||
@ -21,7 +21,7 @@ from .reader import Connection, RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
|
||||
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
|
||||
|
||||
|
||||
class WriterError(Exception):
|
||||
@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
|
||||
return struct.pack('<LL', sec, nsec)
|
||||
|
||||
|
||||
class Header(dict):
|
||||
class Header(Dict[str, Any]):
|
||||
"""Record header."""
|
||||
|
||||
def set_uint32(self, name: str, value: int):
|
||||
def set_uint32(self, name: str, value: int) -> None:
|
||||
"""Set field to uint32 value.
|
||||
|
||||
Args:
|
||||
@ -70,7 +70,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = serialize_uint32(value)
|
||||
|
||||
def set_uint64(self, name: str, value: int):
|
||||
def set_uint64(self, name: str, value: int) -> None:
|
||||
"""Set field to uint64 value.
|
||||
|
||||
Args:
|
||||
@ -80,7 +80,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = serialize_uint64(value)
|
||||
|
||||
def set_string(self, name: str, value: str):
|
||||
def set_string(self, name: str, value: str) -> None:
|
||||
"""Set field to string value.
|
||||
|
||||
Args:
|
||||
@ -90,7 +90,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = value.encode()
|
||||
|
||||
def set_time(self, name: str, value: int):
|
||||
def set_time(self, name: str, value: int) -> None:
|
||||
"""Set field to time value.
|
||||
|
||||
Args:
|
||||
@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
]
|
||||
self.chunk_threshold = 1 * (1 << 20)
|
||||
|
||||
def set_compression(self, fmt: CompressionFormat):
|
||||
def set_compression(self, fmt: CompressionFormat) -> None:
|
||||
"""Enable compression on rosbag1.
|
||||
|
||||
This function has to be called before opening.
|
||||
@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
self.compression_format = fmt.name.lower()
|
||||
|
||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
|
||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
|
||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
|
||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
|
||||
self.compressor = {
|
||||
'bz2': bz2,
|
||||
'lz4': lz4,
|
||||
}[self.compression_format]
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag1 for writing."""
|
||||
try:
|
||||
self.bio = self.path.open('xb')
|
||||
except FileExistsError:
|
||||
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
||||
|
||||
assert self.bio
|
||||
self.bio.write(b'#ROSBAG V2.0\n')
|
||||
header = Header()
|
||||
header.set_uint64('index_pos', 0)
|
||||
@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.connections[connection.cid] = connection
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||
"""Write message to rosbag1.
|
||||
|
||||
Args:
|
||||
@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.write_chunk(chunk)
|
||||
|
||||
@staticmethod
|
||||
def write_connection(connection: Connection, bio: BytesIO):
|
||||
def write_connection(connection: Connection, bio: BinaryIO) -> None:
|
||||
"""Write connection record."""
|
||||
header = Header()
|
||||
header.set_uint32('conn', connection.cid)
|
||||
@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
header.set_string('latching', str(connection.latching))
|
||||
header.write(bio)
|
||||
|
||||
def write_chunk(self, chunk: WriteChunk):
|
||||
def write_chunk(self, chunk: WriteChunk) -> None:
|
||||
"""Write open chunk to file."""
|
||||
assert self.bio
|
||||
|
||||
@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
chunk.data.close()
|
||||
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag1 after writing.
|
||||
|
||||
Closes open chunks and writes index.
|
||||
|
||||
"""
|
||||
assert self.bio
|
||||
for chunk in self.chunks:
|
||||
if chunk.pos == -1:
|
||||
self.write_chunk(chunk)
|
||||
|
||||
@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.error import YAMLError
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union
|
||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
@ -25,7 +26,7 @@ class ReaderError(Exception):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def decompress(path: Path, do_decompress: bool):
|
||||
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
|
||||
"""Transparent rosbag2 database decompression context.
|
||||
|
||||
This context manager will yield a path to the decompressed file contents.
|
||||
@ -119,12 +120,12 @@ class Reader:
|
||||
except KeyError as exc:
|
||||
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
self.bio = True
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
assert self.bio
|
||||
@ -133,12 +134,14 @@ class Reader:
|
||||
@property
|
||||
def duration(self) -> int:
|
||||
"""Duration in nanoseconds between earliest and latest messages."""
|
||||
return self.metadata['duration']['nanoseconds'] + 1
|
||||
nsecs: int = self.metadata['duration']['nanoseconds']
|
||||
return nsecs + 1
|
||||
|
||||
@property
|
||||
def start_time(self) -> int:
|
||||
"""Timestamp in nanoseconds of the earliest message."""
|
||||
return self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
return nsecs
|
||||
|
||||
@property
|
||||
def end_time(self) -> int:
|
||||
@ -148,7 +151,8 @@ class Reader:
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total message count."""
|
||||
return self.metadata['message_count']
|
||||
count: int = self.metadata['message_count']
|
||||
return count
|
||||
|
||||
@property
|
||||
def compression_format(self) -> Optional[str]:
|
||||
@ -233,7 +237,7 @@ class Reader:
|
||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||
|
||||
cur.execute('SELECT name,id FROM topics')
|
||||
connmap: Dict[int, Connection] = {
|
||||
connmap: dict[int, Connection] = {
|
||||
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
|
||||
None) # type: ignore
|
||||
for row in cur
|
||||
|
||||
@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_format = ''
|
||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.conn = None
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.cursor: Optional[sqlite3.Cursor] = None
|
||||
|
||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
|
||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
|
||||
"""Enable compression on bag.
|
||||
|
||||
This function has to be called before opening.
|
||||
@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_format = fmt.name.lower()
|
||||
self.compressor = zstandard.ZstdCompressor()
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2 for writing.
|
||||
|
||||
Create base directory and open database connection.
|
||||
@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||
"""Write message to rosbag2.
|
||||
|
||||
Args:
|
||||
@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
)
|
||||
connection.count += 1
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2 after writing.
|
||||
|
||||
Closes open database transactions and writes metadata.yaml.
|
||||
|
||||
"""
|
||||
assert self.cursor
|
||||
assert self.conn
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
|
||||
@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.conn.close()
|
||||
|
||||
if self.compression_mode == 'file':
|
||||
assert self.compressor
|
||||
src = self.dbpath
|
||||
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
|
||||
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:
|
||||
|
||||
@ -19,10 +19,10 @@ from .typing import Field
|
||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
from .typing import CDRDeser, CDRSer, CDRSerSize
|
||||
|
||||
|
||||
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
||||
def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
|
||||
"""Generate cdr size calculation function.
|
||||
|
||||
Args:
|
||||
@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
||||
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
|
||||
|
||||
|
||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||
"""Generate cdr serialization function.
|
||||
|
||||
Args:
|
||||
@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
return compile_lines(lines).serialize_cdr # type: ignore
|
||||
|
||||
|
||||
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
"""Generate cdr deserialization function.
|
||||
|
||||
Args:
|
||||
|
||||
@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
|
||||
generate_serialize_cdr(fields, 'be'),
|
||||
generate_deserialize_cdr(fields, 'le'),
|
||||
generate_deserialize_cdr(fields, 'be'),
|
||||
generate_ros1_to_cdr(fields, typename, False),
|
||||
generate_ros1_to_cdr(fields, typename, True),
|
||||
generate_cdr_to_ros1(fields, typename, False),
|
||||
generate_cdr_to_ros1(fields, typename, True),
|
||||
generate_ros1_to_cdr(fields, typename, False), # type: ignore
|
||||
generate_ros1_to_cdr(fields, typename, True), # type: ignore
|
||||
generate_cdr_to_ros1(fields, typename, False), # type: ignore
|
||||
generate_cdr_to_ros1(fields, typename, True), # type: ignore
|
||||
)
|
||||
return MSGDEFCACHE[typename]
|
||||
|
||||
@ -18,10 +18,16 @@ from .typing import Field
|
||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable # pylint: disable=ungrouped-imports
|
||||
from typing import Union # pylint: disable=ungrouped-imports
|
||||
|
||||
from .typing import Bitcvt, BitcvtSize
|
||||
|
||||
|
||||
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
|
||||
def generate_ros1_to_cdr(
|
||||
fields: list[Field],
|
||||
typename: str,
|
||||
copy: bool,
|
||||
) -> Union[Bitcvt, BitcvtSize]:
|
||||
"""Generate ROS1 to CDR conversion function.
|
||||
|
||||
Args:
|
||||
@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
|
||||
aligned = anext
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname)
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
|
||||
|
||||
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
|
||||
def generate_cdr_to_ros1(
|
||||
fields: list[Field],
|
||||
typename: str,
|
||||
copy: bool,
|
||||
) -> Union[Bitcvt, BitcvtSize]:
|
||||
"""Generate CDR to ROS1 conversion function.
|
||||
|
||||
Args:
|
||||
@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
|
||||
aligned = anext
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname)
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
|
||||
@ -7,7 +7,14 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
|
||||
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
|
||||
|
||||
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
|
||||
CDRSer = Callable[[bytes, int, type], int]
|
||||
CDRSerSize = Callable[[int, type], int]
|
||||
|
||||
|
||||
class Descriptor(NamedTuple):
|
||||
@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
|
||||
"""Metadata of a message."""
|
||||
|
||||
name: str
|
||||
fields: List[Field]
|
||||
fields: list[Field]
|
||||
cls: Any
|
||||
size_cdr: int
|
||||
getsize_cdr: Callable
|
||||
serialize_cdr_le: Callable
|
||||
serialize_cdr_be: Callable
|
||||
deserialize_cdr_le: Callable
|
||||
deserialize_cdr_be: Callable
|
||||
getsize_ros1_to_cdr: Callable
|
||||
ros1_to_cdr: Callable
|
||||
getsize_cdr_to_ros1: Callable
|
||||
cdr_to_ros1: Callable
|
||||
getsize_cdr: CDRSerSize
|
||||
serialize_cdr_le: CDRSer
|
||||
serialize_cdr_be: CDRSer
|
||||
deserialize_cdr_le: CDRDeser
|
||||
deserialize_cdr_be: CDRDeser
|
||||
getsize_ros1_to_cdr: BitcvtSize
|
||||
ros1_to_cdr: Bitcvt
|
||||
getsize_cdr_to_ros1: BitcvtSize
|
||||
cdr_to_ros1: Bitcvt
|
||||
|
||||
@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
|
||||
pos = rule.skip_ws(text, 0)
|
||||
npos, trees = rule.parse(text, pos)
|
||||
assert npos == len(text), f'Could not parse: {text!r}'
|
||||
return visitor.visit(trees)
|
||||
return visitor.visit(trees) # type: ignore
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
raise TypesysError(f'Could not parse: {text!r}') from err
|
||||
|
||||
@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
|
||||
RULES = parse_grammar(GRAMMAR_IDL)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__()
|
||||
self.typedefs = {}
|
||||
self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
|
||||
|
||||
def visit_specification(self, children: Any) -> Typesdict:
|
||||
"""Process start symbol, return only children of modules."""
|
||||
|
||||
@ -50,7 +50,7 @@ class Rule:
|
||||
}
|
||||
return data
|
||||
|
||||
def parse(self, text: str, pos: int):
|
||||
def parse(self, text: str, pos: int) -> tuple[int, Any]:
|
||||
"""Apply rule at position."""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
|
||||
|
||||
RULES: dict[str, Rule] = {}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
|
||||
def visit(self, tree: Any) -> Any:
|
||||
|
||||
@ -13,12 +13,14 @@ from . import types
|
||||
from .base import Nodetype, TypesysError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from .base import Typesdict
|
||||
|
||||
INTLIKE = re.compile('^u?(bool|int|float)')
|
||||
|
||||
|
||||
def get_typehint(desc: tuple) -> str:
|
||||
def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
|
||||
"""Get python type hint for field.
|
||||
|
||||
Args:
|
||||
@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
|
||||
|
||||
"""
|
||||
if desc[0] == Nodetype.BASE:
|
||||
if match := INTLIKE.match(desc[1]):
|
||||
if match := INTLIKE.match(desc[1]): # type: ignore
|
||||
return match.group(1)
|
||||
return 'str'
|
||||
|
||||
if desc[0] == Nodetype.NAME:
|
||||
assert isinstance(desc[1], str)
|
||||
return desc[1].replace('/', '__')
|
||||
|
||||
sub = desc[1][0]
|
||||
if INTLIKE.match(sub[1]):
|
||||
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
||||
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
||||
return f'list[{get_typehint(sub)}]'
|
||||
return f'list[{get_typehint(sub)}]' # type: ignore
|
||||
|
||||
|
||||
def generate_python_code(typs: Typesdict) -> str:
|
||||
@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
|
||||
'',
|
||||
]
|
||||
|
||||
def get_ftype(ftype: tuple) -> tuple:
|
||||
def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
|
||||
if ftype[0] <= 2:
|
||||
return int(ftype[0]), ftype[1]
|
||||
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
|
||||
|
||||
@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from
|
||||
from typing import TYPE_CHECKING, Dict, List, Union, cast
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from rosbags.serde.messages import SerdeError, get_msgdef
|
||||
from rosbags.serde.typing import Msgdef
|
||||
@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
|
||||
|
||||
size = SIZEMAP[desc.args]
|
||||
pos = (pos + size - 1) & -size
|
||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
|
||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
|
||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||
ndarr = ndarr.byteswap() # no inplace on readonly array
|
||||
return ndarr, pos + num * SIZEMAP[desc.args]
|
||||
@ -278,7 +279,7 @@ def serialize_array(
|
||||
size = SIZEMAP[desc.args]
|
||||
pos = (pos + size - 1) & -size
|
||||
size *= len(val)
|
||||
val = cast(numpy.ndarray, val)
|
||||
val = cast(NDArray[numpy.int_], val)
|
||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||
val = val.byteswap() # no inplace on readonly array
|
||||
rawdata[pos:pos + size] = memoryview(val.tobytes())
|
||||
|
||||
@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError
|
||||
from rosbags.rosbag2 import WriterError
|
||||
|
||||
|
||||
def test_cliwrapper(tmp_path: Path):
|
||||
def test_cliwrapper(tmp_path: Path) -> None:
|
||||
"""Test cli wrapper."""
|
||||
(tmp_path / 'subdir').mkdir()
|
||||
(tmp_path / 'ros1.bag').write_text('')
|
||||
@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path):
|
||||
mock_print.assert_called_with('ERROR: exc')
|
||||
|
||||
|
||||
def test_convert(tmp_path: Path):
|
||||
def test_convert(tmp_path: Path) -> None:
|
||||
"""Test conversion function."""
|
||||
(tmp_path / 'subdir').mkdir()
|
||||
(tmp_path / 'foo.bag').write_text('')
|
||||
|
||||
@ -142,13 +142,13 @@ module test_msgs {
|
||||
"""
|
||||
|
||||
|
||||
def test_parse_empty_msg():
|
||||
def test_parse_empty_msg() -> None:
|
||||
"""Test msg parser with empty message."""
|
||||
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
|
||||
assert ret == {'std_msgs/msg/Empty': ([], [])}
|
||||
|
||||
|
||||
def test_parse_bounds_msg():
|
||||
def test_parse_bounds_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
|
||||
assert ret == {
|
||||
@ -168,7 +168,7 @@ def test_parse_bounds_msg():
|
||||
}
|
||||
|
||||
|
||||
def test_parse_defaults_msg():
|
||||
def test_parse_defaults_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
|
||||
assert ret == {
|
||||
@ -188,7 +188,7 @@ def test_parse_defaults_msg():
|
||||
}
|
||||
|
||||
|
||||
def test_parse_msg():
|
||||
def test_parse_msg() -> None:
|
||||
"""Test msg parser."""
|
||||
with pytest.raises(TypesysError, match='Could not parse'):
|
||||
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
|
||||
@ -208,7 +208,7 @@ def test_parse_msg():
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
|
||||
def test_parse_multi_msg():
|
||||
def test_parse_multi_msg() -> None:
|
||||
"""Test multi msg parser."""
|
||||
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
|
||||
assert len(ret) == 3
|
||||
@ -223,7 +223,7 @@ def test_parse_multi_msg():
|
||||
assert consts == [('static', 'uint32', 42)]
|
||||
|
||||
|
||||
def test_parse_cstring_confusion():
|
||||
def test_parse_cstring_confusion() -> None:
|
||||
"""Test if msg separator is confused with const string."""
|
||||
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
|
||||
assert len(ret) == 2
|
||||
@ -235,7 +235,7 @@ def test_parse_cstring_confusion():
|
||||
assert fields[1][1][1] == 'string'
|
||||
|
||||
|
||||
def test_parse_relative_siblings_msg():
|
||||
def test_parse_relative_siblings_msg() -> None:
|
||||
"""Test relative siblings with msg parser."""
|
||||
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
|
||||
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||
@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg():
|
||||
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
|
||||
|
||||
|
||||
def test_parse_idl():
|
||||
def test_parse_idl() -> None:
|
||||
"""Test idl parser."""
|
||||
ret = get_types_from_idl(IDL_LANG)
|
||||
assert ret == {}
|
||||
@ -267,21 +267,21 @@ def test_parse_idl():
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
|
||||
def test_register_types():
|
||||
def test_register_types() -> None:
|
||||
"""Test type registeration."""
|
||||
assert 'foo' not in FIELDDEFS
|
||||
register_types({})
|
||||
register_types({'foo': [[], [('b', (1, 'bool'))]]})
|
||||
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
|
||||
assert 'foo' in FIELDDEFS
|
||||
|
||||
register_types({'std_msgs/msg/Header': [[], []]})
|
||||
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
|
||||
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
|
||||
|
||||
with pytest.raises(TypesysError, match='different definition'):
|
||||
register_types({'foo': [[], [('x', (1, 'bool'))]]})
|
||||
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
|
||||
|
||||
|
||||
def test_generate_msgdef():
|
||||
def test_generate_msgdef() -> None:
|
||||
"""Test message definition generator."""
|
||||
res = generate_msgdef('std_msgs/msg/Header')
|
||||
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
|
||||
|
||||
@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_reader(bag: Path):
|
||||
def test_reader(bag: Path) -> None:
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
with Reader(bag) as reader:
|
||||
assert reader.duration == 43
|
||||
@ -151,7 +151,7 @@ def test_reader(bag: Path):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_message_filters(bag: Path):
|
||||
def test_message_filters(bag: Path) -> None:
|
||||
"""Test reader filters messages."""
|
||||
with Reader(bag) as reader:
|
||||
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
|
||||
@ -188,14 +188,14 @@ def test_message_filters(bag: Path):
|
||||
next(gen)
|
||||
|
||||
|
||||
def test_user_errors(bag: Path):
|
||||
def test_user_errors(bag: Path) -> None:
|
||||
"""Test user errors."""
|
||||
reader = Reader(bag)
|
||||
with pytest.raises(ReaderError, match='Rosbag is not open'):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path):
|
||||
def test_failure_cases(tmp_path: Path) -> None:
|
||||
"""Test bags with broken fs layout."""
|
||||
with pytest.raises(ReaderError, match='not read metadata'):
|
||||
Reader(tmp_path)
|
||||
|
||||
@ -2,8 +2,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Reader tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from struct import pack
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -11,8 +14,12 @@ import pytest
|
||||
from rosbags.rosbag1 import Reader, ReaderError
|
||||
from rosbags.rosbag1.reader import IndexData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
def ser(data):
|
||||
|
||||
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
|
||||
"""Serialize record header."""
|
||||
if isinstance(data, dict):
|
||||
fields = []
|
||||
@ -23,7 +30,7 @@ def ser(data):
|
||||
return pack('<L', len(data)) + data
|
||||
|
||||
|
||||
def create_default_header():
|
||||
def create_default_header() -> dict[str, bytes]:
|
||||
"""Create empty rosbag header."""
|
||||
return {
|
||||
'op': b'\x03',
|
||||
@ -32,7 +39,11 @@ def create_default_header():
|
||||
}
|
||||
|
||||
|
||||
def create_connection(cid=1, topic=0, typ=0):
|
||||
def create_connection(
|
||||
cid: int = 1,
|
||||
topic: int = 0,
|
||||
typ: int = 0,
|
||||
) -> tuple[dict[str, bytes], dict[str, bytes]]:
|
||||
"""Create connection record."""
|
||||
return {
|
||||
'op': b'\x07',
|
||||
@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0):
|
||||
}
|
||||
|
||||
|
||||
def create_message(cid=1, time=0, msg=0):
|
||||
def create_message(
|
||||
cid: int = 1,
|
||||
time: int = 0,
|
||||
msg: int = 0,
|
||||
) -> tuple[dict[str, Union[bytes, int]], bytes]:
|
||||
"""Create message record."""
|
||||
return {
|
||||
'op': b'\x02',
|
||||
@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0):
|
||||
}, f'MSGCONTENT{msg}'.encode()
|
||||
|
||||
|
||||
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements
|
||||
def write_bag( # pylint: disable=too-many-locals,too-many-statements
|
||||
|
||||
bag: Path,
|
||||
header: dict[str, bytes],
|
||||
chunks: Sequence[Any] = (),
|
||||
) -> None:
|
||||
"""Write bag file."""
|
||||
magic = b'#ROSBAG V2.0\n'
|
||||
|
||||
@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
||||
chunk_bytes = b''
|
||||
start_time = 2**32 - 1
|
||||
end_time = 0
|
||||
counts = defaultdict(int)
|
||||
counts: dict[int, int] = defaultdict(int)
|
||||
index = {}
|
||||
offset = 0
|
||||
|
||||
@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
||||
'count': 0,
|
||||
'msgs': b'',
|
||||
}
|
||||
index[conn]['count'] += 1
|
||||
index[conn]['msgs'] += pack('<LLL', time, 0, offset)
|
||||
index[conn]['count'] += 1 # type: ignore
|
||||
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
|
||||
|
||||
add = ser(head) + ser(data)
|
||||
chunk_bytes += add
|
||||
@ -140,19 +160,19 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
|
||||
if 'index_pos' not in header:
|
||||
header['index_pos'] = pack('<Q', pos)
|
||||
|
||||
header = ser(header)
|
||||
header += b'\x20' * (4096 - len(header))
|
||||
header_bytes = ser(header)
|
||||
header_bytes += b'\x20' * (4096 - len(header_bytes))
|
||||
|
||||
bag.write_bytes(b''.join([
|
||||
magic,
|
||||
header,
|
||||
header_bytes,
|
||||
chunks_bytes,
|
||||
connections,
|
||||
chunkinfos,
|
||||
]))
|
||||
|
||||
|
||||
def test_indexdata():
|
||||
def test_indexdata() -> None:
|
||||
"""Test IndexData sort sorder."""
|
||||
x42_1_0 = IndexData(42, 1, 0)
|
||||
x42_2_0 = IndexData(42, 2, 0)
|
||||
@ -175,7 +195,7 @@ def test_indexdata():
|
||||
assert not x42_1_0 > x43_3_0
|
||||
|
||||
|
||||
def test_reader(tmp_path): # pylint: disable=too-many-statements
|
||||
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||
"""Test reader and deserializer on simple bag."""
|
||||
# empty bag
|
||||
bag = tmp_path / 'test.bag'
|
||||
@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
|
||||
assert msgs[0][2] == b'MSGCONTENT5'
|
||||
|
||||
|
||||
def test_user_errors(tmp_path):
|
||||
def test_user_errors(tmp_path: Path) -> None:
|
||||
"""Test user errors."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
write_bag(bag, create_default_header(), chunks=[[
|
||||
@ -281,7 +301,7 @@ def test_user_errors(tmp_path):
|
||||
next(reader.messages())
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements
|
||||
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
|
||||
"""Test failure cases."""
|
||||
bag = tmp_path / 'test.bag'
|
||||
with pytest.raises(ReaderError, match='does not exist'):
|
||||
|
||||
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
|
||||
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
|
||||
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
|
||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
||||
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||
"""Test full data roundtrip."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
|
||||
@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__
|
||||
from .cdr import deserialize, serialize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Tuple, Union
|
||||
from typing import Any, Generator, Union
|
||||
|
||||
MSG_POLY = (
|
||||
(
|
||||
@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _comparable():
|
||||
def _comparable() -> Generator[None, None, None]:
|
||||
"""Make messages containing numpy arrays comparable.
|
||||
|
||||
Notes:
|
||||
@ -180,7 +180,7 @@ def _comparable():
|
||||
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
|
||||
lhs = self._mock_wraps # pylint: disable=protected-access
|
||||
rhs = getattr(other, '_mock_wraps', other)
|
||||
return (lhs == rhs).all()
|
||||
return (lhs == rhs).all() # type: ignore
|
||||
|
||||
class CNDArray(MagicMock):
|
||||
"""Mock ndarray."""
|
||||
@ -194,14 +194,14 @@ def _comparable():
|
||||
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
|
||||
|
||||
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
|
||||
return CNDArray(wraps=frombuffer(*args, **kwargs))
|
||||
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
|
||||
|
||||
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize('message', MESSAGES)
|
||||
def test_serde(message: Tuple[bytes, str, bool]):
|
||||
def test_serde(message: tuple[bytes, str, bool]) -> None:
|
||||
"""Test serialization deserialization roundtrip."""
|
||||
rawdata, typ, is_little = message
|
||||
|
||||
@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_deserializer():
|
||||
def test_deserializer() -> None:
|
||||
"""Test deserializer."""
|
||||
msg = deserialize_cdr(*MSG_POLY[:2])
|
||||
assert msg == deserialize(*MSG_POLY[:2])
|
||||
@ -233,7 +233,8 @@ def test_deserializer():
|
||||
assert msg.header.frame_id == 'foo42'
|
||||
field = msg.magnetic_field
|
||||
assert (field.x, field.y, field.z) == (128., 128., 128.)
|
||||
assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all()
|
||||
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore
|
||||
assert (diag == [1., 1., 1.]).all()
|
||||
|
||||
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
|
||||
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
|
||||
@ -241,7 +242,7 @@ def test_deserializer():
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_serializer():
|
||||
def test_serializer() -> None:
|
||||
"""Test serializer."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
@ -268,7 +269,7 @@ def test_serializer():
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_serializer_errors():
|
||||
def test_serializer_errors() -> None:
|
||||
"""Test seralizer with broken messages."""
|
||||
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
@ -286,7 +287,7 @@ def test_serializer_errors():
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_custom_type():
|
||||
def test_custom_type() -> None:
|
||||
"""Test custom type."""
|
||||
cname = 'test_msgs/msg/custom'
|
||||
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
|
||||
@ -362,7 +363,7 @@ def test_custom_type():
|
||||
assert res == msg
|
||||
|
||||
|
||||
def test_ros1_to_cdr():
|
||||
def test_ros1_to_cdr() -> None:
|
||||
"""Test ROS1 to CDR conversion."""
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
@ -385,7 +386,7 @@ def test_ros1_to_cdr():
|
||||
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
|
||||
|
||||
|
||||
def test_cdr_to_ros1():
|
||||
def test_cdr_to_ros1() -> None:
|
||||
"""Test CDR to ROS1 conversion."""
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_writer(tmp_path: Path):
|
||||
def test_writer(tmp_path: Path) -> None:
|
||||
"""Test Writer."""
|
||||
path = (tmp_path / 'rosbag2')
|
||||
with Writer(path) as bag:
|
||||
@ -60,7 +60,7 @@ def test_writer(tmp_path: Path):
|
||||
assert size > (path / 'compress_message.db3').stat().st_size
|
||||
|
||||
|
||||
def test_failure_cases(tmp_path: Path):
|
||||
def test_failure_cases(tmp_path: Path) -> None:
|
||||
"""Test writer failure cases."""
|
||||
with pytest.raises(WriterError, match='exists'):
|
||||
Writer(tmp_path)
|
||||
|
||||
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def test_no_overwrite(tmp_path: Path):
|
||||
def test_no_overwrite(tmp_path: Path) -> None:
|
||||
"""Test writer does not touch existing files."""
|
||||
path = tmp_path / 'test.bag'
|
||||
path.write_text('foo')
|
||||
@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path):
|
||||
writer.open()
|
||||
|
||||
|
||||
def test_empty(tmp_path: Path):
|
||||
def test_empty(tmp_path: Path) -> None:
|
||||
"""Test empty bag."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
@ -40,7 +40,7 @@ def test_empty(tmp_path: Path):
|
||||
assert len(data) == 13 + 4096
|
||||
|
||||
|
||||
def test_add_connection(tmp_path: Path):
|
||||
def test_add_connection(tmp_path: Path) -> None:
|
||||
"""Test adding of connections."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path):
|
||||
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
||||
|
||||
|
||||
def test_write_errors(tmp_path: Path):
|
||||
def test_write_errors(tmp_path: Path) -> None:
|
||||
"""Test write errors."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path):
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_write_simple(tmp_path: Path):
|
||||
def test_write_simple(tmp_path: Path) -> None:
|
||||
"""Test writing of messages."""
|
||||
path = tmp_path / 'test.bag'
|
||||
|
||||
@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path):
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_compression_errors(tmp_path: Path):
|
||||
def test_compression_errors(tmp_path: Path) -> None:
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
with Writer(path) as writer, \
|
||||
@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
|
||||
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
|
||||
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
|
||||
"""Test compression modes."""
|
||||
path = tmp_path / 'test.bag'
|
||||
writer = Writer(path)
|
||||
|
||||
@ -21,7 +21,7 @@ from rosbags.rosbag2 import Reader
|
||||
from rosbags.serde import deserialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from typing import Any, Generator
|
||||
|
||||
|
||||
class ReaderPy: # pylint: disable=too-few-public-methods
|
||||
@ -35,7 +35,7 @@ class ReaderPy: # pylint: disable=too-few-public-methods
|
||||
self.reader.open(soptions, coptions)
|
||||
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
||||
|
||||
def messages(self):
|
||||
def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]:
|
||||
"""Expose rosbag2 like generator behavior."""
|
||||
while self.reader.has_next():
|
||||
topic, data, timestamp = self.reader.read_next()
|
||||
@ -48,7 +48,7 @@ def deserialize_py(data: bytes, msgtype: str) -> Any:
|
||||
return deserialize_message(data, pytype)
|
||||
|
||||
|
||||
def compare_msg(lite: Any, native: Any):
|
||||
def compare_msg(lite: Any, native: Any) -> None:
|
||||
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
|
||||
|
||||
Args:
|
||||
@ -79,7 +79,7 @@ def compare_msg(lite: Any, native: Any):
|
||||
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
|
||||
|
||||
|
||||
def compare(path: Path):
|
||||
def compare(path: Path) -> None:
|
||||
"""Compare raw and deserialized messages."""
|
||||
with Reader(path) as reader:
|
||||
gens = (reader.messages(), ReaderPy(path).messages())
|
||||
@ -100,7 +100,7 @@ def compare(path: Path):
|
||||
assert len(list(gens[1])) == 0
|
||||
|
||||
|
||||
def read_deser_rosbag2_py(path: Path):
|
||||
def read_deser_rosbag2_py(path: Path) -> None:
|
||||
"""Read testbag with rosbag2_py."""
|
||||
soptions = StorageOptions(str(path), 'sqlite3')
|
||||
coptions = ConverterOptions('', '')
|
||||
@ -115,14 +115,14 @@ def read_deser_rosbag2_py(path: Path):
|
||||
deserialize_message(rawdata, pytype)
|
||||
|
||||
|
||||
def read_deser_rosbag2(path: Path):
|
||||
def read_deser_rosbag2(path: Path) -> None:
|
||||
"""Read testbag with rosbag2lite."""
|
||||
with Reader(path) as reader:
|
||||
for connection, _, data in reader.messages():
|
||||
deserialize_cdr(data, connection.msgtype)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""Benchmark rosbag2 against rosbag2_py."""
|
||||
path = Path(sys.argv[1])
|
||||
try:
|
||||
|
||||
@ -25,7 +25,7 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
|
||||
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Generator, List, Union
|
||||
|
||||
from rosbag.bag import _Connection_Info
|
||||
|
||||
@ -39,7 +39,7 @@ class Reader: # pylint: disable=too-few-public-methods
|
||||
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
|
||||
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
|
||||
|
||||
def messages(self):
|
||||
def messages(self) -> Generator[tuple[str, int, bytes], None, None]:
|
||||
"""Expose rosbag2 like generator behavior."""
|
||||
while self.reader.has_next():
|
||||
topic, data, timestamp = self.reader.read_next()
|
||||
@ -47,7 +47,7 @@ class Reader: # pylint: disable=too-few-public-methods
|
||||
yield topic, timestamp, deserialize_message(data, pytype)
|
||||
|
||||
|
||||
def fixup_ros1(conns: List[_Connection_Info]):
|
||||
def fixup_ros1(conns: List[_Connection_Info]) -> None:
|
||||
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
|
||||
|
||||
Args:
|
||||
@ -69,16 +69,13 @@ def fixup_ros1(conns: List[_Connection_Info]):
|
||||
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
|
||||
|
||||
|
||||
def compare(ref: Any, msg: Any):
|
||||
def compare(ref: Any, msg: Any) -> None:
|
||||
"""Compare message to its reference.
|
||||
|
||||
Args:
|
||||
ref: Reference ROS1 message.
|
||||
msg: Converted ROS2 message.
|
||||
|
||||
Return:
|
||||
True if messages are identical.
|
||||
|
||||
"""
|
||||
if hasattr(msg, 'get_fields_and_field_types'):
|
||||
for name in msg.get_fields_and_field_types():
|
||||
@ -107,7 +104,7 @@ def compare(ref: Any, msg: Any):
|
||||
assert ref == msg
|
||||
|
||||
|
||||
def main_bag1_bag1(path1: Path, path2: Path):
|
||||
def main_bag1_bag1(path1: Path, path2: Path) -> None:
|
||||
"""Compare rosbag1 to rosbag1 message by message.
|
||||
|
||||
Args:
|
||||
@ -132,7 +129,7 @@ def main_bag1_bag1(path1: Path, path2: Path):
|
||||
print('Bags are identical.') # noqa: T001
|
||||
|
||||
|
||||
def main_bag1_bag2(path1: Path, path2: Path):
|
||||
def main_bag1_bag2(path1: Path, path2: Path) -> None:
|
||||
"""Compare rosbag1 to rosbag2 message by message.
|
||||
|
||||
Args:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user