Type generics and missing return types

This commit is contained in:
Marko Durkovic 2021-11-25 14:26:17 +01:00
parent ac704bd890
commit 52480e2bad
26 changed files with 263 additions and 175 deletions

View File

@ -30,6 +30,7 @@ classifiers =
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering
Typing :: Typed
project_urls =
@ -109,7 +110,7 @@ avoid-escape = False
docstring_convention = google
docstring_style = google
extend-exclude = venv*,.venv*
extend-select =
extend-select =
# docstrings
D204,
D400,
@ -119,8 +120,10 @@ extend-select =
ignore =
# do not require annotation of `self`
ANN101,
# allow line break before binary operator
W503,
# handled by B001
E722,
# allow line break after binary operator
W504,
max-line-length = 100
strictness = long
suppress-none-returning = True
@ -134,10 +137,14 @@ multi_line_output = 3
explicit_package_bases = True
mypy_path = $MYPY_CONFIG_FILE_DIR/src
namespace_packages = True
strict = True
[mypy-ruamel]
[mypy-lz4.frame]
ignore_missing_imports = True
[mypy-ruamel.yaml]
implicit_reexport = True
[pydocstyle]
convention = google
add-select = D204,D400,D401,D404,D413
@ -146,9 +153,37 @@ add-select = D204,D400,D401,D404,D413
max-line-length = 100
[pylint.'MESSAGES CONTROL']
enable = all
disable =
duplicate-code,
ungrouped-imports,
# isort (pylint FAQ)
wrong-import-order,
# mccabe (pylint FAQ)
too-many-branches,
# fixme
fixme,
# pep8-naming (pylint FAQ, keep: invalid-name)
bad-classmethod-argument,
bad-mcs-classmethod-argument,
no-self-argument
# pycodestyle (pylint FAQ)
bad-indentation,
bare-except,
line-too-long,
missing-final-newline,
multiple-statements,
trailing-whitespace,
unnecessary-semicolon,
unneeded-not,
# pydocstyle (pylint FAQ)
missing-class-docstring,
missing-function-docstring,
missing-module-docstring,
# pyflakes (pylint FAQ)
undefined-variable,
unused-import,
unused-variable,
[yapf]
based_on_style = google
@ -159,15 +194,15 @@ indent_dictionary_value = false
[tool:pytest]
addopts =
-v
--flake8
--mypy
--pylint
--yapf
--cov=src
--cov-branch
--cov-report=html
--cov-report=term
--no-cov-on-fail
--junitxml=report.xml
-v
--flake8
--mypy
--pylint
--yapf
--cov=src
--cov-branch
--cov-report=html
--cov-report=term
--no-cov-on-fail
--junitxml=report.xml
junit_family=xunit2

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
from typing import Callable
def pathtype(exists: bool = True) -> Callable:
def pathtype(exists: bool = True) -> Callable[[str], Path]:
"""Path argument for argparse.
Args:

View File

@ -10,14 +10,15 @@ import os
import re
import struct
from bz2 import decompress as bz2_decompress
from collections import defaultdict
from enum import Enum, IntEnum
from functools import reduce
from io import BytesIO
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Any, Dict, NamedTuple
from lz4.frame import decompress as lz4_decompress # type: ignore
from lz4.frame import decompress as lz4_decompress
from rosbags.typesys.msg import normalize_msgtype
@ -59,7 +60,7 @@ class Connection(NamedTuple):
md5sum: str
callerid: Optional[str]
latching: Optional[int]
indexes: list
indexes: list[IndexData]
class ChunkInfo(NamedTuple):
@ -76,7 +77,7 @@ class Chunk(NamedTuple):
datasize: int
datapos: int
decompressor: Callable
decompressor: Callable[[bytes], bytes]
class TopicInfo(NamedTuple):
@ -124,9 +125,9 @@ class IndexData(NamedTuple):
return self.time != other[0]
deserialize_uint8 = struct.Struct('<B').unpack
deserialize_uint32 = struct.Struct('<L').unpack
deserialize_uint64 = struct.Struct('<Q').unpack
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
def deserialize_time(val: bytes) -> int:
@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
Deserialized value.
"""
sec, nsec = struct.unpack('<LL', val)
unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
sec, nsec = unpacked
return sec * 10**9 + nsec
class Header(dict):
class Header(Dict[str, Any]):
"""Record header."""
def get_uint8(self, name: str) -> int:
@ -214,7 +216,9 @@ class Header(dict):
"""
try:
return self[name].decode()
value = self[name]
assert isinstance(value, bytes)
return value.decode()
except (KeyError, ValueError) as err:
raise ReaderError(f'Could not read string field {name!r}.') from err
@ -237,7 +241,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle.
Args:
@ -362,10 +366,10 @@ class Reader:
self.connections: dict[int, Connection] = {}
self.chunk_infos: list[ChunkInfo] = []
self.chunks: dict[int, Chunk] = {}
self.current_chunk = (-1, BytesIO())
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
self.topics: dict[str, TopicInfo] = {}
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
"""Open rosbag and read metadata."""
try:
self.bio = self.path.open('rb')
@ -409,24 +413,25 @@ class Reader:
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
self.chunks = {}
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
for chunk_info in self.chunk_infos:
self.bio.seek(chunk_info.pos)
self.chunks[chunk_info.pos] = self.read_chunk()
for _ in range(len(chunk_info.connection_counts)):
cid, index = self.read_index_data(chunk_info.pos)
self.connections[cid].indexes.append(index)
indexes[cid].append(index)
for connection in self.connections.values():
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
for cid, connection in self.connections.items():
connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
assert connection.indexes
self.topics = {}
for topic, connections in groupby(
for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic),
key=lambda x: x.topic,
):
connections = list(connections)
connections = list(group)
count = reduce(
lambda x, y: x + y,
(
@ -446,7 +451,7 @@ class Reader:
self.close()
raise
def close(self):
def close(self) -> None:
"""Close rosbag."""
assert self.bio
self.bio.close()
@ -614,8 +619,8 @@ class Reader:
chunk_header = self.chunks[entry.chunk_pos]
self.bio.seek(chunk_header.datapos)
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
chunk = self.current_chunk[1]
chunk.seek(entry.offset)

View File

@ -11,9 +11,9 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict
from lz4.frame import compress as lz4_compress # type: ignore
from lz4.frame import compress as lz4_compress
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
@ -21,7 +21,7 @@ from .reader import Connection, RecordType
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
class WriterError(Exception):
@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
return struct.pack('<LL', sec, nsec)
class Header(dict):
class Header(Dict[str, Any]):
"""Record header."""
def set_uint32(self, name: str, value: int):
def set_uint32(self, name: str, value: int) -> None:
"""Set field to uint32 value.
Args:
@ -70,7 +70,7 @@ class Header(dict):
"""
self[name] = serialize_uint32(value)
def set_uint64(self, name: str, value: int):
def set_uint64(self, name: str, value: int) -> None:
"""Set field to uint64 value.
Args:
@ -80,7 +80,7 @@ class Header(dict):
"""
self[name] = serialize_uint64(value)
def set_string(self, name: str, value: str):
def set_string(self, name: str, value: str) -> None:
"""Set field to string value.
Args:
@ -90,7 +90,7 @@ class Header(dict):
"""
self[name] = value.encode()
def set_time(self, name: str, value: int):
def set_time(self, name: str, value: int) -> None:
"""Set field to time value.
Args:
@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
]
self.chunk_threshold = 1 * (1 << 20)
def set_compression(self, fmt: CompressionFormat):
def set_compression(self, fmt: CompressionFormat) -> None:
"""Enable compression on rosbag1.
This function has to be called before opening.
@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower()
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
self.compressor = {
'bz2': bz2,
'lz4': lz4,
}[self.compression_format]
def open(self):
def open(self) -> None:
"""Open rosbag1 for writing."""
try:
self.bio = self.path.open('xb')
except FileExistsError:
raise WriterError(f'{self.path} exists already, not overwriting.') from None
assert self.bio
self.bio.write(b'#ROSBAG V2.0\n')
header = Header()
header.set_uint64('index_pos', 0)
@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.connections[connection.cid] = connection
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag1.
Args:
@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.write_chunk(chunk)
@staticmethod
def write_connection(connection: Connection, bio: BytesIO):
def write_connection(connection: Connection, bio: BinaryIO) -> None:
"""Write connection record."""
header = Header()
header.set_uint32('conn', connection.cid)
@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
header.set_string('latching', str(connection.latching))
header.write(bio)
def write_chunk(self, chunk: WriteChunk):
def write_chunk(self, chunk: WriteChunk) -> None:
"""Write open chunk to file."""
assert self.bio
@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
chunk.data.close()
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
def close(self):
def close(self) -> None:
"""Close rosbag1 after writing.
Closes open chunks and writes index.
"""
assert self.bio
for chunk in self.chunks:
if chunk.pos == -1:
self.write_chunk(chunk)

View File

@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML, YAMLError
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
class ReaderError(Exception):
@ -25,7 +26,7 @@ class ReaderError(Exception):
@contextmanager
def decompress(path: Path, do_decompress: bool):
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
"""Transparent rosbag2 database decompression context.
This context manager will yield a path to the decompressed file contents.
@ -119,12 +120,12 @@ class Reader:
except KeyError as exc:
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
def open(self):
def open(self) -> None:
"""Open rosbag2."""
# Future storage formats will require file handles.
self.bio = True
def close(self):
def close(self) -> None:
"""Close rosbag2."""
# Future storage formats will require file handles.
assert self.bio
@ -133,12 +134,14 @@ class Reader:
@property
def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages."""
return self.metadata['duration']['nanoseconds'] + 1
nsecs: int = self.metadata['duration']['nanoseconds']
return nsecs + 1
@property
def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message."""
return self.metadata['starting_time']['nanoseconds_since_epoch']
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
return nsecs
@property
def end_time(self) -> int:
@ -148,7 +151,8 @@ class Reader:
@property
def message_count(self) -> int:
"""Total message count."""
return self.metadata['message_count']
count: int = self.metadata['message_count']
return count
@property
def compression_format(self) -> Optional[str]:
@ -233,7 +237,7 @@ class Reader:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: Dict[int, Connection] = {
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
None) # type: ignore
for row in cur

