Type generics and missing return types

This commit is contained in:
Marko Durkovic 2021-11-25 14:26:17 +01:00
parent ac704bd890
commit 52480e2bad
26 changed files with 263 additions and 175 deletions

View File

@ -30,6 +30,7 @@ classifiers =
Programming Language :: Python :: 3 :: Only Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering Topic :: Scientific/Engineering
Typing :: Typed Typing :: Typed
project_urls = project_urls =
@ -119,8 +120,10 @@ extend-select =
ignore = ignore =
# do not require annotation of `self` # do not require annotation of `self`
ANN101, ANN101,
# allow line break before binary operator # handled by B001
W503, E722,
# allow line break after binary operator
W504,
max-line-length = 100 max-line-length = 100
strictness = long strictness = long
suppress-none-returning = True suppress-none-returning = True
@ -134,10 +137,14 @@ multi_line_output = 3
explicit_package_bases = True explicit_package_bases = True
mypy_path = $MYPY_CONFIG_FILE_DIR/src mypy_path = $MYPY_CONFIG_FILE_DIR/src
namespace_packages = True namespace_packages = True
strict = True
[mypy-ruamel] [mypy-lz4.frame]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-ruamel.yaml]
implicit_reexport = True
[pydocstyle] [pydocstyle]
convention = google convention = google
add-select = D204,D400,D401,D404,D413 add-select = D204,D400,D401,D404,D413
@ -146,9 +153,37 @@ add-select = D204,D400,D401,D404,D413
max-line-length = 100 max-line-length = 100
[pylint.'MESSAGES CONTROL'] [pylint.'MESSAGES CONTROL']
enable = all
disable = disable =
duplicate-code, duplicate-code,
ungrouped-imports, ungrouped-imports,
# isort (pylint FAQ)
wrong-import-order,
# mccabe (pylint FAQ)
too-many-branches,
# fixme
fixme,
# pep8-naming (pylint FAQ, keep: invalid-name)
bad-classmethod-argument,
bad-mcs-classmethod-argument,
no-self-argument
# pycodestyle (pylint FAQ)
bad-indentation,
bare-except,
line-too-long,
missing-final-newline,
multiple-statements,
trailing-whitespace,
unnecessary-semicolon,
unneeded-not,
# pydocstyle (pylint FAQ)
missing-class-docstring,
missing-function-docstring,
missing-module-docstring,
# pyflakes (pylint FAQ)
undefined-variable,
unused-import,
unused-variable,
[yapf] [yapf]
based_on_style = google based_on_style = google

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
from typing import Callable from typing import Callable
def pathtype(exists: bool = True) -> Callable: def pathtype(exists: bool = True) -> Callable[[str], Path]:
"""Path argument for argparse. """Path argument for argparse.
Args: Args:

View File

@ -10,14 +10,15 @@ import os
import re import re
import struct import struct
from bz2 import decompress as bz2_decompress from bz2 import decompress as bz2_decompress
from collections import defaultdict
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import reduce from functools import reduce
from io import BytesIO from io import BytesIO
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING, Any, Dict, NamedTuple
from lz4.frame import decompress as lz4_decompress # type: ignore from lz4.frame import decompress as lz4_decompress
from rosbags.typesys.msg import normalize_msgtype from rosbags.typesys.msg import normalize_msgtype
@ -59,7 +60,7 @@ class Connection(NamedTuple):
md5sum: str md5sum: str
callerid: Optional[str] callerid: Optional[str]
latching: Optional[int] latching: Optional[int]
indexes: list indexes: list[IndexData]
class ChunkInfo(NamedTuple): class ChunkInfo(NamedTuple):
@ -76,7 +77,7 @@ class Chunk(NamedTuple):
datasize: int datasize: int
datapos: int datapos: int
decompressor: Callable decompressor: Callable[[bytes], bytes]
class TopicInfo(NamedTuple): class TopicInfo(NamedTuple):
@ -124,9 +125,9 @@ class IndexData(NamedTuple):
return self.time != other[0] return self.time != other[0]
deserialize_uint8 = struct.Struct('<B').unpack deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
deserialize_uint32 = struct.Struct('<L').unpack deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
deserialize_uint64 = struct.Struct('<Q').unpack deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
def deserialize_time(val: bytes) -> int: def deserialize_time(val: bytes) -> int:
@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
Deserialized value. Deserialized value.
""" """
sec, nsec = struct.unpack('<LL', val) unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
sec, nsec = unpacked
return sec * 10**9 + nsec return sec * 10**9 + nsec
class Header(dict): class Header(Dict[str, Any]):
"""Record header.""" """Record header."""
def get_uint8(self, name: str) -> int: def get_uint8(self, name: str) -> int:
@ -214,7 +216,9 @@ class Header(dict):
""" """
try: try:
return self[name].decode() value = self[name]
assert isinstance(value, bytes)
return value.decode()
except (KeyError, ValueError) as err: except (KeyError, ValueError) as err:
raise ReaderError(f'Could not read string field {name!r}.') from err raise ReaderError(f'Could not read string field {name!r}.') from err
@ -237,7 +241,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod @classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header: def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle. """Read header from file handle.
Args: Args:
@ -362,10 +366,10 @@ class Reader:
self.connections: dict[int, Connection] = {} self.connections: dict[int, Connection] = {}
self.chunk_infos: list[ChunkInfo] = [] self.chunk_infos: list[ChunkInfo] = []
self.chunks: dict[int, Chunk] = {} self.chunks: dict[int, Chunk] = {}
self.current_chunk = (-1, BytesIO()) self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
self.topics: dict[str, TopicInfo] = {} self.topics: dict[str, TopicInfo] = {}
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
"""Open rosbag and read metadata.""" """Open rosbag and read metadata."""
try: try:
self.bio = self.path.open('rb') self.bio = self.path.open('rb')
@ -409,24 +413,25 @@ class Reader:
raise ReaderError(f'Bag index looks damaged: {err.args}') from None raise ReaderError(f'Bag index looks damaged: {err.args}') from None
self.chunks = {} self.chunks = {}
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
for chunk_info in self.chunk_infos: for chunk_info in self.chunk_infos:
self.bio.seek(chunk_info.pos) self.bio.seek(chunk_info.pos)
self.chunks[chunk_info.pos] = self.read_chunk() self.chunks[chunk_info.pos] = self.read_chunk()
for _ in range(len(chunk_info.connection_counts)): for _ in range(len(chunk_info.connection_counts)):
cid, index = self.read_index_data(chunk_info.pos) cid, index = self.read_index_data(chunk_info.pos)
self.connections[cid].indexes.append(index) indexes[cid].append(index)
for connection in self.connections.values(): for cid, connection in self.connections.items():
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time)) connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
assert connection.indexes assert connection.indexes
self.topics = {} self.topics = {}
for topic, connections in groupby( for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic), sorted(self.connections.values(), key=lambda x: x.topic),
key=lambda x: x.topic, key=lambda x: x.topic,
): ):
connections = list(connections) connections = list(group)
count = reduce( count = reduce(
lambda x, y: x + y, lambda x, y: x + y,
( (
@ -446,7 +451,7 @@ class Reader:
self.close() self.close()
raise raise
def close(self): def close(self) -> None:
"""Close rosbag.""" """Close rosbag."""
assert self.bio assert self.bio
self.bio.close() self.bio.close()
@ -614,8 +619,8 @@ class Reader:
chunk_header = self.chunks[entry.chunk_pos] chunk_header = self.chunks[entry.chunk_pos]
self.bio.seek(chunk_header.datapos) self.bio.seek(chunk_header.datapos)
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize)) rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
self.current_chunk = (entry.chunk_pos, BytesIO(chunk)) self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
chunk = self.current_chunk[1] chunk = self.current_chunk[1]
chunk.seek(entry.offset) chunk.seek(entry.offset)

View File

@ -11,9 +11,9 @@ from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict
from lz4.frame import compress as lz4_compress # type: ignore from lz4.frame import compress as lz4_compress
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
@ -21,7 +21,7 @@ from .reader import Connection, RecordType
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union from typing import BinaryIO, Callable, Literal, Optional, Type, Union
class WriterError(Exception): class WriterError(Exception):
@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
return struct.pack('<LL', sec, nsec) return struct.pack('<LL', sec, nsec)
class Header(dict): class Header(Dict[str, Any]):
"""Record header.""" """Record header."""
def set_uint32(self, name: str, value: int): def set_uint32(self, name: str, value: int) -> None:
"""Set field to uint32 value. """Set field to uint32 value.
Args: Args:
@ -70,7 +70,7 @@ class Header(dict):
""" """
self[name] = serialize_uint32(value) self[name] = serialize_uint32(value)
def set_uint64(self, name: str, value: int): def set_uint64(self, name: str, value: int) -> None:
"""Set field to uint64 value. """Set field to uint64 value.
Args: Args:
@ -80,7 +80,7 @@ class Header(dict):
""" """
self[name] = serialize_uint64(value) self[name] = serialize_uint64(value)
def set_string(self, name: str, value: str): def set_string(self, name: str, value: str) -> None:
"""Set field to string value. """Set field to string value.
Args: Args:
@ -90,7 +90,7 @@ class Header(dict):
""" """
self[name] = value.encode() self[name] = value.encode()
def set_time(self, name: str, value: int): def set_time(self, name: str, value: int) -> None:
"""Set field to time value. """Set field to time value.
Args: Args:
@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
] ]
self.chunk_threshold = 1 * (1 << 20) self.chunk_threshold = 1 * (1 << 20)
def set_compression(self, fmt: CompressionFormat): def set_compression(self, fmt: CompressionFormat) -> None:
"""Enable compression on rosbag1. """Enable compression on rosbag1.
This function has to be called before opening. This function has to be called before opening.
@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower() self.compression_format = fmt.name.lower()
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9) bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16) lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
self.compressor = { self.compressor = {
'bz2': bz2, 'bz2': bz2,
'lz4': lz4, 'lz4': lz4,
}[self.compression_format] }[self.compression_format]
def open(self): def open(self) -> None:
"""Open rosbag1 for writing.""" """Open rosbag1 for writing."""
try: try:
self.bio = self.path.open('xb') self.bio = self.path.open('xb')
except FileExistsError: except FileExistsError:
raise WriterError(f'{self.path} exists already, not overwriting.') from None raise WriterError(f'{self.path} exists already, not overwriting.') from None
assert self.bio
self.bio.write(b'#ROSBAG V2.0\n') self.bio.write(b'#ROSBAG V2.0\n')
header = Header() header = Header()
header.set_uint64('index_pos', 0) header.set_uint64('index_pos', 0)
@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.connections[connection.cid] = connection self.connections[connection.cid] = connection
return connection return connection
def write(self, connection: Connection, timestamp: int, data: bytes): def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag1. """Write message to rosbag1.
Args: Args:
@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.write_chunk(chunk) self.write_chunk(chunk)
@staticmethod @staticmethod
def write_connection(connection: Connection, bio: BytesIO): def write_connection(connection: Connection, bio: BinaryIO) -> None:
"""Write connection record.""" """Write connection record."""
header = Header() header = Header()
header.set_uint32('conn', connection.cid) header.set_uint32('conn', connection.cid)
@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
header.set_string('latching', str(connection.latching)) header.set_string('latching', str(connection.latching))
header.write(bio) header.write(bio)
def write_chunk(self, chunk: WriteChunk): def write_chunk(self, chunk: WriteChunk) -> None:
"""Write open chunk to file.""" """Write open chunk to file."""
assert self.bio assert self.bio
@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
chunk.data.close() chunk.data.close()
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list))) self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
def close(self): def close(self) -> None:
"""Close rosbag1 after writing. """Close rosbag1 after writing.
Closes open chunks and writes index. Closes open chunks and writes index.
""" """
assert self.bio
for chunk in self.chunks: for chunk in self.chunks:
if chunk.pos == -1: if chunk.pos == -1:
self.write_chunk(chunk) self.write_chunk(chunk)

View File

@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import zstandard import zstandard
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError
from .connection import Connection from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
class ReaderError(Exception): class ReaderError(Exception):
@ -25,7 +26,7 @@ class ReaderError(Exception):
@contextmanager @contextmanager
def decompress(path: Path, do_decompress: bool): def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
"""Transparent rosbag2 database decompression context. """Transparent rosbag2 database decompression context.
This context manager will yield a path to the decompressed file contents. This context manager will yield a path to the decompressed file contents.
@ -119,12 +120,12 @@ class Reader:
except KeyError as exc: except KeyError as exc:
raise ReaderError(f'A metadata key is missing {exc!r}.') from None raise ReaderError(f'A metadata key is missing {exc!r}.') from None
def open(self): def open(self) -> None:
"""Open rosbag2.""" """Open rosbag2."""
# Future storage formats will require file handles. # Future storage formats will require file handles.
self.bio = True self.bio = True
def close(self): def close(self) -> None:
"""Close rosbag2.""" """Close rosbag2."""
# Future storage formats will require file handles. # Future storage formats will require file handles.
assert self.bio assert self.bio
@ -133,12 +134,14 @@ class Reader:
@property @property
def duration(self) -> int: def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages.""" """Duration in nanoseconds between earliest and latest messages."""
return self.metadata['duration']['nanoseconds'] + 1 nsecs: int = self.metadata['duration']['nanoseconds']
return nsecs + 1
@property @property
def start_time(self) -> int: def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message.""" """Timestamp in nanoseconds of the earliest message."""
return self.metadata['starting_time']['nanoseconds_since_epoch'] nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
return nsecs
@property @property
def end_time(self) -> int: def end_time(self) -> int:
@ -148,7 +151,8 @@ class Reader:
@property @property
def message_count(self) -> int: def message_count(self) -> int:
"""Total message count.""" """Total message count."""
return self.metadata['message_count'] count: int = self.metadata['message_count']
return count
@property @property
def compression_format(self) -> Optional[str]: def compression_format(self) -> Optional[str]:
@ -233,7 +237,7 @@ class Reader:
raise ReaderError(f'Cannot open database {path} or database missing tables.') raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics') cur.execute('SELECT name,id FROM topics')
connmap: Dict[int, Connection] = { connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections.values() if x.topic == row[0]), row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
None) # type: ignore None) # type: ignore
for row in cur for row in cur

View File

@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = '' self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {} self.connections: dict[int, Connection] = {}
self.conn = None self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None self.cursor: Optional[sqlite3.Cursor] = None
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat): def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
"""Enable compression on bag. """Enable compression on bag.
This function has to be called before opening. This function has to be called before opening.
@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower() self.compression_format = fmt.name.lower()
self.compressor = zstandard.ZstdCompressor() self.compressor = zstandard.ZstdCompressor()
def open(self): def open(self) -> None:
"""Open rosbag2 for writing. """Open rosbag2 for writing.
Create base directory and open database connection. Create base directory and open database connection.
@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection return connection
def write(self, connection: Connection, timestamp: int, data: bytes): def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag2. """Write message to rosbag2.
Args: Args:
@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
) )
connection.count += 1 connection.count += 1
def close(self): def close(self) -> None:
"""Close rosbag2 after writing. """Close rosbag2 after writing.
Closes open database transactions and writes metadata.yaml. Closes open database transactions and writes metadata.yaml.
""" """
assert self.cursor
assert self.conn
self.cursor.close() self.cursor.close()
self.cursor = None self.cursor = None
@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.close() self.conn.close()
if self.compression_mode == 'file': if self.compression_mode == 'file':
assert self.compressor
src = self.dbpath src = self.dbpath
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}') self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
with src.open('rb') as infile, self.dbpath.open('wb') as outfile: with src.open('rb') as infile, self.dbpath.open('wb') as outfile:

View File

@ -19,10 +19,10 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable from .typing import CDRDeser, CDRSer, CDRSerSize
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]: def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
"""Generate cdr size calculation function. """Generate cdr size calculation function.
Args: Args:
@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable: def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
"""Generate cdr serialization function. """Generate cdr serialization function.
Args: Args:
@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
return compile_lines(lines).serialize_cdr # type: ignore return compile_lines(lines).serialize_cdr # type: ignore
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable: def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
"""Generate cdr deserialization function. """Generate cdr deserialization function.
Args: Args:

View File

@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
generate_serialize_cdr(fields, 'be'), generate_serialize_cdr(fields, 'be'),
generate_deserialize_cdr(fields, 'le'), generate_deserialize_cdr(fields, 'le'),
generate_deserialize_cdr(fields, 'be'), generate_deserialize_cdr(fields, 'be'),
generate_ros1_to_cdr(fields, typename, False), generate_ros1_to_cdr(fields, typename, False), # type: ignore
generate_ros1_to_cdr(fields, typename, True), generate_ros1_to_cdr(fields, typename, True), # type: ignore
generate_cdr_to_ros1(fields, typename, False), generate_cdr_to_ros1(fields, typename, False), # type: ignore
generate_cdr_to_ros1(fields, typename, True), generate_cdr_to_ros1(fields, typename, True), # type: ignore
) )
return MSGDEFCACHE[typename] return MSGDEFCACHE[typename]

View File

@ -18,10 +18,16 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable # pylint: disable=ungrouped-imports from typing import Union # pylint: disable=ungrouped-imports
from .typing import Bitcvt, BitcvtSize
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable: def generate_ros1_to_cdr(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate ROS1 to CDR conversion function. """Generate ROS1 to CDR conversion function.
Args: Args:
@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext aligned = anext
lines.append(' return ipos, opos') lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) return getattr(compile_lines(lines), funcname) # type: ignore
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable: def generate_cdr_to_ros1(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate CDR to ROS1 conversion function. """Generate CDR to ROS1 conversion function.
Args: Args:
@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext aligned = anext
lines.append(' return ipos, opos') lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) return getattr(compile_lines(lines), funcname) # type: ignore

View File

@ -7,7 +7,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Callable, List from typing import Any, Callable, Tuple
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
CDRSer = Callable[[bytes, int, type], int]
CDRSerSize = Callable[[int, type], int]
class Descriptor(NamedTuple): class Descriptor(NamedTuple):
@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
"""Metadata of a message.""" """Metadata of a message."""
name: str name: str
fields: List[Field] fields: list[Field]
cls: Any cls: Any
size_cdr: int size_cdr: int
getsize_cdr: Callable getsize_cdr: CDRSerSize
serialize_cdr_le: Callable serialize_cdr_le: CDRSer
serialize_cdr_be: Callable serialize_cdr_be: CDRSer
deserialize_cdr_le: Callable deserialize_cdr_le: CDRDeser
deserialize_cdr_be: Callable deserialize_cdr_be: CDRDeser
getsize_ros1_to_cdr: Callable getsize_ros1_to_cdr: BitcvtSize
ros1_to_cdr: Callable ros1_to_cdr: Bitcvt
getsize_cdr_to_ros1: Callable getsize_cdr_to_ros1: BitcvtSize
cdr_to_ros1: Callable cdr_to_ros1: Bitcvt

View File

@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
pos = rule.skip_ws(text, 0) pos = rule.skip_ws(text, 0)
npos, trees = rule.parse(text, pos) npos, trees = rule.parse(text, pos)
assert npos == len(text), f'Could not parse: {text!r}' assert npos == len(text), f'Could not parse: {text!r}'
return visitor.visit(trees) return visitor.visit(trees) # type: ignore
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
raise TypesysError(f'Could not parse: {text!r}') from err raise TypesysError(f'Could not parse: {text!r}') from err

View File

@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
RULES = parse_grammar(GRAMMAR_IDL) RULES = parse_grammar(GRAMMAR_IDL)
def __init__(self): def __init__(self) -> None:
"""Initialize.""" """Initialize."""
super().__init__() super().__init__()
self.typedefs = {} self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
def visit_specification(self, children: Any) -> Typesdict: def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol, return only children of modules.""" """Process start symbol, return only children of modules."""

View File

@ -50,7 +50,7 @@ class Rule:
} }
return data return data
def parse(self, text: str, pos: int): def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
RULES: dict[str, Rule] = {} RULES: dict[str, Rule] = {}
def __init__(self): def __init__(self) -> None:
"""Initialize.""" """Initialize."""
def visit(self, tree: Any) -> Any: def visit(self, tree: Any) -> Any:

View File

@ -13,12 +13,14 @@ from . import types
from .base import Nodetype, TypesysError from .base import Nodetype, TypesysError
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Optional, Union
from .base import Typesdict from .base import Typesdict
INTLIKE = re.compile('^u?(bool|int|float)') INTLIKE = re.compile('^u?(bool|int|float)')
def get_typehint(desc: tuple) -> str: def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
"""Get python type hint for field. """Get python type hint for field.
Args: Args:
@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
""" """
if desc[0] == Nodetype.BASE: if desc[0] == Nodetype.BASE:
if match := INTLIKE.match(desc[1]): if match := INTLIKE.match(desc[1]): # type: ignore
return match.group(1) return match.group(1)
return 'str' return 'str'
if desc[0] == Nodetype.NAME: if desc[0] == Nodetype.NAME:
assert isinstance(desc[1], str)
return desc[1].replace('/', '__') return desc[1].replace('/', '__')
sub = desc[1][0] sub = desc[1][0]
if INTLIKE.match(sub[1]): if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1] typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]' return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
return f'list[{get_typehint(sub)}]' return f'list[{get_typehint(sub)}]' # type: ignore
def generate_python_code(typs: Typesdict) -> str: def generate_python_code(typs: Typesdict) -> str:
@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
'', '',
] ]
def get_ftype(ftype: tuple) -> tuple: def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
if ftype[0] <= 2: if ftype[0] <= 2:
return int(ftype[0]), ftype[1] return int(ftype[0]), ftype[1]
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1]) return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])

View File

@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from
from typing import TYPE_CHECKING, Dict, List, Union, cast from typing import TYPE_CHECKING, Dict, List, Union, cast
import numpy import numpy
from numpy.typing import NDArray
from rosbags.serde.messages import SerdeError, get_msgdef from rosbags.serde.messages import SerdeError, get_msgdef
from rosbags.serde.typing import Msgdef from rosbags.serde.typing import Msgdef
@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
size = SIZEMAP[desc.args] size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args] return ndarr, pos + num * SIZEMAP[desc.args]
@ -278,7 +279,7 @@ def serialize_array(
size = SIZEMAP[desc.args] size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size pos = (pos + size - 1) & -size
size *= len(val) size *= len(val)
val = cast(numpy.ndarray, val) val = cast(NDArray[numpy.int_], val)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
val = val.byteswap() # no inplace on readonly array val = val.byteswap() # no inplace on readonly array
rawdata[pos:pos + size] = memoryview(val.tobytes()) rawdata[pos:pos + size] = memoryview(val.tobytes())

View File

@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError
from rosbags.rosbag2 import WriterError from rosbags.rosbag2 import WriterError
def test_cliwrapper(tmp_path: Path): def test_cliwrapper(tmp_path: Path) -> None:
"""Test cli wrapper.""" """Test cli wrapper."""
(tmp_path / 'subdir').mkdir() (tmp_path / 'subdir').mkdir()
(tmp_path / 'ros1.bag').write_text('') (tmp_path / 'ros1.bag').write_text('')
@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path):
mock_print.assert_called_with('ERROR: exc') mock_print.assert_called_with('ERROR: exc')
def test_convert(tmp_path: Path): def test_convert(tmp_path: Path) -> None:
"""Test conversion function.""" """Test conversion function."""
(tmp_path / 'subdir').mkdir() (tmp_path / 'subdir').mkdir()
(tmp_path / 'foo.bag').write_text('') (tmp_path / 'foo.bag').write_text('')

View File

@ -142,13 +142,13 @@ module test_msgs {
""" """
def test_parse_empty_msg(): def test_parse_empty_msg() -> None:
"""Test msg parser with empty message.""" """Test msg parser with empty message."""
ret = get_types_from_msg('', 'std_msgs/msg/Empty') ret = get_types_from_msg('', 'std_msgs/msg/Empty')
assert ret == {'std_msgs/msg/Empty': ([], [])} assert ret == {'std_msgs/msg/Empty': ([], [])}
def test_parse_bounds_msg(): def test_parse_bounds_msg() -> None:
"""Test msg parser.""" """Test msg parser."""
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo') ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
assert ret == { assert ret == {
@ -168,7 +168,7 @@ def test_parse_bounds_msg():
} }
def test_parse_defaults_msg(): def test_parse_defaults_msg() -> None:
"""Test msg parser.""" """Test msg parser."""
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo') ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
assert ret == { assert ret == {
@ -188,7 +188,7 @@ def test_parse_defaults_msg():
} }
def test_parse_msg(): def test_parse_msg() -> None:
"""Test msg parser.""" """Test msg parser."""
with pytest.raises(TypesysError, match='Could not parse'): with pytest.raises(TypesysError, match='Could not parse'):
get_types_from_msg('invalid', 'test_msgs/msg/Foo') get_types_from_msg('invalid', 'test_msgs/msg/Foo')
@ -208,7 +208,7 @@ def test_parse_msg():
assert fields[6][1][0] == Nodetype.ARRAY assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg(): def test_parse_multi_msg() -> None:
"""Test multi msg parser.""" """Test multi msg parser."""
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo') ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 3 assert len(ret) == 3
@ -223,7 +223,7 @@ def test_parse_multi_msg():
assert consts == [('static', 'uint32', 42)] assert consts == [('static', 'uint32', 42)]
def test_parse_cstring_confusion(): def test_parse_cstring_confusion() -> None:
"""Test if msg separator is confused with const string.""" """Test if msg separator is confused with const string."""
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo') ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 2 assert len(ret) == 2
@ -235,7 +235,7 @@ def test_parse_cstring_confusion():
assert fields[1][1][1] == 'string' assert fields[1][1][1] == 'string'
def test_parse_relative_siblings_msg(): def test_parse_relative_siblings_msg() -> None:
"""Test relative siblings with msg parser.""" """Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo') ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header' assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg():
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other' assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl(): def test_parse_idl() -> None:
"""Test idl parser.""" """Test idl parser."""
ret = get_types_from_idl(IDL_LANG) ret = get_types_from_idl(IDL_LANG)
assert ret == {} assert ret == {}
@ -267,21 +267,21 @@ def test_parse_idl():
assert fields[6][1][0] == Nodetype.ARRAY assert fields[6][1][0] == Nodetype.ARRAY
def test_register_types(): def test_register_types() -> None:
"""Test type registeration.""" """Test type registeration."""
assert 'foo' not in FIELDDEFS assert 'foo' not in FIELDDEFS
register_types({}) register_types({})
register_types({'foo': [[], [('b', (1, 'bool'))]]}) register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
assert 'foo' in FIELDDEFS assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': [[], []]}) register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2 assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'): with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[], [('x', (1, 'bool'))]]}) register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
def test_generate_msgdef(): def test_generate_msgdef() -> None:
"""Test message definition generator.""" """Test message definition generator."""
res = generate_msgdef('std_msgs/msg/Header') res = generate_msgdef('std_msgs/msg/Header')
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed') assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')

View File

@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path:
return tmp_path return tmp_path
def test_reader(bag: Path): def test_reader(bag: Path) -> None:
"""Test reader and deserializer on simple bag.""" """Test reader and deserializer on simple bag."""
with Reader(bag) as reader: with Reader(bag) as reader:
assert reader.duration == 43 assert reader.duration == 43
@ -151,7 +151,7 @@ def test_reader(bag: Path):
next(gen) next(gen)
def test_message_filters(bag: Path): def test_message_filters(bag: Path) -> None:
"""Test reader filters messages.""" """Test reader filters messages."""
with Reader(bag) as reader: with Reader(bag) as reader:
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn'] magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
@ -188,14 +188,14 @@ def test_message_filters(bag: Path):
next(gen) next(gen)
def test_user_errors(bag: Path): def test_user_errors(bag: Path) -> None:
"""Test user errors.""" """Test user errors."""
reader = Reader(bag) reader = Reader(bag)
with pytest.raises(ReaderError, match='Rosbag is not open'): with pytest.raises(ReaderError, match='Rosbag is not open'):
next(reader.messages()) next(reader.messages())
def test_failure_cases(tmp_path: Path): def test_failure_cases(tmp_path: Path) -> None:
"""Test bags with broken fs layout.""" """Test bags with broken fs layout."""
with pytest.raises(ReaderError, match='not read metadata'): with pytest.raises(ReaderError, match='not read metadata'):
Reader(tmp_path) Reader(tmp_path)

View File

@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Reader tests.""" """Reader tests."""
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from struct import pack from struct import pack
from typing import TYPE_CHECKING
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -11,8 +14,12 @@ import pytest
from rosbags.rosbag1 import Reader, ReaderError from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag1.reader import IndexData from rosbags.rosbag1.reader import IndexData
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Union
def ser(data):
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
"""Serialize record header.""" """Serialize record header."""
if isinstance(data, dict): if isinstance(data, dict):
fields = [] fields = []
@ -23,7 +30,7 @@ def ser(data):
return pack('<L', len(data)) + data return pack('<L', len(data)) + data
def create_default_header(): def create_default_header() -> dict[str, bytes]:
"""Create empty rosbag header.""" """Create empty rosbag header."""
return { return {
'op': b'\x03', 'op': b'\x03',
@ -32,7 +39,11 @@ def create_default_header():
} }
def create_connection(cid=1, topic=0, typ=0): def create_connection(
cid: int = 1,
topic: int = 0,
typ: int = 0,
) -> tuple[dict[str, bytes], dict[str, bytes]]:
"""Create connection record.""" """Create connection record."""
return { return {
'op': b'\x07', 'op': b'\x07',
@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0):
} }
def create_message(cid=1, time=0, msg=0): def create_message(
cid: int = 1,
time: int = 0,
msg: int = 0,
) -> tuple[dict[str, Union[bytes, int]], bytes]:
"""Create message record.""" """Create message record."""
return { return {
'op': b'\x02', 'op': b'\x02',
@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0):
}, f'MSGCONTENT{msg}'.encode() }, f'MSGCONTENT{msg}'.encode()
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements def write_bag( # pylint: disable=too-many-locals,too-many-statements
bag: Path,
header: dict[str, bytes],
chunks: Sequence[Any] = (),
) -> None:
"""Write bag file.""" """Write bag file."""
magic = b'#ROSBAG V2.0\n' magic = b'#ROSBAG V2.0\n'
@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
chunk_bytes = b'' chunk_bytes = b''
start_time = 2**32 - 1 start_time = 2**32 - 1
end_time = 0 end_time = 0
counts = defaultdict(int) counts: dict[int, int] = defaultdict(int)
index = {} index = {}
offset = 0 offset = 0
@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
'count': 0, 'count': 0,
'msgs': b'', 'msgs': b'',
} }
index[conn]['count'] += 1 index[conn]['count'] += 1 # type: ignore
index[conn]['msgs'] += pack('<LLL', time, 0, offset) index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
add = ser(head) + ser(data) add = ser(head) + ser(data)
chunk_bytes += add chunk_bytes += add
@ -140,19 +160,19 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
if 'index_pos' not in header: if 'index_pos' not in header:
header['index_pos'] = pack('<Q', pos) header['index_pos'] = pack('<Q', pos)
header = ser(header) header_bytes = ser(header)
header += b'\x20' * (4096 - len(header)) header_bytes += b'\x20' * (4096 - len(header_bytes))
bag.write_bytes(b''.join([ bag.write_bytes(b''.join([
magic, magic,
header, header_bytes,
chunks_bytes, chunks_bytes,
connections, connections,
chunkinfos, chunkinfos,
])) ]))
def test_indexdata(): def test_indexdata() -> None:
"""Test IndexData sort sorder.""" """Test IndexData sort sorder."""
x42_1_0 = IndexData(42, 1, 0) x42_1_0 = IndexData(42, 1, 0)
x42_2_0 = IndexData(42, 2, 0) x42_2_0 = IndexData(42, 2, 0)
@ -175,7 +195,7 @@ def test_indexdata():
assert not x42_1_0 > x43_3_0 assert not x42_1_0 > x43_3_0
def test_reader(tmp_path): # pylint: disable=too-many-statements def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test reader and deserializer on simple bag.""" """Test reader and deserializer on simple bag."""
# empty bag # empty bag
bag = tmp_path / 'test.bag' bag = tmp_path / 'test.bag'
@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert msgs[0][2] == b'MSGCONTENT5' assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path): def test_user_errors(tmp_path: Path) -> None:
"""Test user errors.""" """Test user errors."""
bag = tmp_path / 'test.bag' bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header(), chunks=[[ write_bag(bag, create_default_header(), chunks=[[
@ -281,7 +301,7 @@ def test_user_errors(tmp_path):
next(reader.messages()) next(reader.messages())
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test failure cases.""" """Test failure cases."""
bag = tmp_path / 'test.bag' bag = tmp_path / 'test.bag'
with pytest.raises(ReaderError, match='does not exist'): with pytest.raises(ReaderError, match='does not exist'):

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('mode', [*Writer.CompressionMode]) @pytest.mark.parametrize('mode', [*Writer.CompressionMode])
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path): def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
"""Test full data roundtrip.""" """Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods class Foo: # pylint: disable=too-few-public-methods

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4]) @pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]): def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test full data roundtrip.""" """Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods class Foo: # pylint: disable=too-few-public-methods

View File

@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__
from .cdr import deserialize, serialize from .cdr import deserialize, serialize
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Tuple, Union from typing import Any, Generator, Union
MSG_POLY = ( MSG_POLY = (
( (
@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6
@pytest.fixture() @pytest.fixture()
def _comparable(): def _comparable() -> Generator[None, None, None]:
"""Make messages containing numpy arrays comparable. """Make messages containing numpy arrays comparable.
Notes: Notes:
@ -180,7 +180,7 @@ def _comparable():
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool: def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
lhs = self._mock_wraps # pylint: disable=protected-access lhs = self._mock_wraps # pylint: disable=protected-access
rhs = getattr(other, '_mock_wraps', other) rhs = getattr(other, '_mock_wraps', other)
return (lhs == rhs).all() return (lhs == rhs).all() # type: ignore
class CNDArray(MagicMock): class CNDArray(MagicMock):
"""Mock ndarray.""" """Mock ndarray."""
@ -194,14 +194,14 @@ def _comparable():
return CNDArray(wraps=self._mock_wraps.byteswap(*args)) return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
return CNDArray(wraps=frombuffer(*args, **kwargs)) return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer): with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield yield
@pytest.mark.parametrize('message', MESSAGES) @pytest.mark.parametrize('message', MESSAGES)
def test_serde(message: Tuple[bytes, str, bool]): def test_serde(message: tuple[bytes, str, bool]) -> None:
"""Test serialization deserialization roundtrip.""" """Test serialization deserialization roundtrip."""
rawdata, typ, is_little = message rawdata, typ, is_little = message
@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]):
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_deserializer(): def test_deserializer() -> None:
"""Test deserializer.""" """Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2]) msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2]) assert msg == deserialize(*MSG_POLY[:2])
@ -233,7 +233,8 @@ def test_deserializer():
assert msg.header.frame_id == 'foo42' assert msg.header.frame_id == 'foo42'
field = msg.magnetic_field field = msg.magnetic_field
assert (field.x, field.y, field.z) == (128., 128., 128.) assert (field.x, field.y, field.z) == (128., 128., 128.)
assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all() diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore
assert (diag == [1., 1., 1.]).all()
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2]) msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2]) assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
@ -241,7 +242,7 @@ def test_deserializer():
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_serializer(): def test_serializer() -> None:
"""Test serializer.""" """Test serializer."""
class Foo: # pylint: disable=too-few-public-methods class Foo: # pylint: disable=too-few-public-methods
@ -268,7 +269,7 @@ def test_serializer():
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_serializer_errors(): def test_serializer_errors() -> None:
"""Test seralizer with broken messages.""" """Test seralizer with broken messages."""
class Foo: # pylint: disable=too-few-public-methods class Foo: # pylint: disable=too-few-public-methods
@ -286,7 +287,7 @@ def test_serializer_errors():
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_custom_type(): def test_custom_type() -> None:
"""Test custom type.""" """Test custom type."""
cname = 'test_msgs/msg/custom' cname = 'test_msgs/msg/custom'
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64'))) register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
@ -362,7 +363,7 @@ def test_custom_type():
assert res == msg assert res == msg
def test_ros1_to_cdr(): def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion.""" """Test ROS1 to CDR conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
@ -385,7 +386,7 @@ def test_ros1_to_cdr():
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
def test_cdr_to_ros1(): def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion.""" """Test CDR to ROS1 conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
def test_writer(tmp_path: Path): def test_writer(tmp_path: Path) -> None:
"""Test Writer.""" """Test Writer."""
path = (tmp_path / 'rosbag2') path = (tmp_path / 'rosbag2')
with Writer(path) as bag: with Writer(path) as bag:
@ -60,7 +60,7 @@ def test_writer(tmp_path: Path):
assert size > (path / 'compress_message.db3').stat().st_size assert size > (path / 'compress_message.db3').stat().st_size
def test_failure_cases(tmp_path: Path): def test_failure_cases(tmp_path: Path) -> None:
"""Test writer failure cases.""" """Test writer failure cases."""
with pytest.raises(WriterError, match='exists'): with pytest.raises(WriterError, match='exists'):
Writer(tmp_path) Writer(tmp_path)

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Optional from typing import Optional
def test_no_overwrite(tmp_path: Path): def test_no_overwrite(tmp_path: Path) -> None:
"""Test writer does not touch existing files.""" """Test writer does not touch existing files."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
path.write_text('foo') path.write_text('foo')
@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path):
writer.open() writer.open()
def test_empty(tmp_path: Path): def test_empty(tmp_path: Path) -> None:
"""Test empty bag.""" """Test empty bag."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
@ -40,7 +40,7 @@ def test_empty(tmp_path: Path):
assert len(data) == 13 + 4096 assert len(data) == 13 + 4096
def test_add_connection(tmp_path: Path): def test_add_connection(tmp_path: Path) -> None:
"""Test adding of connections.""" """Test adding of connections."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path):
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2) assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
def test_write_errors(tmp_path: Path): def test_write_errors(tmp_path: Path) -> None:
"""Test write errors.""" """Test write errors."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path):
path.unlink() path.unlink()
def test_write_simple(tmp_path: Path): def test_write_simple(tmp_path: Path) -> None:
"""Test writing of messages.""" """Test writing of messages."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path):
path.unlink() path.unlink()
def test_compression_errors(tmp_path: Path): def test_compression_errors(tmp_path: Path) -> None:
"""Test compression modes.""" """Test compression modes."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
with Writer(path) as writer, \ with Writer(path) as writer, \
@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path):
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4]) @pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]): def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test compression modes.""" """Test compression modes."""
path = tmp_path / 'test.bag' path = tmp_path / 'test.bag'
writer = Writer(path) writer = Writer(path)

View File

@ -21,7 +21,7 @@ from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr from rosbags.serde import deserialize_cdr
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any, Generator
class ReaderPy: # pylint: disable=too-few-public-methods class ReaderPy: # pylint: disable=too-few-public-methods
@ -35,7 +35,7 @@ class ReaderPy: # pylint: disable=too-few-public-methods
self.reader.open(soptions, coptions) self.reader.open(soptions, coptions)
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()} self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
def messages(self): def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]:
"""Expose rosbag2 like generator behavior.""" """Expose rosbag2 like generator behavior."""
while self.reader.has_next(): while self.reader.has_next():
topic, data, timestamp = self.reader.read_next() topic, data, timestamp = self.reader.read_next()
@ -48,7 +48,7 @@ def deserialize_py(data: bytes, msgtype: str) -> Any:
return deserialize_message(data, pytype) return deserialize_message(data, pytype)
def compare_msg(lite: Any, native: Any): def compare_msg(lite: Any, native: Any) -> None:
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content. """Compare rosbag2 (lite) vs rosbag2_py (native) message content.
Args: Args:
@ -79,7 +79,7 @@ def compare_msg(lite: Any, native: Any):
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}' assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
def compare(path: Path): def compare(path: Path) -> None:
"""Compare raw and deserialized messages.""" """Compare raw and deserialized messages."""
with Reader(path) as reader: with Reader(path) as reader:
gens = (reader.messages(), ReaderPy(path).messages()) gens = (reader.messages(), ReaderPy(path).messages())
@ -100,7 +100,7 @@ def compare(path: Path):
assert len(list(gens[1])) == 0 assert len(list(gens[1])) == 0
def read_deser_rosbag2_py(path: Path): def read_deser_rosbag2_py(path: Path) -> None:
"""Read testbag with rosbag2_py.""" """Read testbag with rosbag2_py."""
soptions = StorageOptions(str(path), 'sqlite3') soptions = StorageOptions(str(path), 'sqlite3')
coptions = ConverterOptions('', '') coptions = ConverterOptions('', '')
@ -115,14 +115,14 @@ def read_deser_rosbag2_py(path: Path):
deserialize_message(rawdata, pytype) deserialize_message(rawdata, pytype)
def read_deser_rosbag2(path: Path): def read_deser_rosbag2(path: Path) -> None:
"""Read testbag with rosbag2lite.""" """Read testbag with rosbag2lite."""
with Reader(path) as reader: with Reader(path) as reader:
for connection, _, data in reader.messages(): for connection, _, data in reader.messages():
deserialize_cdr(data, connection.msgtype) deserialize_cdr(data, connection.msgtype)
def main(): def main() -> None:
"""Benchmark rosbag2 against rosbag2_py.""" """Benchmark rosbag2 against rosbag2_py."""
path = Path(sys.argv[1]) path = Path(sys.argv[1])
try: try:

View File

@ -25,7 +25,7 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, List, Union from typing import Any, Generator, List, Union
from rosbag.bag import _Connection_Info from rosbag.bag import _Connection_Info
@ -39,7 +39,7 @@ class Reader: # pylint: disable=too-few-public-methods
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', '')) self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()} self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
def messages(self): def messages(self) -> Generator[tuple[str, int, bytes], None, None]:
"""Expose rosbag2 like generator behavior.""" """Expose rosbag2 like generator behavior."""
while self.reader.has_next(): while self.reader.has_next():
topic, data, timestamp = self.reader.read_next() topic, data, timestamp = self.reader.read_next()
@ -47,7 +47,7 @@ class Reader: # pylint: disable=too-few-public-methods
yield topic, timestamp, deserialize_message(data, pytype) yield topic, timestamp, deserialize_message(data, pytype)
def fixup_ros1(conns: List[_Connection_Info]): def fixup_ros1(conns: List[_Connection_Info]) -> None:
"""Monkeypatch ROS2 fieldnames onto ROS1 objects. """Monkeypatch ROS2 fieldnames onto ROS1 objects.
Args: Args:
@ -69,16 +69,13 @@ def fixup_ros1(conns: List[_Connection_Info]):
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010 cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
def compare(ref: Any, msg: Any): def compare(ref: Any, msg: Any) -> None:
"""Compare message to its reference. """Compare message to its reference.
Args: Args:
ref: Reference ROS1 message. ref: Reference ROS1 message.
msg: Converted ROS2 message. msg: Converted ROS2 message.
Return:
True if messages are identical.
""" """
if hasattr(msg, 'get_fields_and_field_types'): if hasattr(msg, 'get_fields_and_field_types'):
for name in msg.get_fields_and_field_types(): for name in msg.get_fields_and_field_types():
@ -107,7 +104,7 @@ def compare(ref: Any, msg: Any):
assert ref == msg assert ref == msg
def main_bag1_bag1(path1: Path, path2: Path): def main_bag1_bag1(path1: Path, path2: Path) -> None:
"""Compare rosbag1 to rosbag1 message by message. """Compare rosbag1 to rosbag1 message by message.
Args: Args:
@ -132,7 +129,7 @@ def main_bag1_bag1(path1: Path, path2: Path):
print('Bags are identical.') # noqa: T001 print('Bags are identical.') # noqa: T001
def main_bag1_bag2(path1: Path, path2: Path): def main_bag1_bag2(path1: Path, path2: Path) -> None:
"""Compare rosbag1 to rosbag2 message by message. """Compare rosbag1 to rosbag2 message by message.
Args: Args: