Type generics and missing return types

This commit is contained in:
Marko Durkovic
2021-11-25 14:26:17 +01:00
parent ac704bd890
commit 52480e2bad
26 changed files with 263 additions and 175 deletions
+1 -1
View File
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
from typing import Callable
def pathtype(exists: bool = True) -> Callable:
def pathtype(exists: bool = True) -> Callable[[str], Path]:
"""Path argument for argparse.
Args:
+26 -21
View File
@@ -10,14 +10,15 @@ import os
import re
import struct
from bz2 import decompress as bz2_decompress
from collections import defaultdict
from enum import Enum, IntEnum
from functools import reduce
from io import BytesIO
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Any, Dict, NamedTuple
from lz4.frame import decompress as lz4_decompress # type: ignore
from lz4.frame import decompress as lz4_decompress
from rosbags.typesys.msg import normalize_msgtype
@@ -59,7 +60,7 @@ class Connection(NamedTuple):
md5sum: str
callerid: Optional[str]
latching: Optional[int]
indexes: list
indexes: list[IndexData]
class ChunkInfo(NamedTuple):
@@ -76,7 +77,7 @@ class Chunk(NamedTuple):
datasize: int
datapos: int
decompressor: Callable
decompressor: Callable[[bytes], bytes]
class TopicInfo(NamedTuple):
@@ -124,9 +125,9 @@ class IndexData(NamedTuple):
return self.time != other[0]
deserialize_uint8 = struct.Struct('<B').unpack
deserialize_uint32 = struct.Struct('<L').unpack
deserialize_uint64 = struct.Struct('<Q').unpack
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
def deserialize_time(val: bytes) -> int:
@@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
Deserialized value.
"""
sec, nsec = struct.unpack('<LL', val)
unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
sec, nsec = unpacked
return sec * 10**9 + nsec
class Header(dict):
class Header(Dict[str, Any]):
"""Record header."""
def get_uint8(self, name: str) -> int:
@@ -214,7 +216,9 @@ class Header(dict):
"""
try:
return self[name].decode()
value = self[name]
assert isinstance(value, bytes)
return value.decode()
except (KeyError, ValueError) as err:
raise ReaderError(f'Could not read string field {name!r}.') from err
@@ -237,7 +241,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle.
Args:
@@ -362,10 +366,10 @@ class Reader:
self.connections: dict[int, Connection] = {}
self.chunk_infos: list[ChunkInfo] = []
self.chunks: dict[int, Chunk] = {}
self.current_chunk = (-1, BytesIO())
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
self.topics: dict[str, TopicInfo] = {}
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
"""Open rosbag and read metadata."""
try:
self.bio = self.path.open('rb')
@@ -409,24 +413,25 @@ class Reader:
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
self.chunks = {}
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
for chunk_info in self.chunk_infos:
self.bio.seek(chunk_info.pos)
self.chunks[chunk_info.pos] = self.read_chunk()
for _ in range(len(chunk_info.connection_counts)):
cid, index = self.read_index_data(chunk_info.pos)
self.connections[cid].indexes.append(index)
indexes[cid].append(index)
for connection in self.connections.values():
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
for cid, connection in self.connections.items():
connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
assert connection.indexes
self.topics = {}
for topic, connections in groupby(
for topic, group in groupby(
sorted(self.connections.values(), key=lambda x: x.topic),
key=lambda x: x.topic,
):
connections = list(connections)
connections = list(group)
count = reduce(
lambda x, y: x + y,
(
@@ -446,7 +451,7 @@ class Reader:
self.close()
raise
def close(self):
def close(self) -> None:
"""Close rosbag."""
assert self.bio
self.bio.close()
@@ -614,8 +619,8 @@ class Reader:
chunk_header = self.chunks[entry.chunk_pos]
self.bio.seek(chunk_header.datapos)
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
chunk = self.current_chunk[1]
chunk.seek(entry.offset)
+18 -16
View File
@@ -11,9 +11,9 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict
from lz4.frame import compress as lz4_compress # type: ignore
from lz4.frame import compress as lz4_compress
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
@@ -21,7 +21,7 @@ from .reader import Connection, RecordType
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
class WriterError(Exception):
@@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
return struct.pack('<LL', sec, nsec)
class Header(dict):
class Header(Dict[str, Any]):
"""Record header."""
def set_uint32(self, name: str, value: int):
def set_uint32(self, name: str, value: int) -> None:
"""Set field to uint32 value.
Args:
@@ -70,7 +70,7 @@ class Header(dict):
"""
self[name] = serialize_uint32(value)
def set_uint64(self, name: str, value: int):
def set_uint64(self, name: str, value: int) -> None:
"""Set field to uint64 value.
Args:
@@ -80,7 +80,7 @@ class Header(dict):
"""
self[name] = serialize_uint64(value)
def set_string(self, name: str, value: str):
def set_string(self, name: str, value: str) -> None:
"""Set field to string value.
Args:
@@ -90,7 +90,7 @@ class Header(dict):
"""
self[name] = value.encode()
def set_time(self, name: str, value: int):
def set_time(self, name: str, value: int) -> None:
"""Set field to time value.
Args:
@@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
]
self.chunk_threshold = 1 * (1 << 20)
def set_compression(self, fmt: CompressionFormat):
def set_compression(self, fmt: CompressionFormat) -> None:
"""Enable compression on rosbag1.
This function has to be called before opening.
@@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower()
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
self.compressor = {
'bz2': bz2,
'lz4': lz4,
}[self.compression_format]
def open(self):
def open(self) -> None:
"""Open rosbag1 for writing."""
try:
self.bio = self.path.open('xb')
except FileExistsError:
raise WriterError(f'{self.path} exists already, not overwriting.') from None
assert self.bio
self.bio.write(b'#ROSBAG V2.0\n')
header = Header()
header.set_uint64('index_pos', 0)
@@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.connections[connection.cid] = connection
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag1.
Args:
@@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.write_chunk(chunk)
@staticmethod
def write_connection(connection: Connection, bio: BytesIO):
def write_connection(connection: Connection, bio: BinaryIO) -> None:
"""Write connection record."""
header = Header()
header.set_uint32('conn', connection.cid)
@@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
header.set_string('latching', str(connection.latching))
header.write(bio)
def write_chunk(self, chunk: WriteChunk):
def write_chunk(self, chunk: WriteChunk) -> None:
"""Write open chunk to file."""
assert self.bio
@@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
chunk.data.close()
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
def close(self):
def close(self) -> None:
"""Close rosbag1 after writing.
Closes open chunks and writes index.
"""
assert self.bio
for chunk in self.chunks:
if chunk.pos == -1:
self.write_chunk(chunk)
+13 -9
View File
@@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML, YAMLError
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
class ReaderError(Exception):
@@ -25,7 +26,7 @@ class ReaderError(Exception):
@contextmanager
def decompress(path: Path, do_decompress: bool):
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
"""Transparent rosbag2 database decompression context.
This context manager will yield a path to the decompressed file contents.
@@ -119,12 +120,12 @@ class Reader:
except KeyError as exc:
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
def open(self):
def open(self) -> None:
"""Open rosbag2."""
# Future storage formats will require file handles.
self.bio = True
def close(self):
def close(self) -> None:
"""Close rosbag2."""
# Future storage formats will require file handles.
assert self.bio
@@ -133,12 +134,14 @@ class Reader:
@property
def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages."""
return self.metadata['duration']['nanoseconds'] + 1
nsecs: int = self.metadata['duration']['nanoseconds']
return nsecs + 1
@property
def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message."""
return self.metadata['starting_time']['nanoseconds_since_epoch']
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
return nsecs
@property
def end_time(self) -> int:
@@ -148,7 +151,8 @@ class Reader:
@property
def message_count(self) -> int:
"""Total message count."""
return self.metadata['message_count']
count: int = self.metadata['message_count']
return count
@property
def compression_format(self) -> Optional[str]:
@@ -233,7 +237,7 @@ class Reader:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: Dict[int, Connection] = {
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
None) # type: ignore
for row in cur
+8 -5
View File
@@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: dict[int, Connection] = {}
self.conn = None
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
"""Enable compression on bag.
This function has to be called before opening.
@@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_format = fmt.name.lower()
self.compressor = zstandard.ZstdCompressor()
def open(self):
def open(self) -> None:
"""Open rosbag2 for writing.
Create base directory and open database connection.
@@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
"""Write message to rosbag2.
Args:
@@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
)
connection.count += 1
def close(self):
def close(self) -> None:
"""Close rosbag2 after writing.
Closes open database transactions and writes metadata.yaml.
"""
assert self.cursor
assert self.conn
self.cursor.close()
self.cursor = None
@@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.close()
if self.compression_mode == 'file':
assert self.compressor
src = self.dbpath
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:
+4 -4
View File
@@ -19,10 +19,10 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING:
from typing import Callable
from .typing import CDRDeser, CDRSer, CDRSerSize
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
"""Generate cdr size calculation function.
Args:
@@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
"""Generate cdr serialization function.
Args:
@@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
return compile_lines(lines).serialize_cdr # type: ignore
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
"""Generate cdr deserialization function.
Args:
+4 -4
View File
@@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
generate_serialize_cdr(fields, 'be'),
generate_deserialize_cdr(fields, 'le'),
generate_deserialize_cdr(fields, 'be'),
generate_ros1_to_cdr(fields, typename, False),
generate_ros1_to_cdr(fields, typename, True),
generate_cdr_to_ros1(fields, typename, False),
generate_cdr_to_ros1(fields, typename, True),
generate_ros1_to_cdr(fields, typename, False), # type: ignore
generate_ros1_to_cdr(fields, typename, True), # type: ignore
generate_cdr_to_ros1(fields, typename, False), # type: ignore
generate_cdr_to_ros1(fields, typename, True), # type: ignore
)
return MSGDEFCACHE[typename]
+15 -5
View File
@@ -18,10 +18,16 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING:
from typing import Callable # pylint: disable=ungrouped-imports
from typing import Union # pylint: disable=ungrouped-imports
from .typing import Bitcvt, BitcvtSize
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
def generate_ros1_to_cdr(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate ROS1 to CDR conversion function.
Args:
@@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname)
return getattr(compile_lines(lines), funcname) # type: ignore
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
def generate_cdr_to_ros1(
fields: list[Field],
typename: str,
copy: bool,
) -> Union[Bitcvt, BitcvtSize]:
"""Generate CDR to ROS1 conversion function.
Args:
@@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
aligned = anext
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname)
return getattr(compile_lines(lines), funcname) # type: ignore
+18 -11
View File
@@ -7,7 +7,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Any, Callable, List
from typing import Any, Callable, Tuple
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
CDRSer = Callable[[bytes, int, type], int]
CDRSerSize = Callable[[int, type], int]
class Descriptor(NamedTuple):
@@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
"""Metadata of a message."""
name: str
fields: List[Field]
fields: list[Field]
cls: Any
size_cdr: int
getsize_cdr: Callable
serialize_cdr_le: Callable
serialize_cdr_be: Callable
deserialize_cdr_le: Callable
deserialize_cdr_be: Callable
getsize_ros1_to_cdr: Callable
ros1_to_cdr: Callable
getsize_cdr_to_ros1: Callable
cdr_to_ros1: Callable
getsize_cdr: CDRSerSize
serialize_cdr_le: CDRSer
serialize_cdr_be: CDRSer
deserialize_cdr_le: CDRDeser
deserialize_cdr_be: CDRDeser
getsize_ros1_to_cdr: BitcvtSize
ros1_to_cdr: Bitcvt
getsize_cdr_to_ros1: BitcvtSize
cdr_to_ros1: Bitcvt
+1 -1
View File
@@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
pos = rule.skip_ws(text, 0)
npos, trees = rule.parse(text, pos)
assert npos == len(text), f'Could not parse: {text!r}'
return visitor.visit(trees)
return visitor.visit(trees) # type: ignore
except Exception as err: # pylint: disable=broad-except
raise TypesysError(f'Could not parse: {text!r}') from err
+2 -2
View File
@@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
RULES = parse_grammar(GRAMMAR_IDL)
def __init__(self):
def __init__(self) -> None:
"""Initialize."""
super().__init__()
self.typedefs = {}
self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol, return only children of modules."""
+2 -2
View File
@@ -50,7 +50,7 @@ class Rule:
}
return data
def parse(self, text: str, pos: int):
def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position."""
raise NotImplementedError # pragma: no cover
@@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
RULES: dict[str, Rule] = {}
def __init__(self):
def __init__(self) -> None:
"""Initialize."""
def visit(self, tree: Any) -> Any:
+7 -4
View File
@@ -13,12 +13,14 @@ from . import types
from .base import Nodetype, TypesysError
if TYPE_CHECKING:
from typing import Any, Optional, Union
from .base import Typesdict
INTLIKE = re.compile('^u?(bool|int|float)')
def get_typehint(desc: tuple) -> str:
def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
"""Get python type hint for field.
Args:
@@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
"""
if desc[0] == Nodetype.BASE:
if match := INTLIKE.match(desc[1]):
if match := INTLIKE.match(desc[1]): # type: ignore
return match.group(1)
return 'str'
if desc[0] == Nodetype.NAME:
assert isinstance(desc[1], str)
return desc[1].replace('/', '__')
sub = desc[1][0]
if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
return f'list[{get_typehint(sub)}]'
return f'list[{get_typehint(sub)}]' # type: ignore
def generate_python_code(typs: Typesdict) -> str:
@@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
'',
]
def get_ftype(ftype: tuple) -> tuple:
def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
if ftype[0] <= 2:
return int(ftype[0]), ftype[1]
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])