View File

@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {}
self.conn = None
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
"""Enable compression on bag.
This function has to be called before opening.
@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower()
self.compressor = zstandard.ZstdCompressor()
def open(self):
def open(self) -> None:
"""Open rosbag2 for writing.
Create base directory and open database connection.
@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag2.
Args:
@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
)
connection.count += 1
def close(self):
def close(self) -> None:
"""Close rosbag2 after writing.
Closes open database transactions and writes metadata.yaml.
"""
assert self.cursor
assert self.conn
self.cursor.close()
self.cursor = None
@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.close()
if self.compression_mode == 'file':
assert self.compressor
src = self.dbpath
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:

View File

@ -19,10 +19,10 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING:
from typing import Callable
from .typing import CDRDeser, CDRSer, CDRSerSize
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
"""Generate cdr size calculation function.
Args:
@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
"""Generate cdr serialization function.
Args:
@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
return compile_lines(lines).serialize_cdr # type: ignore
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
"""Generate cdr deserialization function.
Args:

View File

@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
generate_serialize_cdr(fields, 'be'),
generate_deserialize_cdr(fields, 'le'),
generate_deserialize_cdr(fields, 'be'),
generate_ros1_to_cdr(fields, typename, False),
generate_ros1_to_cdr(fields, typename, True),
generate_cdr_to_ros1(fields, typename, False),
generate_cdr_to_ros1(fields, typename, True),
generate_ros1_to_cdr(fields, typename, False), # type: ignore
generate_ros1_to_cdr(fields, typename, True), # type: ignore
generate_cdr_to_ros1(fields, typename, False), # type: ignore
generate_cdr_to_ros1(fields, typename, True), # type: ignore
)
return MSGDEFCACHE[typename]

View File

@ -18,10 +18,16 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING:
from typing import Callable # pylint: disable=ungrouped-imports
from typing import Union # pylint: disable=ungrouped-imports
from .typing import Bitcvt, BitcvtSize
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
def generate_ros1_to_cdr(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate ROS1 to CDR conversion function.
Args:
@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname)
return getattr(compile_lines(lines), funcname) # type: ignore
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
def generate_cdr_to_ros1(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate CDR to ROS1 conversion function.
Args:
@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname)
return getattr(compile_lines(lines), funcname) # type: ignore

View File

@ -7,7 +7,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Any, Callable, List
from typing import Any, Callable, Tuple
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
CDRSer = Callable[[bytes, int, type], int]
CDRSerSize = Callable[[int, type], int]
class Descriptor(NamedTuple):
@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
"""Metadata of a message."""
name: str
fields: List[Field]
fields: list[Field]
cls: Any
size_cdr: int
getsize_cdr: Callable
serialize_cdr_le: Callable
serialize_cdr_be: Callable
deserialize_cdr_le: Callable
deserialize_cdr_be: Callable
getsize_ros1_to_cdr: Callable
ros1_to_cdr: Callable
getsize_cdr_to_ros1: Callable
cdr_to_ros1: Callable
getsize_cdr: CDRSerSize
serialize_cdr_le: CDRSer
serialize_cdr_be: CDRSer
deserialize_cdr_le: CDRDeser
deserialize_cdr_be: CDRDeser
getsize_ros1_to_cdr: BitcvtSize
ros1_to_cdr: Bitcvt
getsize_cdr_to_ros1: BitcvtSize
cdr_to_ros1: Bitcvt

View File

@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
pos = rule.skip_ws(text, 0)
npos, trees = rule.parse(text, pos)
assert npos == len(text), f'Could not parse: {text!r}'
return visitor.visit(trees)
return visitor.visit(trees) # type: ignore
except Exception as err: # pylint: disable=broad-except
raise TypesysError(f'Could not parse: {text!r}') from err

View File

@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
RULES = parse_grammar(GRAMMAR_IDL)
def __init__(self):
def __init__(self) -> None:
"""Initialize."""
super().__init__()
self.typedefs = {}
self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol, return only children of modules."""

View File

@ -50,7 +50,7 @@ class Rule:
}
return data
def parse(self, text: str, pos: int):
def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position."""
raise NotImplementedError # pragma: no cover
@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
RULES: dict[str, Rule] = {}
def __init__(self):
def __init__(self) -> None:
"""Initialize."""
def visit(self, tree: Any) -> Any:

View File

@ -13,12 +13,14 @@ from . import types
from .base import Nodetype, TypesysError
if TYPE_CHECKING:
from typing import Any, Optional, Union
from .base import Typesdict
INTLIKE = re.compile('^u?(bool|int|float)')
def get_typehint(desc: tuple) -> str:
def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
"""Get python type hint for field.
Args:
@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
"""
if desc[0] == Nodetype.BASE:
if match := INTLIKE.match(desc[1]):
if match := INTLIKE.match(desc[1]): # type: ignore
return match.group(1)
return 'str'
if desc[0] == Nodetype.NAME:
assert isinstance(desc[1], str)
return desc[1].replace('/', '__')
sub = desc[1][0]
if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
return f'list[{get_typehint(sub)}]'
return f'list[{get_typehint(sub)}]' # type: ignore
def generate_python_code(typs: Typesdict) -> str:
@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
'',
]
def get_ftype(ftype: tuple) -> tuple:
def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
if ftype[0] <= 2:
return int(ftype[0]), ftype[1]
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])

View File

@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from
from typing import TYPE_CHECKING, Dict, List, Union, cast
import numpy
from numpy.typing import NDArray
from rosbags.serde.messages import SerdeError, get_msgdef
from rosbags.serde.typing import Msgdef
@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args]
@ -278,7 +279,7 @@ def serialize_array(
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
size *= len(val)
val = cast(numpy.ndarray, val)
val = cast(NDArray[numpy.int_], val)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
val = val.byteswap() # no inplace on readonly array
rawdata[pos:pos + size] = memoryview(val.tobytes())

View File

@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError
from rosbags.rosbag2 import WriterError
def test_cliwrapper(tmp_path: Path):
def test_cliwrapper(tmp_path: Path) -> None:
"""Test cli wrapper."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'ros1.bag').write_text('')
@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path):
mock_print.assert_called_with('ERROR: exc')
def test_convert(tmp_path: Path):
def test_convert(tmp_path: Path) -> None:
"""Test conversion function."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'foo.bag').write_text('')

View File

@ -142,13 +142,13 @@ module test_msgs {
"""
def test_parse_empty_msg():
def test_parse_empty_msg() -> None:
"""Test msg parser with empty message."""
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
assert ret == {'std_msgs/msg/Empty': ([], [])}
def test_parse_bounds_msg():
def test_parse_bounds_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
assert ret == {
@ -168,7 +168,7 @@ def test_parse_bounds_msg():
}
def test_parse_defaults_msg():
def test_parse_defaults_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
assert ret == {
@ -188,7 +188,7 @@ def test_parse_defaults_msg():
}
def test_parse_msg():
def test_parse_msg() -> None:
"""Test msg parser."""
with pytest.raises(TypesysError, match='Could not parse'):
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
@ -208,7 +208,7 @@ def test_parse_msg():
assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg():
def test_parse_multi_msg() -> None:
"""Test multi msg parser."""
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 3
@ -223,7 +223,7 @@ def test_parse_multi_msg():
assert consts == [('static', 'uint32', 42)]
def test_parse_cstring_confusion():
def test_parse_cstring_confusion() -> None:
"""Test if msg separator is confused with const string."""
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 2
@ -235,7 +235,7 @@ def test_parse_cstring_confusion():
assert fields[1][1][1] == 'string'
def test_parse_relative_siblings_msg():
def test_parse_relative_siblings_msg() -> None:
"""Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg():
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl():
def test_parse_idl() -> None:
"""Test idl parser."""
ret = get_types_from_idl(IDL_LANG)
assert ret == {}
@ -267,21 +267,21 @@ def test_parse_idl():
assert fields[6][1][0] == Nodetype.ARRAY
def test_register_types():
def test_register_types() -> None:
"""Test type registeration."""
assert 'foo' not in FIELDDEFS
register_types({})
register_types({'foo': [[], [('b', (1, 'bool'))]]})
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': [[], []]})
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[], [('x', (1, 'bool'))]]})
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
def test_generate_msgdef():
def test_generate_msgdef() -> None:
"""Test message definition generator."""
res = generate_msgdef('std_msgs/msg/Header')
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')

View File

@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path:
return tmp_path
def test_reader(bag: Path):
def test_reader(bag: Path) -> None:
"""Test reader and deserializer on simple bag."""
with Reader(bag) as reader:
assert reader.duration == 43
@ -151,7 +151,7 @@ def test_reader(bag: Path):
next(gen)
def test_message_filters(bag: Path):
def test_message_filters(bag: Path) -> None:
"""Test reader filters messages."""
with Reader(bag) as reader:
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
@ -188,14 +188,14 @@ def test_message_filters(bag: Path):
next(gen)
def test_user_errors(bag: Path):
def test_user_errors(bag: Path) -> None:
"""Test user errors."""
reader = Reader(bag)
with pytest.raises(ReaderError, match='Rosbag is not open'):
next(reader.messages())
def test_failure_cases(tmp_path: Path):
def test_failure_cases(tmp_path: Path) -> None:
"""Test bags with broken fs layout."""
with pytest.raises(ReaderError, match='not read metadata'):
Reader(tmp_path)

View File

@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from collections import defaultdict
from struct import pack
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
@ -11,8 +14,12 @@ import pytest
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag1.reader import IndexData
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Union
def ser(data):
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
"""Serialize record header."""
if isinstance(data, dict):
fields = []
@ -23,7 +30,7 @@ def ser(data):
return pack('<L', len(data)) + data
def create_default_header():
def create_default_header() -> dict[str, bytes]:
"""Create empty rosbag header."""
return {
'op': b'\x03',
@ -32,7 +39,11 @@ def create_default_header():
}
def create_connection(cid=1, topic=0, typ=0):
def create_connection(
cid: int = 1,
topic: int = 0,
typ: int = 0,
) -> tuple[dict[str, bytes], dict[str, bytes]]:
"""Create connection record."""
return {
'op': b'\x07',
@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0):
}
def create_message(cid=1, time=0, msg=0):
def create_message(
cid: int = 1,
time: int = 0,
msg: int = 0,
) -> tuple[dict[str, Union[bytes, int]], bytes]:
"""Create message record."""
return {
'op': b'\x02',
@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0):
}, f'MSGCONTENT{msg}'.encode()
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements
def write_bag( # pylint: disable=too-many-locals,too-many-statements
bag: Path,
header: dict[str, bytes],
chunks: Sequence[Any] = (),
) -> None:
"""Write bag file."""
magic = b'#ROSBAG V2.0\n'
@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
chunk_bytes = b''
start_time = 2**32 - 1
end_time = 0
counts = defaultdict(int)
counts: dict[int, int] = defaultdict(int)
index = {}
offset = 0
@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
'count': 0,
'msgs': b'',
}
index[conn]['count'] += 1
index[conn]['msgs'] += pack('<LLL', time, 0, offset)
index[conn]['count'] += 1 # type: ignore
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
add = ser(head) + ser(data)
chunk_bytes += add
@ -140,19 +160,19 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
if 'index_pos' not in header:
header['index_pos'] = pack('<Q', pos)
header = ser(header)
header += b'\x20' * (4096 - len(header))
header_bytes = ser(header)
header_bytes += b'\x20' * (4096 - len(header_bytes))
bag.write_bytes(b''.join([
magic,
header,
header_bytes,
chunks_bytes,
connections,
chunkinfos,
]))
def test_indexdata():
def test_indexdata() -> None:
"""Test IndexData sort sorder."""
x42_1_0 = IndexData(42, 1, 0)
x42_2_0 = IndexData(42, 2, 0)
@ -175,7 +195,7 @@ def test_indexdata():
assert not x42_1_0 > x43_3_0
def test_reader(tmp_path): # pylint: disable=too-many-statements
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test reader and deserializer on simple bag."""
# empty bag
bag = tmp_path / 'test.bag'
@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path):
def test_user_errors(tmp_path: Path) -> None:
"""Test user errors."""
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header(), chunks=[[
@ -281,7 +301,7 @@ def test_user_errors(tmp_path):
next(reader.messages())
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test failure cases."""
bag = tmp_path / 'test.bag'
with pytest.raises(ReaderError, match='does not exist'):

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods

View File

@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__
from .cdr import deserialize, serialize
if TYPE_CHECKING:
from typing import Any, Tuple, Union
from typing import Any, Generator, Union
MSG_POLY = (
(
@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6
@pytest.fixture()
def _comparable():
def _comparable() -> Generator[None, None, None]:
"""Make messages containing numpy arrays comparable.
Notes:
@ -180,7 +180,7 @@ def _comparable():
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
lhs = self._mock_wraps # pylint: disable=protected-access
rhs = getattr(other, '_mock_wraps', other)
return (lhs == rhs).all()
return (lhs == rhs).all() # type: ignore
class CNDArray(MagicMock):
"""Mock ndarray."""
@ -194,14 +194,14 @@ def _comparable():
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
return CNDArray(wraps=frombuffer(*args, **kwargs))
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield
@pytest.mark.parametrize('message', MESSAGES)
def test_serde(message: Tuple[bytes, str, bool]):
def test_serde(message: tuple[bytes, str, bool]) -> None:
"""Test serialization deserialization roundtrip."""
rawdata, typ, is_little = message
@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]):
@pytest.mark.usefixtures('_comparable')
def test_deserializer():
def test_deserializer() -> None:
"""Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2])
@ -233,7 +233,8 @@ def test_deserializer():
assert msg.header.frame_id == 'foo42'
field = msg.magnetic_field
assert (field.x, field.y, field.z) == (128., 128., 128.)
assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all()
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore
assert (diag == [1., 1., 1.]).all()
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
@ -241,7 +242,7 @@ def test_deserializer():
@pytest.mark.usefixtures('_comparable')
def test_serializer():
def test_serializer() -> None:
"""Test serializer."""
class Foo: # pylint: disable=too-few-public-methods
@ -268,7 +269,7 @@ def test_serializer():
@pytest.mark.usefixtures('_comparable')
def test_serializer_errors():
def test_serializer_errors() -> None:
"""Test seralizer with broken messages."""
class Foo: # pylint: disable=too-few-public-methods
@ -286,7 +287,7 @@ def test_serializer_errors():
@pytest.mark.usefixtures('_comparable')
def test_custom_type():
def test_custom_type() -> None:
"""Test custom type."""
cname = 'test_msgs/msg/custom'
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
@ -362,7 +363,7 @@ def test_custom_type():
assert res == msg
def test_ros1_to_cdr():
def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
@ -385,7 +386,7 @@ def test_ros1_to_cdr():
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
def test_cdr_to_ros1():
def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
from pathlib import Path
def test_writer(tmp_path: Path):
def test_writer(tmp_path: Path) -> None:
"""Test Writer."""
path = (tmp_path / 'rosbag2')
with Writer(path) as bag:
@ -60,7 +60,7 @@ def test_writer(tmp_path: Path):
assert size > (path / 'compress_message.db3').stat().st_size
def test_failure_cases(tmp_path: Path):
def test_failure_cases(tmp_path: Path) -> None:
"""Test writer failure cases."""
with pytest.raises(WriterError, match='exists'):
Writer(tmp_path)

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Optional
def test_no_overwrite(tmp_path: Path):
def test_no_overwrite(tmp_path: Path) -> None:
"""Test writer does not touch existing files."""
path = tmp_path / 'test.bag'
path.write_text('foo')
@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path):
writer.open()
def test_empty(tmp_path: Path):
def test_empty(tmp_path: Path) -> None:
"""Test empty bag."""
path = tmp_path / 'test.bag'
@ -40,7 +40,7 @@ def test_empty(tmp_path: Path):
assert len(data) == 13 + 4096
def test_add_connection(tmp_path: Path):
def test_add_connection(tmp_path: Path) -> None:
"""Test adding of connections."""
path = tmp_path / 'test.bag'
@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path):
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
def test_write_errors(tmp_path: Path):
def test_write_errors(tmp_path: Path) -> None:
"""Test write errors."""
path = tmp_path / 'test.bag'
@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path):
path.unlink()
def test_write_simple(tmp_path: Path):
def test_write_simple(tmp_path: Path) -> None:
"""Test writing of messages."""
path = tmp_path / 'test.bag'
@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path):
path.unlink()
def test_compression_errors(tmp_path: Path):
def test_compression_errors(tmp_path: Path) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
with Writer(path) as writer, \
@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path):
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
writer = Writer(path)

View File

@ -21,7 +21,7 @@ from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr
if TYPE_CHECKING:
from typing import Any
from typing import Any, Generator
class ReaderPy: # pylint: disable=too-few-public-methods
@ -35,7 +35,7 @@ class ReaderPy: # pylint: disable=too-few-public-methods
self.reader.open(soptions, coptions)
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
def messages(self):
def messages(self) -> Generator[tuple[str, str, int, bytes], None, None]:
"""Expose rosbag2 like generator behavior."""
while self.reader.has_next():
topic, data, timestamp = self.reader.read_next()
@ -48,7 +48,7 @@ def deserialize_py(data: bytes, msgtype: str) -> Any:
return deserialize_message(data, pytype)
def compare_msg(lite: Any, native: Any):
def compare_msg(lite: Any, native: Any) -> None:
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
Args:
@ -79,7 +79,7 @@ def compare_msg(lite: Any, native: Any):
assert native_val == lite_val, f'{fieldname}: {native_val} != {lite_val}'
def compare(path: Path):
def compare(path: Path) -> None:
"""Compare raw and deserialized messages."""
with Reader(path) as reader:
gens = (reader.messages(), ReaderPy(path).messages())
@ -100,7 +100,7 @@ def compare(path: Path):
assert len(list(gens[1])) == 0
def read_deser_rosbag2_py(path: Path):
def read_deser_rosbag2_py(path: Path) -> None:
"""Read testbag with rosbag2_py."""
soptions = StorageOptions(str(path), 'sqlite3')
coptions = ConverterOptions('', '')
@ -115,14 +115,14 @@ def read_deser_rosbag2_py(path: Path):
deserialize_message(rawdata, pytype)
def read_deser_rosbag2(path: Path):
def read_deser_rosbag2(path: Path) -> None:
"""Read testbag with rosbag2lite."""
with Reader(path) as reader:
for connection, _, data in reader.messages():
deserialize_cdr(data, connection.msgtype)
def main():
def main() -> None:
"""Benchmark rosbag2 against rosbag2_py."""
path = Path(sys.argv[1])
try:

View File

@ -25,7 +25,7 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
if TYPE_CHECKING:
from typing import Any, List, Union
from typing import Any, Generator, List, Union
from rosbag.bag import _Connection_Info
@ -39,7 +39,7 @@ class Reader: # pylint: disable=too-few-public-methods
self.reader.open(StorageOptions(path, 'sqlite3'), ConverterOptions('', ''))
self.typemap = {x.name: x.type for x in self.reader.get_all_topics_and_types()}
def messages(self):
def messages(self) -> Generator[tuple[str, int, bytes], None, None]:
"""Expose rosbag2 like generator behavior."""
while self.reader.has_next():
topic, data, timestamp = self.reader.read_next()
@ -47,7 +47,7 @@ class Reader: # pylint: disable=too-few-public-methods
yield topic, timestamp, deserialize_message(data, pytype)
def fixup_ros1(conns: List[_Connection_Info]):
def fixup_ros1(conns: List[_Connection_Info]) -> None:
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
Args:
@ -69,16 +69,13 @@ def fixup_ros1(conns: List[_Connection_Info]):
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
def compare(ref: Any, msg: Any):
def compare(ref: Any, msg: Any) -> None:
"""Compare message to its reference.
Args:
ref: Reference ROS1 message.
msg: Converted ROS2 message.
Return:
True if messages are identical.
"""
if hasattr(msg, 'get_fields_and_field_types'):
for name in msg.get_fields_and_field_types():
@ -107,7 +104,7 @@ def compare(ref: Any, msg: Any):
assert ref == msg
def main_bag1_bag1(path1: Path, path2: Path):
def main_bag1_bag1(path1: Path, path2: Path) -> None:
"""Compare rosbag1 to rosbag1 message by message.
Args:
@ -132,7 +129,7 @@ def main_bag1_bag1(path1: Path, path2: Path):
print('Bags are identical.') # noqa: T001
def main_bag1_bag2(path1: Path, path2: Path):
def main_bag1_bag2(path1: Path, path2: Path) -> None:
"""Compare rosbag1 to rosbag2 message by message.
Args: