Type generics and missing return types
This commit is contained in:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def pathtype(exists: bool = True) -> Callable:
|
||||
def pathtype(exists: bool = True) -> Callable[[str], Path]:
|
||||
"""Path argument for argparse.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -10,14 +10,15 @@ import os
|
||||
import re
|
||||
import struct
|
||||
from bz2 import decompress as bz2_decompress
|
||||
from collections import defaultdict
|
||||
from enum import Enum, IntEnum
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, NamedTuple
|
||||
|
||||
from lz4.frame import decompress as lz4_decompress # type: ignore
|
||||
from lz4.frame import decompress as lz4_decompress
|
||||
|
||||
from rosbags.typesys.msg import normalize_msgtype
|
||||
|
||||
@@ -59,7 +60,7 @@ class Connection(NamedTuple):
|
||||
md5sum: str
|
||||
callerid: Optional[str]
|
||||
latching: Optional[int]
|
||||
indexes: list
|
||||
indexes: list[IndexData]
|
||||
|
||||
|
||||
class ChunkInfo(NamedTuple):
|
||||
@@ -76,7 +77,7 @@ class Chunk(NamedTuple):
|
||||
|
||||
datasize: int
|
||||
datapos: int
|
||||
decompressor: Callable
|
||||
decompressor: Callable[[bytes], bytes]
|
||||
|
||||
|
||||
class TopicInfo(NamedTuple):
|
||||
@@ -124,9 +125,9 @@ class IndexData(NamedTuple):
|
||||
return self.time != other[0]
|
||||
|
||||
|
||||
deserialize_uint8 = struct.Struct('<B').unpack
|
||||
deserialize_uint32 = struct.Struct('<L').unpack
|
||||
deserialize_uint64 = struct.Struct('<Q').unpack
|
||||
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
|
||||
deserialize_uint32: Callable[[bytes], tuple[int]] = struct.Struct('<L').unpack # type: ignore
|
||||
deserialize_uint64: Callable[[bytes], tuple[int]] = struct.Struct('<Q').unpack # type: ignore
|
||||
|
||||
|
||||
def deserialize_time(val: bytes) -> int:
|
||||
@@ -139,11 +140,12 @@ def deserialize_time(val: bytes) -> int:
|
||||
Deserialized value.
|
||||
|
||||
"""
|
||||
sec, nsec = struct.unpack('<LL', val)
|
||||
unpacked: tuple[int, int] = struct.unpack('<LL', val) # type: ignore
|
||||
sec, nsec = unpacked
|
||||
return sec * 10**9 + nsec
|
||||
|
||||
|
||||
class Header(dict):
|
||||
class Header(Dict[str, Any]):
|
||||
"""Record header."""
|
||||
|
||||
def get_uint8(self, name: str) -> int:
|
||||
@@ -214,7 +216,9 @@ class Header(dict):
|
||||
|
||||
"""
|
||||
try:
|
||||
return self[name].decode()
|
||||
value = self[name]
|
||||
assert isinstance(value, bytes)
|
||||
return value.decode()
|
||||
except (KeyError, ValueError) as err:
|
||||
raise ReaderError(f'Could not read string field {name!r}.') from err
|
||||
|
||||
@@ -237,7 +241,7 @@ class Header(dict):
|
||||
raise ReaderError(f'Could not read time field {name!r}.') from err
|
||||
|
||||
@classmethod
|
||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||
def read(cls: Type[Header], src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||
"""Read header from file handle.
|
||||
|
||||
Args:
|
||||
@@ -362,10 +366,10 @@ class Reader:
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.chunk_infos: list[ChunkInfo] = []
|
||||
self.chunks: dict[int, Chunk] = {}
|
||||
self.current_chunk = (-1, BytesIO())
|
||||
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
|
||||
self.topics: dict[str, TopicInfo] = {}
|
||||
|
||||
def open(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
"""Open rosbag and read metadata."""
|
||||
try:
|
||||
self.bio = self.path.open('rb')
|
||||
@@ -409,24 +413,25 @@ class Reader:
|
||||
raise ReaderError(f'Bag index looks damaged: {err.args}') from None
|
||||
|
||||
self.chunks = {}
|
||||
indexes: dict[int, list[list[IndexData]]] = defaultdict(list)
|
||||
for chunk_info in self.chunk_infos:
|
||||
self.bio.seek(chunk_info.pos)
|
||||
self.chunks[chunk_info.pos] = self.read_chunk()
|
||||
|
||||
for _ in range(len(chunk_info.connection_counts)):
|
||||
cid, index = self.read_index_data(chunk_info.pos)
|
||||
self.connections[cid].indexes.append(index)
|
||||
indexes[cid].append(index)
|
||||
|
||||
for connection in self.connections.values():
|
||||
connection.indexes[:] = list(heapq.merge(*connection.indexes, key=lambda x: x.time))
|
||||
for cid, connection in self.connections.items():
|
||||
connection.indexes.extend(heapq.merge(*indexes[cid], key=lambda x: x.time))
|
||||
assert connection.indexes
|
||||
|
||||
self.topics = {}
|
||||
for topic, connections in groupby(
|
||||
for topic, group in groupby(
|
||||
sorted(self.connections.values(), key=lambda x: x.topic),
|
||||
key=lambda x: x.topic,
|
||||
):
|
||||
connections = list(connections)
|
||||
connections = list(group)
|
||||
count = reduce(
|
||||
lambda x, y: x + y,
|
||||
(
|
||||
@@ -446,7 +451,7 @@ class Reader:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag."""
|
||||
assert self.bio
|
||||
self.bio.close()
|
||||
@@ -614,8 +619,8 @@ class Reader:
|
||||
|
||||
chunk_header = self.chunks[entry.chunk_pos]
|
||||
self.bio.seek(chunk_header.datapos)
|
||||
chunk = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||
self.current_chunk = (entry.chunk_pos, BytesIO(chunk))
|
||||
rawbytes = chunk_header.decompressor(read_bytes(self.bio, chunk_header.datasize))
|
||||
self.current_chunk = (entry.chunk_pos, BytesIO(rawbytes))
|
||||
|
||||
chunk = self.current_chunk[1]
|
||||
chunk.seek(entry.offset)
|
||||
|
||||
@@ -11,9 +11,9 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from lz4.frame import compress as lz4_compress # type: ignore
|
||||
from lz4.frame import compress as lz4_compress
|
||||
|
||||
from rosbags.typesys.msg import denormalize_msgtype, generate_msgdef
|
||||
|
||||
@@ -21,7 +21,7 @@ from .reader import Connection, RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, BinaryIO, Callable, Literal, Optional, Type, Union
|
||||
from typing import BinaryIO, Callable, Literal, Optional, Type, Union
|
||||
|
||||
|
||||
class WriterError(Exception):
|
||||
@@ -57,10 +57,10 @@ def serialize_time(val: int) -> bytes:
|
||||
return struct.pack('<LL', sec, nsec)
|
||||
|
||||
|
||||
class Header(dict):
|
||||
class Header(Dict[str, Any]):
|
||||
"""Record header."""
|
||||
|
||||
def set_uint32(self, name: str, value: int):
|
||||
def set_uint32(self, name: str, value: int) -> None:
|
||||
"""Set field to uint32 value.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +70,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = serialize_uint32(value)
|
||||
|
||||
def set_uint64(self, name: str, value: int):
|
||||
def set_uint64(self, name: str, value: int) -> None:
|
||||
"""Set field to uint64 value.
|
||||
|
||||
Args:
|
||||
@@ -80,7 +80,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = serialize_uint64(value)
|
||||
|
||||
def set_string(self, name: str, value: str):
|
||||
def set_string(self, name: str, value: str) -> None:
|
||||
"""Set field to string value.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +90,7 @@ class Header(dict):
|
||||
"""
|
||||
self[name] = value.encode()
|
||||
|
||||
def set_time(self, name: str, value: int):
|
||||
def set_time(self, name: str, value: int) -> None:
|
||||
"""Set field to time value.
|
||||
|
||||
Args:
|
||||
@@ -163,7 +163,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
]
|
||||
self.chunk_threshold = 1 * (1 << 20)
|
||||
|
||||
def set_compression(self, fmt: CompressionFormat):
|
||||
def set_compression(self, fmt: CompressionFormat) -> None:
|
||||
"""Enable compression on rosbag1.
|
||||
|
||||
This function has to be called before opening.
|
||||
@@ -180,20 +180,21 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
self.compression_format = fmt.name.lower()
|
||||
|
||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, compresslevel=9)
|
||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, compression_level=16)
|
||||
bz2: Callable[[bytes], bytes] = lambda x: bz2_compress(x, 9)
|
||||
lz4: Callable[[bytes], bytes] = lambda x: lz4_compress(x, 16) # type: ignore
|
||||
self.compressor = {
|
||||
'bz2': bz2,
|
||||
'lz4': lz4,
|
||||
}[self.compression_format]
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag1 for writing."""
|
||||
try:
|
||||
self.bio = self.path.open('xb')
|
||||
except FileExistsError:
|
||||
raise WriterError(f'{self.path} exists already, not overwriting.') from None
|
||||
|
||||
assert self.bio
|
||||
self.bio.write(b'#ROSBAG V2.0\n')
|
||||
header = Header()
|
||||
header.set_uint64('index_pos', 0)
|
||||
@@ -263,7 +264,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.connections[connection.cid] = connection
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||
"""Write message to rosbag1.
|
||||
|
||||
Args:
|
||||
@@ -301,7 +302,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.write_chunk(chunk)
|
||||
|
||||
@staticmethod
|
||||
def write_connection(connection: Connection, bio: BytesIO):
|
||||
def write_connection(connection: Connection, bio: BinaryIO) -> None:
|
||||
"""Write connection record."""
|
||||
header = Header()
|
||||
header.set_uint32('conn', connection.cid)
|
||||
@@ -319,7 +320,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
header.set_string('latching', str(connection.latching))
|
||||
header.write(bio)
|
||||
|
||||
def write_chunk(self, chunk: WriteChunk):
|
||||
def write_chunk(self, chunk: WriteChunk) -> None:
|
||||
"""Write open chunk to file."""
|
||||
assert self.bio
|
||||
|
||||
@@ -347,12 +348,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
chunk.data.close()
|
||||
self.chunks.append(WriteChunk(BytesIO(), -1, 2**64, 0, defaultdict(list)))
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag1 after writing.
|
||||
|
||||
Closes open chunks and writes index.
|
||||
|
||||
"""
|
||||
assert self.bio
|
||||
for chunk in self.chunks:
|
||||
if chunk.pos == -1:
|
||||
self.write_chunk(chunk)
|
||||
|
||||
@@ -11,13 +11,14 @@ from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.error import YAMLError
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Type, Union
|
||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
@@ -25,7 +26,7 @@ class ReaderError(Exception):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def decompress(path: Path, do_decompress: bool):
|
||||
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
|
||||
"""Transparent rosbag2 database decompression context.
|
||||
|
||||
This context manager will yield a path to the decompressed file contents.
|
||||
@@ -119,12 +120,12 @@ class Reader:
|
||||
except KeyError as exc:
|
||||
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
self.bio = True
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
assert self.bio
|
||||
@@ -133,12 +134,14 @@ class Reader:
|
||||
@property
|
||||
def duration(self) -> int:
|
||||
"""Duration in nanoseconds between earliest and latest messages."""
|
||||
return self.metadata['duration']['nanoseconds'] + 1
|
||||
nsecs: int = self.metadata['duration']['nanoseconds']
|
||||
return nsecs + 1
|
||||
|
||||
@property
|
||||
def start_time(self) -> int:
|
||||
"""Timestamp in nanoseconds of the earliest message."""
|
||||
return self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
return nsecs
|
||||
|
||||
@property
|
||||
def end_time(self) -> int:
|
||||
@@ -148,7 +151,8 @@ class Reader:
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total message count."""
|
||||
return self.metadata['message_count']
|
||||
count: int = self.metadata['message_count']
|
||||
return count
|
||||
|
||||
@property
|
||||
def compression_format(self) -> Optional[str]:
|
||||
@@ -233,7 +237,7 @@ class Reader:
|
||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||
|
||||
cur.execute('SELECT name,id FROM topics')
|
||||
connmap: Dict[int, Connection] = {
|
||||
connmap: dict[int, Connection] = {
|
||||
row[1]: next((x for x in self.connections.values() if x.topic == row[0]),
|
||||
None) # type: ignore
|
||||
for row in cur
|
||||
|
||||
@@ -80,10 +80,10 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_format = ''
|
||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||
self.connections: dict[int, Connection] = {}
|
||||
self.conn = None
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.cursor: Optional[sqlite3.Cursor] = None
|
||||
|
||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
|
||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat) -> None:
|
||||
"""Enable compression on bag.
|
||||
|
||||
This function has to be called before opening.
|
||||
@@ -104,7 +104,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_format = fmt.name.lower()
|
||||
self.compressor = zstandard.ZstdCompressor()
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2 for writing.
|
||||
|
||||
Create base directory and open database connection.
|
||||
@@ -164,7 +164,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes) -> None:
|
||||
"""Write message to rosbag2.
|
||||
|
||||
Args:
|
||||
@@ -191,12 +191,14 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
)
|
||||
connection.count += 1
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2 after writing.
|
||||
|
||||
Closes open database transactions and writes metadata.yaml.
|
||||
|
||||
"""
|
||||
assert self.cursor
|
||||
assert self.conn
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
|
||||
@@ -209,6 +211,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.conn.close()
|
||||
|
||||
if self.compression_mode == 'file':
|
||||
assert self.compressor
|
||||
src = self.dbpath
|
||||
self.dbpath = src.with_suffix(f'.db3.{self.compression_format}')
|
||||
with src.open('rb') as infile, self.dbpath.open('wb') as outfile:
|
||||
|
||||
@@ -19,10 +19,10 @@ from .typing import Field
|
||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
from .typing import CDRDeser, CDRSer, CDRSerSize
|
||||
|
||||
|
||||
def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
||||
def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
|
||||
"""Generate cdr size calculation function.
|
||||
|
||||
Args:
|
||||
@@ -157,7 +157,7 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
|
||||
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
|
||||
|
||||
|
||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||
"""Generate cdr serialization function.
|
||||
|
||||
Args:
|
||||
@@ -296,7 +296,7 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
return compile_lines(lines).serialize_cdr # type: ignore
|
||||
|
||||
|
||||
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
|
||||
def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
"""Generate cdr deserialization function.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -65,9 +65,9 @@ def get_msgdef(typename: str) -> Msgdef:
|
||||
generate_serialize_cdr(fields, 'be'),
|
||||
generate_deserialize_cdr(fields, 'le'),
|
||||
generate_deserialize_cdr(fields, 'be'),
|
||||
generate_ros1_to_cdr(fields, typename, False),
|
||||
generate_ros1_to_cdr(fields, typename, True),
|
||||
generate_cdr_to_ros1(fields, typename, False),
|
||||
generate_cdr_to_ros1(fields, typename, True),
|
||||
generate_ros1_to_cdr(fields, typename, False), # type: ignore
|
||||
generate_ros1_to_cdr(fields, typename, True), # type: ignore
|
||||
generate_cdr_to_ros1(fields, typename, False), # type: ignore
|
||||
generate_cdr_to_ros1(fields, typename, True), # type: ignore
|
||||
)
|
||||
return MSGDEFCACHE[typename]
|
||||
|
||||
@@ -18,10 +18,16 @@ from .typing import Field
|
||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable # pylint: disable=ungrouped-imports
|
||||
from typing import Union # pylint: disable=ungrouped-imports
|
||||
|
||||
from .typing import Bitcvt, BitcvtSize
|
||||
|
||||
|
||||
def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
|
||||
def generate_ros1_to_cdr(
|
||||
fields: list[Field],
|
||||
typename: str,
|
||||
copy: bool,
|
||||
) -> Union[Bitcvt, BitcvtSize]:
|
||||
"""Generate ROS1 to CDR conversion function.
|
||||
|
||||
Args:
|
||||
@@ -169,10 +175,14 @@ def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Call
|
||||
aligned = anext
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname)
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
|
||||
|
||||
def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
|
||||
def generate_cdr_to_ros1(
|
||||
fields: list[Field],
|
||||
typename: str,
|
||||
copy: bool,
|
||||
) -> Union[Bitcvt, BitcvtSize]:
|
||||
"""Generate CDR to ROS1 conversion function.
|
||||
|
||||
Args:
|
||||
@@ -318,4 +328,4 @@ def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Call
|
||||
aligned = anext
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname)
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
|
||||
+18
-11
@@ -7,7 +7,14 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
Bitcvt = Callable[[bytes, int, bytes, int], Tuple[int, int]]
|
||||
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
|
||||
|
||||
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
|
||||
CDRSer = Callable[[bytes, int, type], int]
|
||||
CDRSerSize = Callable[[int, type], int]
|
||||
|
||||
|
||||
class Descriptor(NamedTuple):
|
||||
@@ -28,15 +35,15 @@ class Msgdef(NamedTuple):
|
||||
"""Metadata of a message."""
|
||||
|
||||
name: str
|
||||
fields: List[Field]
|
||||
fields: list[Field]
|
||||
cls: Any
|
||||
size_cdr: int
|
||||
getsize_cdr: Callable
|
||||
serialize_cdr_le: Callable
|
||||
serialize_cdr_be: Callable
|
||||
deserialize_cdr_le: Callable
|
||||
deserialize_cdr_be: Callable
|
||||
getsize_ros1_to_cdr: Callable
|
||||
ros1_to_cdr: Callable
|
||||
getsize_cdr_to_ros1: Callable
|
||||
cdr_to_ros1: Callable
|
||||
getsize_cdr: CDRSerSize
|
||||
serialize_cdr_le: CDRSer
|
||||
serialize_cdr_be: CDRSer
|
||||
deserialize_cdr_le: CDRDeser
|
||||
deserialize_cdr_be: CDRDeser
|
||||
getsize_ros1_to_cdr: BitcvtSize
|
||||
ros1_to_cdr: Bitcvt
|
||||
getsize_cdr_to_ros1: BitcvtSize
|
||||
cdr_to_ros1: Bitcvt
|
||||
|
||||
@@ -67,6 +67,6 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
|
||||
pos = rule.skip_ws(text, 0)
|
||||
npos, trees = rule.parse(text, pos)
|
||||
assert npos == len(text), f'Could not parse: {text!r}'
|
||||
return visitor.visit(trees)
|
||||
return visitor.visit(trees) # type: ignore
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
raise TypesysError(f'Could not parse: {text!r}') from err
|
||||
|
||||
@@ -253,10 +253,10 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
|
||||
RULES = parse_grammar(GRAMMAR_IDL)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__()
|
||||
self.typedefs = {}
|
||||
self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {}
|
||||
|
||||
def visit_specification(self, children: Any) -> Typesdict:
|
||||
"""Process start symbol, return only children of modules."""
|
||||
|
||||
@@ -50,7 +50,7 @@ class Rule:
|
||||
}
|
||||
return data
|
||||
|
||||
def parse(self, text: str, pos: int):
|
||||
def parse(self, text: str, pos: int) -> tuple[int, Any]:
|
||||
"""Apply rule at position."""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@@ -192,7 +192,7 @@ class Visitor: # pylint: disable=too-few-public-methods
|
||||
|
||||
RULES: dict[str, Rule] = {}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
|
||||
def visit(self, tree: Any) -> Any:
|
||||
|
||||
@@ -13,12 +13,14 @@ from . import types
|
||||
from .base import Nodetype, TypesysError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from .base import Typesdict
|
||||
|
||||
INTLIKE = re.compile('^u?(bool|int|float)')
|
||||
|
||||
|
||||
def get_typehint(desc: tuple) -> str:
|
||||
def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int]]]]) -> str:
|
||||
"""Get python type hint for field.
|
||||
|
||||
Args:
|
||||
@@ -29,18 +31,19 @@ def get_typehint(desc: tuple) -> str:
|
||||
|
||||
"""
|
||||
if desc[0] == Nodetype.BASE:
|
||||
if match := INTLIKE.match(desc[1]):
|
||||
if match := INTLIKE.match(desc[1]): # type: ignore
|
||||
return match.group(1)
|
||||
return 'str'
|
||||
|
||||
if desc[0] == Nodetype.NAME:
|
||||
assert isinstance(desc[1], str)
|
||||
return desc[1].replace('/', '__')
|
||||
|
||||
sub = desc[1][0]
|
||||
if INTLIKE.match(sub[1]):
|
||||
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
||||
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
||||
return f'list[{get_typehint(sub)}]'
|
||||
return f'list[{get_typehint(sub)}]' # type: ignore
|
||||
|
||||
|
||||
def generate_python_code(typs: Typesdict) -> str:
|
||||
@@ -99,7 +102,7 @@ def generate_python_code(typs: Typesdict) -> str:
|
||||
'',
|
||||
]
|
||||
|
||||
def get_ftype(ftype: tuple) -> tuple:
|
||||
def get_ftype(ftype: tuple[int, Any]) -> tuple[int, Any]:
|
||||
if ftype[0] <= 2:
|
||||
return int(ftype[0]), ftype[1]
|
||||
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
|
||||
|
||||
Reference in New Issue
Block a user