diff --git a/setup.cfg b/setup.cfg index 89b22ffe..06a203f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,7 @@ classifiers = Programming Language :: Python :: 3 :: Only Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering Typing :: Typed project_urls = @@ -109,7 +110,7 @@ avoid-escape = False docstring_convention = google docstring_style = google extend-exclude = venv*,.venv* -extend-select = +extend-select = # docstrings D204, D400, @@ -119,8 +120,10 @@ extend-select = ignore = # do not require annotation of `self` ANN101, - # allow line break before binary operator - W503, + # handled by B001 + E722, + # allow line break after binary operator + W504, max-line-length = 100 strictness = long suppress-none-returning = True @@ -134,10 +137,14 @@ multi_line_output = 3 explicit_package_bases = True mypy_path = $MYPY_CONFIG_FILE_DIR/src namespace_packages = True +strict = True -[mypy-ruamel] +[mypy-lz4.frame] ignore_missing_imports = True +[mypy-ruamel.yaml] +implicit_reexport = True + [pydocstyle] convention = google add-select = D204,D400,D401,D404,D413 @@ -146,9 +153,37 @@ add-select = D204,D400,D401,D404,D413 max-line-length = 100 [pylint.'MESSAGES CONTROL'] +enable = all disable = duplicate-code, ungrouped-imports, + # isort (pylint FAQ) + wrong-import-order, + # mccabe (pylint FAQ) + too-many-branches, + # fixme + fixme, + # pep8-naming (pylint FAQ, keep: invalid-name) + bad-classmethod-argument, + bad-mcs-classmethod-argument, + no-self-argument + # pycodestyle (pylint FAQ) + bad-indentation, + bare-except, + line-too-long, + missing-final-newline, + multiple-statements, + trailing-whitespace, + unnecessary-semicolon, + unneeded-not, + # pydocstyle (pylint FAQ) + missing-class-docstring, + missing-function-docstring, + missing-module-docstring, + # pyflakes (pylint FAQ) + undefined-variable, + unused-import, + unused-variable, [yapf] based_on_style = google @@ -159,15 +194,15 @@ indent_dictionary_value = false [tool:pytest] addopts = - -v - --flake8 - --mypy - --pylint - --yapf - --cov=src - --cov-branch - --cov-report=html - --cov-report=term - --no-cov-on-fail - --junitxml=report.xml + -v + --flake8 + --mypy + --pylint + --yapf + --cov=src + --cov-branch + --cov-report=html + --cov-report=term + --no-cov-on-fail + --junitxml=report.xml junit_family=xunit2 diff --git a/src/rosbags/convert/__main__.py b/src/rosbags/convert/__main__.py index 2ab89904..60390544 100644 --- a/src/rosbags/convert/__main__.py +++ b/src/rosbags/convert/__main__.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from typing import Callable -def pathtype(exists: bool = True) -> Callable: +def pathtype(exists: bool = True) -> Callable[[str], Path]: """Path argument for argparse. Args: diff --git a/src/rosbags/rosbag1/reader.py b/src/rosbags/rosbag1/reader.py index afe899e9..72f10021 100644 --- a/src/rosbags/rosbag1/reader.py +++ b/src/rosbags/rosbag1/reader.py @@ -10,14 +10,15 @@ import os import re import struct from bz2 import decompress as bz2_decompress +from collections import defaultdict from enum import Enum, IntEnum from functools import reduce from io import BytesIO from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, Dict, NamedTuple -from lz4.frame import decompress as lz4_decompress # type: ignore +from lz4.frame import decompress as lz4_decompress from rosbags.typesys.msg import normalize_msgtype @@ -59,7 +60,7 @@ class Connection(NamedTuple): md5sum: str callerid: Optional[str] latching: Optional[int] - indexes: list + indexes: list[IndexData] class ChunkInfo(NamedTuple): @@ -76,7 +77,7 @@ class Chunk(NamedTuple): datasize: int datapos: int - decompressor: Callable + decompressor: Callable[[bytes], bytes] class TopicInfo(NamedTuple): @@ -124,9 +125,9 @@ class IndexData(NamedTuple): return self.time != other[0] -deserialize_uint8 = struct.Struct(' int: @@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int: Deserialized value. """ - sec, nsec = struct.unpack(' int: @@ -214,7 +216,9 @@ class Header(dict): """ try: - return self[name].decode() + value = self[name] + assert isinstance(value, bytes) + return value.decode() except (KeyError, ValueError) as err: raise ReaderError(f'Could not read string field {name!r}.') from err @@ -237,7 +241,7 @@ class Header(dict): raise ReaderError(f'Could not read time field {name!r}.') from err @classmethod - def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header: + def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header: """Read header from file handle. Args: @@ -362,10 +366,10 @@ class Reader: self.connections: dict[int, Connection] = {} self.chunk_infos: list[ChunkInfo] = [] self.chunks: dict[int, Chunk] = {} - self.current_chunk = (-1, BytesIO()) + self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO()) self.topics: dict[str, TopicInfo] = {} - def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements + def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements """Open rosbag and read metadata.""" try: self.bio = self.path.open('rb') @@ -409,24 +413,25 @@ class Reader: raise ReaderError(f'Bag index looks damaged: {err.args}') from None self.chunks = {} + indexes: dict[int, list[list[IndexData]]] = defaultdict(list) for chunk_info in self.chunk_infos: self.bio.seek(chunk_info.pos) self.chunks[chunk_info.pos] = self.read_chunk() for _ in range(len(chunk_info.connection_counts)): cid, index = self.read_index_data(chunk_info.pos) - self.connections[cid].indexes.append(index) + indexes[cid].append(index) - for connection in self.connections.values(): - connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time)) + for cid, connection in self.connections.items(): + connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time)) assert connection.indexes self.topics = {} - for topic, connections in groupby( + for topic, group in groupby( sorted(self.connections.values(), key=lambda x: x.topic), key=lambda x: x.topic, ): - connections = list(connections) + connections = list(group) count = reduce( lambda x, y: x + y, ( @@ -446,7 +451,7 @@ class Reader: self.close() raise - def close(self): + def close(self) -> None: """Close rosbag.""" assert self.bio self.bio.close() @@ -614,8 +619,8 @@ class Reader: chunk_header = self.chunks[entry.chunk_pos] self.bio.seek(chunk_header.datapos) - chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize)) - self.current_chunk = (entry.chunk_pos, BytesIO(chunk)) + rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize)) + self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes)) chunk = self.current_chunk[1] chunk.seek(entry.offset) diff --git a/src/rosbags/rosbag1/writer.py b/src/rosbags/rosbag1/writer.py index 51a5e6bd..3f6d1719 100644 --- a/src/rosbags/rosbag1/writer.py +++ b/src/rosbags/rosbag1/writer.py @@ -11,9 +11,9 @@ from dataclasses import dataclass from enum import IntEnum, auto from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict -from lz4.frame import compress as lz4_compress # type: ignore +from lz4.frame import compress as lz4_compress from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef @@ -21,7 +21,7 @@ from .reader import Connection, RecordType if TYPE_CHECKING: from types import TracebackType - from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union + from typing import BinaryIO, Callable, Literal, Optional, Type, Union class WriterError(Exception): @@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes: return struct.pack(' None: """Set field to uint32 value. Args: @@ -70,7 +70,7 @@ class Header(dict): """ self[name] = serialize_uint32(value) - def set_uint64(self, name: str, value: int): + def set_uint64(self, name: str, value: int) -> None: """Set field to uint64 value. Args: @@ -80,7 +80,7 @@ class Header(dict): """ self[name] = serialize_uint64(value) - def set_string(self, name: str, value: str): + def set_string(self, name: str, value: str) -> None: """Set field to string value. Args: @@ -90,7 +90,7 @@ class Header(dict): """ self[name] = value.encode() - def set_time(self, name: str, value: int): + def set_time(self, name: str, value: int) -> None: """Set field to time value. Args: @@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes ] self.chunk_threshold = 1 * (1 << 20) - def set_compression(self, fmt: CompressionFormat): + def set_compression(self, fmt: CompressionFormat) -> None: """Enable compression on rosbag1. This function has to be called before opening. @@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_format = fmt.name.lower() - bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9) - lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16) + bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9) + lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore self.compressor = { 'bz2': bz2, 'lz4': lz4, }[self.compression_format] - def open(self): + def open(self) -> None: """Open rosbag1 for writing.""" try: self.bio = self.path.open('xb') except FileExistsError: raise WriterError(f'{self.path} exists already, not overwriting.') from None + assert self.bio self.bio.write(b'#ROSBAG V2.0\n') header = Header() header.set_uint64('index_pos', 0) @@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.connections[connection.cid] = connection return connection - def write(self, connection: Connection, timestamp: int, data: bytes): + def write(self, connection: Connection, timestamp: int, data: bytes) -> None: """Write message to rosbag1. Args: @@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.write_chunk(chunk) @staticmethod - def write_connection(connection: Connection, bio: BytesIO): + def write_connection(connection: Connection, bio: BinaryIO) -> None: """Write connection record.""" header = Header() header.set_uint32('conn', connection.cid) @@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes header.set_string('latching', str(connection.latching)) header.write(bio) - def write_chunk(self, chunk: WriteChunk): + def write_chunk(self, chunk: WriteChunk) -> None: """Write open chunk to file.""" assert self.bio @@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes chunk.data.close() self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list))) - def close(self): + def close(self) -> None: """Close rosbag1 after writing. Closes open chunks and writes index. """ + assert self.bio for chunk in self.chunks: if chunk.pos == -1: self.write_chunk(chunk) diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index e5045e88..75fd616e 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING import zstandard -from ruamel.yaml import YAML, YAMLError +from ruamel.yaml import YAML +from ruamel.yaml.error import YAMLError from .connection import Connection if TYPE_CHECKING: from types import TracebackType - from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union + from typing import Any, Generator, Iterable, Literal, Optional, Type, Union class ReaderError(Exception): @@ -25,7 +26,7 @@ class ReaderError(Exception): @contextmanager -def decompress(path: Path, do_decompress: bool): +def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]: """Transparent rosbag2 database decompression context. This context manager will yield a path to the decompressed file contents. @@ -119,12 +120,12 @@ class Reader: except KeyError as exc: raise ReaderError(f'A metadata key is missing {exc!r}.') from None - def open(self): + def open(self) -> None: """Open rosbag2.""" # Future storage formats will require file handles. self.bio = True - def close(self): + def close(self) -> None: """Close rosbag2.""" # Future storage formats will require file handles. assert self.bio @@ -133,12 +134,14 @@ class Reader: @property def duration(self) -> int: """Duration in nanoseconds between earliest and latest messages.""" - return self.metadata['duration']['nanoseconds'] + 1 + nsecs: int = self.metadata['duration']['nanoseconds'] + return nsecs + 1 @property def start_time(self) -> int: """Timestamp in nanoseconds of the earliest message.""" - return self.metadata['starting_time']['nanoseconds_since_epoch'] + nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch'] + return nsecs @property def end_time(self) -> int: @@ -148,7 +151,8 @@ class Reader: @property def message_count(self) -> int: """Total message count.""" - return self.metadata['message_count'] + count: int = self.metadata['message_count'] + return count @property def compression_format(self) -> Optional[str]: @@ -233,7 +237,7 @@ class Reader: raise ReaderError(f'Cannot open database {path} or database missing tables.') cur.execute('SELECT name,id FROM topics') - connmap: Dict[int, Connection] = { + connmap: dict[int, Connection] = { row[1]: next((x for x in self.connections.values() if x.topic == row[0]), None) # type: ignore for row in cur diff --git a/src/rosbags/rosbag2/writer.py b/src/rosbags/rosbag2/writer.py index 386aef23..a6b74daa 100644 --- a/src/rosbags/rosbag2/writer.py +++ b/src/rosbags/rosbag2/writer.py @@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_format = '' self.compressor: Optional[zstandard.ZstdCompressor] = None self.connections: dict[int, Connection] = {} - self.conn = None + self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None - def set_compression(self, mode: CompressionMode, fmt: CompressionFormat): + def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None: """Enable compression on bag. This function has to be called before opening. @@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compression_format = fmt.name.lower() self.compressor = zstandard.ZstdCompressor() - def open(self): + def open(self) -> None: """Open rosbag2 for writing. Create base directory and open database connection. @@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta) return connection - def write(self, connection: Connection, timestamp: int, data: bytes): + def write(self, connection: Connection, timestamp: int, data: bytes) -> None: """Write message to rosbag2. Args: @@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes ) connection.count += 1 - def close(self): + def close(self) -> None: """Close rosbag2 after writing. Closes open database transactions and writes metadata.yaml. """ + assert self.cursor + assert self.conn self.cursor.close() self.cursor = None @@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.conn.close() if self.compression_mode == 'file': + assert self.compressor src = self.dbpath self.dbpath = src.with_suffix(f'.db3.{self.compression_format}') with src.open('rb') as infile, self.dbpath.open('wb') as outfile: diff --git a/src/rosbags/serde/cdr.py b/src/rosbags/serde/cdr.py index 5b777878..0abafa52 100644 --- a/src/rosbags/serde/cdr.py +++ b/src/rosbags/serde/cdr.py @@ -19,10 +19,10 @@ from .typing import Field from .utils import SIZEMAP, Valtype, align, align_after, compile_lines if TYPE_CHECKING: - from typing import Callable + from .typing import CDRDeser, CDRSer, CDRSerSize -def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]: +def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]: """Generate cdr size calculation function. Args: @@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]: return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore -def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable: +def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer: """Generate cdr serialization function. Args: @@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable: return compile_lines(lines).serialize_cdr # type: ignore -def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable: +def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser: """Generate cdr deserialization function. Args: diff --git a/src/rosbags/serde/messages.py b/src/rosbags/serde/messages.py index ed06449b..d83ce2c2 100644 --- a/src/rosbags/serde/messages.py +++ b/src/rosbags/serde/messages.py @@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef: generate_serialize_cdr(fields, 'be'), generate_deserialize_cdr(fields, 'le'), generate_deserialize_cdr(fields, 'be'), - generate_ros1_to_cdr(fields, typename, False), - generate_ros1_to_cdr(fields, typename, True), - generate_cdr_to_ros1(fields, typename, False), - generate_cdr_to_ros1(fields, typename, True), + generate_ros1_to_cdr(fields, typename, False), # type: ignore + generate_ros1_to_cdr(fields, typename, True), # type: ignore + generate_cdr_to_ros1(fields, typename, False), # type: ignore + generate_cdr_to_ros1(fields, typename, True), # type: ignore ) return MSGDEFCACHE[typename] diff --git a/src/rosbags/serde/ros1.py b/src/rosbags/serde/ros1.py index a12feef6..420177c3 100644 --- a/src/rosbags/serde/ros1.py +++ b/src/rosbags/serde/ros1.py @@ -18,10 +18,16 @@ from .typing import Field from .utils import SIZEMAP, Valtype, align, align_after, compile_lines if TYPE_CHECKING: - from typing import Callable # pylint: disable=ungrouped-imports + from typing import Union # pylint: disable=ungrouped-imports + + from .typing import Bitcvt, BitcvtSize -def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable: +def generate_ros1_to_cdr( + fields: list[Field], + typename: str, + copy: bool, +) -> Union[Bitcvt, BitcvtSize]: """Generate ROS1 to CDR conversion function. Args: @@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call aligned = anext lines.append(' return ipos, opos') - return getattr(compile_lines(lines), funcname) + return getattr(compile_lines(lines), funcname) # type: ignore -def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable: +def generate_cdr_to_ros1( + fields: list[Field], + typename: str, + copy: bool, +) -> Union[Bitcvt, BitcvtSize]: """Generate CDR to ROS1 conversion function. Args: @@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call aligned = anext lines.append(' return ipos, opos') - return getattr(compile_lines(lines), funcname) + return getattr(compile_lines(lines), funcname) # type: ignore diff --git a/src/rosbags/serde/typing.py b/src/rosbags/serde/typing.py index d4e8bd69..1079712e 100644 --- a/src/rosbags/serde/typing.py +++ b/src/rosbags/serde/typing.py @@ -7,7 +7,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: - from typing import Any, Callable, List + from typing import Any, Callable, Tuple + + Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]] + BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]] + + CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]] + CDRSer = Callable[[bytes, int, type], int] + CDRSerSize = Callable[[int, type], int] class Descriptor(NamedTuple): @@ -28,15 +35,15 @@ class Msgdef(NamedTuple): """Metadata of a message.""" name: str - fields: List[Field] + fields: list[Field] cls: Any size_cdr: int - getsize_cdr: Callable - serialize_cdr_le: Callable - serialize_cdr_be: Callable - deserialize_cdr_le: Callable - deserialize_cdr_be: Callable - getsize_ros1_to_cdr: Callable - ros1_to_cdr: Callable - getsize_cdr_to_ros1: Callable - cdr_to_ros1: Callable + getsize_cdr: CDRSerSize + serialize_cdr_le: CDRSer + serialize_cdr_be: CDRSer + deserialize_cdr_le: CDRDeser + deserialize_cdr_be: CDRDeser + getsize_ros1_to_cdr: BitcvtSize + ros1_to_cdr: Bitcvt + getsize_cdr_to_ros1: BitcvtSize + cdr_to_ros1: Bitcvt diff --git a/src/rosbags/typesys/base.py b/src/rosbags/typesys/base.py index 02e0b3b9..cbed9279 100644 --- a/src/rosbags/typesys/base.py +++ b/src/rosbags/typesys/base.py @@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict: pos = rule.skip_ws(text, 0) npos, trees = rule.parse(text, pos) assert npos == len(text), f'Could not parse: {text!r}' - return visitor.visit(trees) + return visitor.visit(trees) # type: ignore except Exception as err: # pylint: disable=broad-except raise TypesysError(f'Could not parse: {text!r}') from err diff --git a/src/rosbags/typesys/idl.py b/src/rosbags/typesys/idl.py index 6c930c8f..f8bee88f 100644 --- a/src/rosbags/typesys/idl.py +++ b/src/rosbags/typesys/idl.py @@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods RULES = parse_grammar(GRAMMAR_IDL) - def __init__(self): + def __init__(self) -> None: """Initialize.""" super().__init__() - self.typedefs = {} + self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {} def visit_specification(self, children: Any) -> Typesdict: """Process start symbol, return only children of modules.""" diff --git a/src/rosbags/typesys/peg.py b/src/rosbags/typesys/peg.py index 3298b2f9..833cfa09 100644 --- a/src/rosbags/typesys/peg.py +++ b/src/rosbags/typesys/peg.py @@ -50,7 +50,7 @@ class Rule: } return data - def parse(self, text: str, pos: int): + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" raise NotImplementedError # pragma: no cover @@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods RULES: dict[str, Rule] = {} - def __init__(self): + def __init__(self) -> None: """Initialize.""" def visit(self, tree: Any) -> Any: diff --git a/src/rosbags/typesys/register.py b/src/rosbags/typesys/register.py index 2b289932..6aa3bf9e 100644 --- a/src/rosbags/typesys/register.py +++ b/src/rosbags/typesys/register.py @@ -13,12 +13,14 @@ from . import types from .base import Nodetype, TypesysError if TYPE_CHECKING: + from typing import Any, Optional, Union + from .base import Typesdict INTLIKE = re.compile('^u?(bool|int|float)') -def get_typehint(desc: tuple) -> str: +def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str: """Get python type hint for field. Args: @@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str: """ if desc[0] == Nodetype.BASE: - if match := INTLIKE.match(desc[1]): + if match := INTLIKE.match(desc[1]): # type: ignore return match.group(1) return 'str' if desc[0] == Nodetype.NAME: + assert isinstance(desc[1], str) return desc[1].replace('/', '__') sub = desc[1][0] if INTLIKE.match(sub[1]): typ = 'bool8' if sub[1] == 'bool' else sub[1] return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]' - return f'list[{get_typehint(sub)}]' + return f'list[{get_typehint(sub)}]' # type: ignore def generate_python_code(typs: Typesdict) -> str: @@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str: '', ] - def get_ftype(ftype: tuple) -> tuple: + def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]: if ftype[0] <= 2: return int(ftype[0]), ftype[1] return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1]) diff --git a/tests/cdr.py b/tests/cdr.py index c3cbb249..9db001cd 100644 --- a/tests/cdr.py +++ b/tests/cdr.py @@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from from typing import TYPE_CHECKING, Dict, List, Union, cast import numpy +from numpy.typing import NDArray from rosbags.serde.messages import SerdeError, get_msgdef from rosbags.serde.typing import Msgdef @@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des size = SIZEMAP[desc.args] pos = (pos + size - 1) & -size - ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) + ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): ndarr = ndarr.byteswap() # no inplace on readonly array return ndarr, pos + num * SIZEMAP[desc.args] @@ -278,7 +279,7 @@ def serialize_array( size = SIZEMAP[desc.args] pos = (pos + size - 1) & -size size *= len(val) - val = cast(numpy.ndarray, val) + val = cast(NDArray[numpy.int_], val) if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): val = val.byteswap() # no inplace on readonly array rawdata[pos:pos + size] = memoryview(val.tobytes()) diff --git a/tests/test_convert.py b/tests/test_convert.py index 057a6092..1b090223 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError from rosbags.rosbag2 import WriterError -def test_cliwrapper(tmp_path: Path): +def test_cliwrapper(tmp_path: Path) -> None: """Test cli wrapper.""" (tmp_path / 'subdir').mkdir() (tmp_path / 'ros1.bag').write_text('') @@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path): mock_print.assert_called_with('ERROR: exc') -def test_convert(tmp_path: Path): +def test_convert(tmp_path: Path) -> None: """Test conversion function.""" (tmp_path / 'subdir').mkdir() (tmp_path / 'foo.bag').write_text('') diff --git a/tests/test_parse.py b/tests/test_parse.py index 6112ef6f..edce9ab7 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -142,13 +142,13 @@ module test_msgs { """ -def test_parse_empty_msg(): +def test_parse_empty_msg() -> None: """Test msg parser with empty message.""" ret = get_types_from_msg('', 'std_msgs/msg/Empty') assert ret == {'std_msgs/msg/Empty': ([], [])} -def test_parse_bounds_msg(): +def test_parse_bounds_msg() -> None: """Test msg parser.""" ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo') assert ret == { @@ -168,7 +168,7 @@ def test_parse_bounds_msg(): } -def test_parse_defaults_msg(): +def test_parse_defaults_msg() -> None: """Test msg parser.""" ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo') assert ret == { @@ -188,7 +188,7 @@ def test_parse_defaults_msg(): } -def test_parse_msg(): +def test_parse_msg() -> None: """Test msg parser.""" with pytest.raises(TypesysError, match='Could not parse'): get_types_from_msg('invalid', 'test_msgs/msg/Foo') @@ -208,7 +208,7 @@ def test_parse_msg(): assert fields[6][1][0] == Nodetype.ARRAY -def test_parse_multi_msg(): +def test_parse_multi_msg() -> None: """Test multi msg parser.""" ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo') assert len(ret) == 3 @@ -223,7 +223,7 @@ def test_parse_multi_msg(): assert consts == [('static', 'uint32', 42)] -def test_parse_cstring_confusion(): +def test_parse_cstring_confusion() -> None: """Test if msg separator is confused with const string.""" ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo') assert len(ret) == 2 @@ -235,7 +235,7 @@ def test_parse_cstring_confusion(): assert fields[1][1][1] == 'string' -def test_parse_relative_siblings_msg(): +def test_parse_relative_siblings_msg() -> None: """Test relative siblings with msg parser.""" ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo') assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header' @@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg(): assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other' -def test_parse_idl(): +def test_parse_idl() -> None: """Test idl parser.""" ret = get_types_from_idl(IDL_LANG) assert ret == {} @@ -267,21 +267,21 @@ def test_parse_idl(): assert fields[6][1][0] == Nodetype.ARRAY -def test_register_types(): +def test_register_types() -> None: """Test type registeration.""" assert 'foo' not in FIELDDEFS register_types({}) - register_types({'foo': [[], [('b', (1, 'bool'))]]}) + register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore assert 'foo' in FIELDDEFS - register_types({'std_msgs/msg/Header': [[], []]}) + register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2 with pytest.raises(TypesysError, match='different definition'): - register_types({'foo': [[], [('x', (1, 'bool'))]]}) + register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore -def test_generate_msgdef(): +def test_generate_msgdef() -> None: """Test message definition generator.""" res = generate_msgdef('std_msgs/msg/Header') assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed') diff --git a/tests/test_reader.py b/tests/test_reader.py index 2862954b..a941ead7 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path: return tmp_path -def test_reader(bag: Path): +def test_reader(bag: Path) -> None: """Test reader and deserializer on simple bag.""" with Reader(bag) as reader: assert reader.duration == 43 @@ -151,7 +151,7 @@ def test_reader(bag: Path): next(gen) -def test_message_filters(bag: Path): +def test_message_filters(bag: Path) -> None: """Test reader filters messages.""" with Reader(bag) as reader: magn_connections = [x for x in reader.connections.values() if x.topic == '/magn'] @@ -188,14 +188,14 @@ def test_message_filters(bag: Path): next(gen) -def test_user_errors(bag: Path): +def test_user_errors(bag: Path) -> None: """Test user errors.""" reader = Reader(bag) with pytest.raises(ReaderError, match='Rosbag is not open'): next(reader.messages()) -def test_failure_cases(tmp_path: Path): +def test_failure_cases(tmp_path: Path) -> None: """Test bags with broken fs layout.""" with pytest.raises(ReaderError, match='not read metadata'): Reader(tmp_path) diff --git a/tests/test_reader1.py b/tests/test_reader1.py index 5acbcf5b..289027c8 100644 --- a/tests/test_reader1.py +++ b/tests/test_reader1.py @@ -2,8 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 """Reader tests.""" +from __future__ import annotations + from collections import defaultdict from struct import pack +from typing import TYPE_CHECKING from unittest.mock import patch import pytest @@ -11,8 +14,12 @@ import pytest from rosbags.rosbag1 import Reader, ReaderError from rosbags.rosbag1.reader import IndexData +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Sequence, Union -def ser(data): + +def ser(data: Union[dict[str, Any], bytes]) -> bytes: """Serialize record header.""" if isinstance(data, dict): fields = [] @@ -23,7 +30,7 @@ def ser(data): return pack(' dict[str, bytes]: """Create empty rosbag header.""" return { 'op': b'\x03', @@ -32,7 +39,11 @@ def create_default_header(): } -def create_connection(cid=1, topic=0, typ=0): +def create_connection( + cid: int = 1, + topic: int = 0, + typ: int = 0, +) -> tuple[dict[str, bytes], dict[str, bytes]]: """Create connection record.""" return { 'op': b'\x07', @@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0): } -def create_message(cid=1, time=0, msg=0): +def create_message( + cid: int = 1, + time: int = 0, + msg: int = 0, +) -> tuple[dict[str, Union[bytes, int]], bytes]: """Create message record.""" return { 'op': b'\x02', @@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0): }, f'MSGCONTENT{msg}'.encode() -def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements +def write_bag( # pylint: disable=too-many-locals,too-many-statements + + bag: Path, + header: dict[str, bytes], + chunks: Sequence[Any] = (), +) -> None: """Write bag file.""" magic = b'#ROSBAG V2.0\n' @@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too- chunk_bytes = b'' start_time = 2**32 - 1 end_time = 0 - counts = defaultdict(int) + counts: dict[int, int] = defaultdict(int) index = {} offset = 0 @@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too- 'count': 0, 'msgs': b'', } - index[conn]['count'] += 1 - index[conn]['msgs'] += pack(' None: """Test IndexData sort sorder.""" x42_1_0 = IndexData(42, 1, 0) x42_2_0 = IndexData(42, 2, 0) @@ -175,7 +195,7 @@ def test_indexdata(): assert not x42_1_0 > x43_3_0 -def test_reader(tmp_path): # pylint: disable=too-many-statements +def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements """Test reader and deserializer on simple bag.""" # empty bag bag = tmp_path / 'test.bag' @@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements assert msgs[0][2] == b'MSGCONTENT5' -def test_user_errors(tmp_path): +def test_user_errors(tmp_path: Path) -> None: """Test user errors.""" bag = tmp_path / 'test.bag' write_bag(bag, create_default_header(), chunks=[[ @@ -281,7 +301,7 @@ def test_user_errors(tmp_path): next(reader.messages()) -def test_failure_cases(tmp_path): # pylint: disable=too-many-statements +def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements """Test failure cases.""" bag = tmp_path / 'test.bag' with pytest.raises(ReaderError, match='does not exist'): diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 09f1d99b..f087b48c 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: @pytest.mark.parametrize('mode', [*Writer.CompressionMode]) -def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path): +def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None: """Test full data roundtrip.""" class Foo: # pylint: disable=too-few-public-methods diff --git a/tests/test_roundtrip1.py b/tests/test_roundtrip1.py index a85288f8..5d677b05 100644 --- a/tests/test_roundtrip1.py +++ b/tests/test_roundtrip1.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: @pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4]) -def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]): +def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None: """Test full data roundtrip.""" class Foo: # pylint: disable=too-few-public-methods diff --git a/tests/test_serde.py b/tests/test_serde.py index a8e95edf..ef429eba 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__ from .cdr import deserialize, serialize if TYPE_CHECKING: - from typing import Any, Tuple, Union + from typing import Any, Generator, Union MSG_POLY = ( ( @@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6 @pytest.fixture() -def _comparable(): +def _comparable() -> Generator[None, None, None]: """Make messages containing numpy arrays comparable. Notes: @@ -180,7 +180,7 @@ def _comparable(): def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool: lhs = self._mock_wraps # pylint: disable=protected-access rhs = getattr(other, '_mock_wraps', other) - return (lhs == rhs).all() + return (lhs == rhs).all() # type: ignore class CNDArray(MagicMock): """Mock ndarray.""" @@ -194,14 +194,14 @@ def _comparable(): return CNDArray(wraps=self._mock_wraps.byteswap(*args)) def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: - return CNDArray(wraps=frombuffer(*args, **kwargs)) + return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer): yield @pytest.mark.parametrize('message', MESSAGES) -def test_serde(message: Tuple[bytes, str, bool]): +def test_serde(message: tuple[bytes, str, bool]) -> None: """Test serialization deserialization roundtrip.""" rawdata, typ, is_little = message @@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]): @pytest.mark.usefixtures('_comparable') -def test_deserializer(): +def test_deserializer() -> None: """Test deserializer.""" msg = deserialize_cdr(*MSG_POLY[:2]) assert msg == deserialize(*MSG_POLY[:2]) @@ -233,7 +233,8 @@ def test_deserializer(): assert msg.header.frame_id == 'foo42' field = msg.magnetic_field assert (field.x, field.y, field.z) == (128., 128., 128.) - assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all() + diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore + assert (diag == [1., 1., 1.]).all() msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2]) assert msg_big == deserialize(*MSG_MAGN_BIG[:2]) @@ -241,7 +242,7 @@ def test_deserializer(): @pytest.mark.usefixtures('_comparable') -def test_serializer(): +def test_serializer() -> None: """Test serializer.""" class Foo: # pylint: disable=too-few-public-methods @@ -268,7 +269,7 @@ def test_serializer(): @pytest.mark.usefixtures('_comparable') -def test_serializer_errors(): +def test_serializer_errors() -> None: """Test seralizer with broken messages.""" class Foo: # pylint: disable=too-few-public-methods @@ -286,7 +287,7 @@ def test_serializer_errors(): @pytest.mark.usefixtures('_comparable') -def test_custom_type(): +def test_custom_type() -> None: """Test custom type.""" cname = 'test_msgs/msg/custom' register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64'))) @@ -362,7 +363,7 @@ def test_custom_type(): assert res == msg -def test_ros1_to_cdr(): +def test_ros1_to_cdr() -> None: """Test ROS1 to CDR conversion.""" register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') @@ -385,7 +386,7 @@ def test_ros1_to_cdr(): assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr -def test_cdr_to_ros1(): +def test_cdr_to_ros1() -> None: """Test CDR to ROS1 conversion.""" register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') diff --git a/tests/test_writer.py b/tests/test_writer.py index 5fac0e13..ebb2c354 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from pathlib import Path -def test_writer(tmp_path: Path): +def test_writer(tmp_path: Path) -> None: """Test Writer.""" path = (tmp_path / 'rosbag2') with Writer(path) as bag: @@ -60,7 +60,7 @@ def test_writer(tmp_path: Path): assert size > (path / 'compress_message.db3').stat().st_size -def test_failure_cases(tmp_path: Path): +def test_failure_cases(tmp_path: Path) -> None: """Test writer failure cases.""" with pytest.raises(WriterError, match='exists'): Writer(tmp_path) diff --git a/tests/test_writer1.py b/tests/test_writer1.py index 384ebadd..b3c7f1e8 100644 --- a/tests/test_writer1.py +++ b/tests/test_writer1.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from typing import Optional -def test_no_overwrite(tmp_path: Path): +def test_no_overwrite(tmp_path: Path) -> None: """Test writer does not touch existing files.""" path = tmp_path / 'test.bag' path.write_text('foo') @@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path): writer.open() -def test_empty(tmp_path: Path): +def test_empty(tmp_path: Path) -> None: """Test empty bag.""" path = tmp_path / 'test.bag' @@ -40,7 +40,7 @@ def test_empty(tmp_path: Path): assert len(data) == 13 + 4096 -def test_add_connection(tmp_path: Path): +def test_add_connection(tmp_path: Path) -> None: """Test adding of connections.""" path = tmp_path / 'test.bag' @@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path): assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2) -def test_write_errors(tmp_path: Path): +def test_write_errors(tmp_path: Path) -> None: """Test write errors.""" path = tmp_path / 'test.bag' @@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path): path.unlink() -def test_write_simple(tmp_path: Path): +def test_write_simple(tmp_path: Path) -> None: """Test writing of messages.""" path = tmp_path / 'test.bag' @@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path): path.unlink() -def test_compression_errors(tmp_path: Path): +def test_compression_errors(tmp_path: Path) -> None: """Test compression modes.""" path = tmp_path / 'test.bag' with Writer(path) as writer, \ @@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path): @pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4]) -def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]): +def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None: """Test compression modes.""" path = tmp_path / 'test.bag' writer = Writer(path) diff --git a/tools/bench/bench.py b/tools/bench/bench.py index cb6d876a..b6330fcf 100644 --- a/tools/bench/bench.py +++ b/tools/bench/bench.py @@ -21,7 +21,7 @@ from rosbags.rosbag2 import Reader from rosbags.serde import deserialize_cdr if TYPE_CHECKING: - from typing import Any + from typing import Any, Generator class ReaderPy: # pylint: disable=too-few-public-methods @@ -35,7 +35,7 @@ class ReaderPy: # pylint: disable=too-few-public-methods self.reader.open(soptions, coptions) self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()} - def messages(self): + def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]: """Expose rosbag2 like generator behavior.""" while self.reader.has_next(): topic, data, timestamp = self.reader.read_next() @@ -48,7 +48,7 @@ def deserialize_py(data: bytes, msgtype: str) -> Any: return deserialize_message(data, pytype) -def compare_msg(lite: Any, native: Any): +def compare_msg(lite: Any, native: Any) -> None: """Compare rosbag2 (lite) vs rosbag2_py (native) message content. Args: @@ -79,7 +79,7 @@ def compare_msg(lite: Any, native: Any): assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}' -def compare(path: Path): +def compare(path: Path) -> None: """Compare raw and deserialized messages.""" with Reader(path) as reader: gens = (reader.messages(), ReaderPy(path).messages()) @@ -100,7 +100,7 @@ def compare(path: Path): assert len(list(gens[1])) == 0 -def read_deser_rosbag2_py(path: Path): +def read_deser_rosbag2_py(path: Path) -> None: """Read testbag with rosbag2_py.""" soptions = StorageOptions(str(path), 'sqlite3') coptions = ConverterOptions('', '') @@ -115,14 +115,14 @@ def read_deser_rosbag2_py(path: Path): deserialize_message(rawdata, pytype) -def read_deser_rosbag2(path: Path): +def read_deser_rosbag2(path: Path) -> None: """Read testbag with rosbag2lite.""" with Reader(path) as reader: for connection, _, data in reader.messages(): deserialize_cdr(data, connection.msgtype) -def main(): +def main() -> None: """Benchmark rosbag2 against rosbag2_py.""" path = Path(sys.argv[1]) try: diff --git a/tools/compare/compare.py b/tools/compare/compare.py index 9649ea48..bb9e34dc 100644 --- a/tools/compare/compare.py +++ b/tools/compare/compare.py @@ -25,7 +25,7 @@ rosgraph_msgs.msg.TopicStatistics = Mock() import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position if TYPE_CHECKING: - from typing import Any, List, Union + from typing import Any, Generator, List, Union from rosbag.bag import _Connection_Info @@ -39,7 +39,7 @@ class Reader: # pylint: disable=too-few-public-methods self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', '')) self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()} - def messages(self): + def messages(self) -> Generator[tuple[str, int, bytes], None, None]: """Expose rosbag2 like generator behavior.""" while self.reader.has_next(): topic, data, timestamp = self.reader.read_next() @@ -47,7 +47,7 @@ class Reader: # pylint: disable=too-few-public-methods yield topic, timestamp, deserialize_message(data, pytype) -def fixup_ros1(conns: List[_Connection_Info]): +def fixup_ros1(conns: List[_Connection_Info]) -> None: """Monkeypatch ROS2 fieldnames onto ROS1 objects. Args: @@ -69,16 +69,13 @@ def fixup_ros1(conns: List[_Connection_Info]): cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010 -def compare(ref: Any, msg: Any): +def compare(ref: Any, msg: Any) -> None: """Compare message to its reference. Args: ref: Reference ROS1 message. msg: Converted ROS2 message. - Return: - True if messages are identical. - """ if hasattr(msg, 'get_fields_and_field_types'): for name in msg.get_fields_and_field_types(): @@ -107,7 +104,7 @@ def compare(ref: Any, msg: Any): assert ref == msg -def main_bag1_bag1(path1: Path, path2: Path): +def main_bag1_bag1(path1: Path, path2: Path) -> None: """Compare rosbag1 to rosbag1 message by message. Args: @@ -132,7 +129,7 @@ def main_bag1_bag1(path1: Path, path2: Path): print('Bags are identical.') # noqa: T001 -def main_bag1_bag2(path1: Path, path2: Path): +def main_bag1_bag2(path1: Path, path2: Path) -> None: """Compare rosbag1 to rosbag2 message by message. Args: