diff --git a/docs/examples/save_images_rosbag1.py b/docs/examples/save_images_rosbag1.py index 60acc8e2..3390407c 100644 --- a/docs/examples/save_images_rosbag1.py +++ b/docs/examples/save_images_rosbag1.py @@ -33,7 +33,7 @@ def save_images() -> None: frame_id=FRAMEID, ), format='jpeg', # could also be 'png' - data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore + data=numpy.fromfile(path, dtype=numpy.uint8), ) writer.write( diff --git a/docs/examples/save_images_rosbag2.py b/docs/examples/save_images_rosbag2.py index 673d4da5..f8a97959 100644 --- a/docs/examples/save_images_rosbag2.py +++ b/docs/examples/save_images_rosbag2.py @@ -33,7 +33,7 @@ def save_images() -> None: frame_id=FRAMEID, ), format='jpeg', # could also be 'png' - data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore + data=numpy.fromfile(path, dtype=numpy.uint8), ) writer.write( diff --git a/docs/examples/use_with_native.py b/docs/examples/use_with_native.py index f484bfe3..93715b80 100644 --- a/docs/examples/use_with_native.py +++ b/docs/examples/use_with_native.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: NATIVE_CLASSES: dict[str, Any] = {} -def to_native(msg: Any) -> Any: +def to_native(msg: Any) -> Any: # noqa: ANN401 """Convert rosbags message to native message. Args: diff --git a/src/rosbags/convert/converter.py b/src/rosbags/convert/converter.py index c20f4735..5b6614c5 100644 --- a/src/rosbags/convert/converter.py +++ b/src/rosbags/convert/converter.py @@ -171,5 +171,5 @@ def convert(src: Path, dst: Optional[Path]) -> None: raise ConverterError(f'Reading source bag: {err}') from err except (WriterError1, WriterError2) as err: raise ConverterError(f'Writing destination bag: {err}') from err - except Exception as err: # pylint: disable=broad-except + except Exception as err: raise ConverterError(f'Converting rosbag: {err!r}') from err diff --git a/src/rosbags/rosbag1/reader.py b/src/rosbags/rosbag1/reader.py index 83ad78a8..14742118 100644 --- a/src/rosbags/rosbag1/reader.py +++ b/src/rosbags/rosbag1/reader.py @@ -106,9 +106,9 @@ class IndexData(NamedTuple): def __eq__(self, other: object) -> bool: """Compare by time only.""" - if not isinstance(other, IndexData): # pragma: no cover - return NotImplemented - return self.time == other[0] + if isinstance(other, IndexData): + return self.time == other[0] + return NotImplemented # pragma: no cover def __ge__(self, other: tuple[int, ...]) -> bool: """Compare by time only.""" @@ -120,9 +120,9 @@ class IndexData(NamedTuple): def __ne__(self, other: object) -> bool: """Compare by time only.""" - if not isinstance(other, IndexData): # pragma: no cover - return NotImplemented - return self.time != other[0] + if isinstance(other, IndexData): + return self.time != other[0] + return NotImplemented # pragma: no cover deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct(' None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements + def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals """Open rosbag and read metadata.""" try: self.bio = self.path.open('rb') @@ -394,13 +394,11 @@ class Reader: conn_count = header.get_uint32('conn_count') chunk_count = header.get_uint32('chunk_count') try: - encryptor = header.get_string('encryptor') - if encryptor: - raise ValueError - except ValueError: - raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None + encryptor: Optional[str] = header.get_string('encryptor') except ReaderError: - pass + encryptor = None + if encryptor: + raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None if index_pos == 0: raise ReaderError('Bag is not indexed, reindex before reading.') diff --git a/src/rosbags/rosbag1/writer.py b/src/rosbags/rosbag1/writer.py index 9126a378..cbe4bc35 100644 --- a/src/rosbags/rosbag1/writer.py +++ b/src/rosbags/rosbag1/writer.py @@ -31,6 +31,7 @@ class WriterError(Exception): @dataclass class WriteChunk: """In progress chunk.""" + data: BytesIO pos: int start: int @@ -126,7 +127,7 @@ class Header(Dict[str, Any]): return size + 4 -class Writer: # pylint: disable=too-many-instance-attributes +class Writer: """Rosbag1 writer. This class implements writing of rosbag1 files in version 2.0. It should be @@ -212,7 +213,7 @@ class Writer: # pylint: disable=too-many-instance-attributes md5sum: Optional[str] = None, callerid: Optional[str] = None, latching: Optional[int] = None, - **_kw: Any, + **_kw: Any, # noqa: ANN401 ) -> Connection: """Add a connection. diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index ac719ec5..b3145aa3 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -18,7 +18,44 @@ from .connection import Connection if TYPE_CHECKING: from types import TracebackType - from typing import Any, Generator, Iterable, Literal, Optional, Type, Union + from typing import Any, Generator, Iterable, Literal, Optional, Type, TypedDict, Union + + class StartingTime(TypedDict): + """Bag starting time.""" + + nanoseconds_since_epoch: int + + class Duration(TypedDict): + """Bag starting time.""" + + nanoseconds: int + + class TopicMetadata(TypedDict): + """Topic metadata.""" + + name: str + type: str + serialization_format: str + offered_qos_profiles: str + + class TopicWithMessageCount(TypedDict): + """Topic with message count.""" + + message_count: int + topic_metadata: TopicMetadata + + class Metadata(TypedDict): + """Rosbag2 metadata file.""" + + version: int + storage_identifier: str + relative_file_paths: list[str] + starting_time: StartingTime + duration: Duration + message_count: int + compression_format: str + compression_mode: str + topics_with_message_count: list[TopicWithMessageCount] class ReaderError(Exception): @@ -72,13 +109,14 @@ class Reader: Raises: ReaderError: Bag not readable or bag metadata. + """ path = Path(path) - self.path = Path + yamlpath = path / 'metadata.yaml' + self.path = path self.bio = False try: yaml = YAML(typ='safe') - yamlpath = path / 'metadata.yaml' dct = yaml.load(yamlpath.read_text()) except OSError as err: raise ReaderError(f'Could not read metadata at {yamlpath}: {err}.') from None @@ -86,7 +124,7 @@ class Reader: raise ReaderError(f'Could not load YAML from {yamlpath}: {exc}') from None try: - self.metadata = dct['rosbag2_bagfile_information'] + self.metadata: Metadata = dct['rosbag2_bagfile_information'] if (ver := self.metadata['version']) > 4: raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.') if storageid := self.metadata['storage_identifier'] != 'sqlite3': @@ -95,8 +133,7 @@ class Reader: ) self.paths = [path / Path(x).name for x in self.metadata['relative_file_paths']] - missing = [x for x in self.paths if not x.exists()] - if missing: + if missing := [x for x in self.paths if not x.exists()]: raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') self.connections = { @@ -110,7 +147,7 @@ class Reader: ) for idx, x in enumerate(self.metadata['topics_with_message_count']) } noncdr = { - y for x in self.connections.values() if (y := x.serialization_format) != 'cdr' + fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr' } if noncdr: raise ReaderError(f'Serialization format {noncdr!r} is not supported.') @@ -140,8 +177,7 @@ class Reader: @property def start_time(self) -> int: """Timestamp in nanoseconds of the earliest message.""" - nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch'] - return nsecs + return self.metadata['starting_time']['nanoseconds_since_epoch'] @property def end_time(self) -> int: @@ -151,8 +187,7 @@ class Reader: @property def message_count(self) -> int: """Total message count.""" - count: int = self.metadata['message_count'] - return count + return self.metadata['message_count'] @property def compression_format(self) -> Optional[str]: diff --git a/src/rosbags/rosbag2/writer.py b/src/rosbags/rosbag2/writer.py index 915b6360..027647ac 100644 --- a/src/rosbags/rosbag2/writer.py +++ b/src/rosbags/rosbag2/writer.py @@ -18,6 +18,8 @@ if TYPE_CHECKING: from types import TracebackType from typing import Any, Literal, Optional, Type, Union + from .reader import Metadata + class WriterError(Exception): """Writer Error.""" @@ -125,7 +127,7 @@ class Writer: # pylint: disable=too-many-instance-attributes msgtype: str, serialization_format: str = 'cdr', offered_qos_profiles: str = '', - **_kw: Any, + **_kw: Any, # noqa: ANN401 ) -> Connection: """Add a connection. @@ -218,7 +220,7 @@ class Writer: # pylint: disable=too-many-instance-attributes self.compressor.copy_stream(infile, outfile) src.unlink() - metadata = { + metadata: dict[str, Metadata] = { 'rosbag2_bagfile_information': { 'version': 4, 'storage_identifier': 'sqlite3', diff --git a/src/rosbags/serde/cdr.py b/src/rosbags/serde/cdr.py index 7d7587d3..8dc87c96 100644 --- a/src/rosbags/serde/cdr.py +++ b/src/rosbags/serde/cdr.py @@ -86,22 +86,22 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]: else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) if subdesc.args.size_cdr: for _ in range(length): - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') - size = (size + anext - 1) & -anext + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + size = (size + anext_before - 1) & -anext_before lines.append(f' pos += {subdesc.args.size_cdr}') size += subdesc.args.size_cdr else: lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') lines.append(f' val = message.{fieldname}') for idx in range(length): - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(f' pos = func(pos, val[{idx}])') is_stat = False aligned = align_after(subdesc) @@ -117,45 +117,45 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]: lines.append(' pos += 4 + len(val.encode()) + 1') aligned = 1 else: - anext = align(subdesc) - if aligned < anext: + anext_before = align(subdesc) + if aligned < anext_before: lines.append(f' if len(message.{fieldname}):') - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') - aligned = anext + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}') else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) lines.append(f' val = message.{fieldname}') if subdesc.args.size_cdr: - if aligned < anext <= anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if aligned < anext_before <= anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' for _ in val:') - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(f' pos += {subdesc.args.size_cdr}') else: lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') - if aligned < anext <= anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if aligned < anext_before <= anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' for item in val:') - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' pos = func(pos, item)') aligned = align_after(subdesc) aligned = min([aligned, 4]) is_stat = False - if fnext and aligned < (anext := align(fnext.descriptor)): - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') - aligned = anext + if fnext and aligned < (anext_before := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before is_stat = False lines.append(' return pos') - return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore + return compile_lines(lines).getsize_cdr, is_stat * size def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer: @@ -240,14 +240,14 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer: else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) lines.append( f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', ) for idx in range(length): - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(f' pos = func(rawdata, pos, val[{idx}])') aligned = align_after(subdesc) else: @@ -272,28 +272,28 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer: lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}') if (endianess == 'le') != (sys.byteorder == 'little'): lines.append(' val = val.byteswap()') - if aligned < (anext := align(subdesc)): + if aligned < (anext_before := align(subdesc)): lines.append(' if size:') - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)') lines.append(' pos += size') - aligned = anext + aligned = anext_before if subdesc.valtype == Valtype.MESSAGE: - anext = align(subdesc) + anext_before = align(subdesc) lines.append( f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', ) lines.append(' for item in val:') - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(' pos = func(rawdata, pos, item)') aligned = align_after(subdesc) aligned = min([4, aligned]) - if fnext and aligned < (anext := align(fnext.descriptor)): - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') - aligned = anext + if fnext and aligned < (anext_before := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before lines.append(' return pos') return compile_lines(lines).serialize_cdr # type: ignore @@ -384,13 +384,13 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser: lines.append(f' pos += {size}') else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') lines.append(' value = []') for _ in range(length): - if anext > anext_after: - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') lines.append(' value.append(obj)') lines.append(' values.append(value)') @@ -418,9 +418,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser: aligned = 1 else: lines.append(f' length = size * {SIZEMAP[subdesc.args]}') - if aligned < (anext := align(subdesc)): + if aligned < (anext_before := align(subdesc)): lines.append(' if size:') - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append( f' val = numpy.frombuffer(rawdata, ' f'dtype=numpy.{subdesc.args}, count=size, offset=pos)', @@ -429,14 +429,14 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser: lines.append(' val = val.byteswap()') lines.append(' values.append(val)') lines.append(' pos += length') - aligned = anext + aligned = anext_before if subdesc.valtype == Valtype.MESSAGE: - anext = align(subdesc) + anext_before = align(subdesc) lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') lines.append(' value = []') lines.append(' for _ in range(size):') - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') lines.append(' value.append(obj)') lines.append(' values.append(value)') @@ -444,9 +444,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser: aligned = min([4, aligned]) - if fnext and aligned < (anext := align(fnext.descriptor)): - lines.append(f' pos = (pos + {anext} - 1) & -{anext}') - aligned = anext + if fnext and aligned < (anext_before := align(fnext.descriptor)): + lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before lines.append(' return cls(*values), pos') return compile_lines(lines).deserialize_cdr # type: ignore diff --git a/src/rosbags/serde/messages.py b/src/rosbags/serde/messages.py index df9c8e0c..bad20c4e 100644 --- a/src/rosbags/serde/messages.py +++ b/src/rosbags/serde/messages.py @@ -14,7 +14,7 @@ from .typing import Descriptor, Field, Msgdef from .utils import Valtype if TYPE_CHECKING: - from typing import Any + from rosbags.typesys.base import Fielddesc MSGDEFCACHE: dict[str, Msgdef] = {} @@ -38,14 +38,18 @@ def get_msgdef(typename: str) -> Msgdef: if typename not in MSGDEFCACHE: entries = types.FIELDDEFS[typename][1] - def fixup(entry: Any) -> Descriptor: + def fixup(entry: Fielddesc) -> Descriptor: if entry[0] == Valtype.BASE: + assert isinstance(entry[1], str) return Descriptor(Valtype.BASE, entry[1]) if entry[0] == Valtype.MESSAGE: + assert isinstance(entry[1], str) return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1])) if entry[0] == Valtype.ARRAY: + assert not isinstance(entry[1][0], str) return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1])) if entry[0] == Valtype.SEQUENCE: + assert not isinstance(entry[1][0], str) return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1])) raise SerdeError( # pragma: no cover f'Unknown field type {entry[0]!r} encountered.', diff --git a/src/rosbags/serde/ros1.py b/src/rosbags/serde/ros1.py index d51a83e0..402afb8c 100644 --- a/src/rosbags/serde/ros1.py +++ b/src/rosbags/serde/ros1.py @@ -18,7 +18,7 @@ from .typing import Field from .utils import SIZEMAP, Valtype, align, align_after, compile_lines if TYPE_CHECKING: - from typing import Union # pylint: disable=ungrouped-imports + from typing import Union from .typing import Bitcvt, BitcvtSize @@ -114,13 +114,13 @@ def generate_ros1_to_cdr( aligned = SIZEMAP[subdesc.args] if subdesc.valtype == Valtype.MESSAGE: - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') for _ in range(length): - if anext > anext_after: - lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}') lines.append(' ipos, opos = func(input, ipos, output, opos)') aligned = anext_after else: @@ -150,30 +150,30 @@ def generate_ros1_to_cdr( lines.append(' opos += length') aligned = 1 else: - if aligned < (anext := align(subdesc)): + if aligned < (anext_before := align(subdesc)): lines.append(' if size:') - lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}') lines.append(f' length = size * {SIZEMAP[subdesc.args]}') if copy: lines.append(' output[opos:opos + length] = input[ipos:ipos + length]') lines.append(' ipos += length') lines.append(' opos += length') - aligned = anext + aligned = anext_before else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(' for _ in range(size):') - lines.append(f' opos = (opos + {anext} - 1) & -{anext}') + lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}') lines.append(' ipos, opos = func(input, ipos, output, opos)') aligned = align_after(subdesc) aligned = min([aligned, 4]) - if fnext and aligned < (anext := align(fnext.descriptor)): - lines.append(f' opos = (opos + {anext} - 1) & -{anext}') - aligned = anext + if fnext and aligned < (anext_before := align(fnext.descriptor)): + lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before lines.append(' return ipos, opos') return getattr(compile_lines(lines), funcname) # type: ignore @@ -270,13 +270,13 @@ def generate_cdr_to_ros1( aligned = SIZEMAP[subdesc.args] if subdesc.valtype == Valtype.MESSAGE: - anext = align(subdesc) + anext_before = align(subdesc) anext_after = align_after(subdesc) lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') for _ in range(length): - if anext > anext_after: - lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') + if anext_before > anext_after: + lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}') lines.append(' ipos, opos = func(input, ipos, output, opos)') aligned = anext_after else: @@ -304,30 +304,30 @@ def generate_cdr_to_ros1( lines.append(' opos += length') aligned = 1 else: - if aligned < (anext := align(subdesc)): + if aligned < (anext_before := align(subdesc)): lines.append(' if size:') - lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') + lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}') lines.append(f' length = size * {SIZEMAP[subdesc.args]}') if copy: lines.append(' output[opos:opos + length] = input[ipos:ipos + length]') lines.append(' ipos += length') lines.append(' opos += length') - aligned = anext + aligned = anext_before else: assert subdesc.valtype == Valtype.MESSAGE - anext = align(subdesc) + anext_before = align(subdesc) lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(' for _ in range(size):') - lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') + lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}') lines.append(' ipos, opos = func(input, ipos, output, opos)') aligned = align_after(subdesc) aligned = min([aligned, 4]) - if fnext and aligned < (anext := align(fnext.descriptor)): - lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') - aligned = anext + if fnext and aligned < (anext_before := align(fnext.descriptor)): + lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}') + aligned = anext_before lines.append(' return ipos, opos') return getattr(compile_lines(lines), funcname) # type: ignore diff --git a/src/rosbags/serde/serdes.py b/src/rosbags/serde/serdes.py index 57e92ceb..96d4aabe 100644 --- a/src/rosbags/serde/serdes.py +++ b/src/rosbags/serde/serdes.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from typing import Any -def deserialize_cdr(rawdata: bytes, typename: str) -> Any: +def deserialize_cdr(rawdata: bytes, typename: str) -> Any: # noqa: ANN401 """Deserialize raw data into a message object. Args: @@ -35,7 +35,7 @@ def deserialize_cdr(rawdata: bytes, typename: str) -> Any: def serialize_cdr( - message: Any, + message: object, typename: str, little_endian: bool = sys.byteorder == 'little', ) -> memoryview: diff --git a/src/rosbags/serde/typing.py b/src/rosbags/serde/typing.py index 413059a9..de4e27b5 100644 --- a/src/rosbags/serde/typing.py +++ b/src/rosbags/serde/typing.py @@ -13,8 +13,8 @@ if TYPE_CHECKING: BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]] CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]] - CDRSer = Callable[[bytes, int, type], int] - CDRSerSize = Callable[[int, type], int] + CDRSer = Callable[[bytes, int, object], int] + CDRSerSize = Callable[[int, object], int] class Descriptor(NamedTuple): diff --git a/src/rosbags/typesys/base.py b/src/rosbags/typesys/base.py index 33a56e1a..79814ab0 100644 --- a/src/rosbags/typesys/base.py +++ b/src/rosbags/typesys/base.py @@ -68,5 +68,5 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict: npos, trees = rule.parse(text, pos) assert npos == len(text), f'Could not parse: {text!r}' return visitor.visit(trees) # type: ignore - except Exception as err: # pylint: disable=broad-except + except Exception as err: raise TypesysError(f'Could not parse: {text!r}') from err diff --git a/src/rosbags/typesys/register.py b/src/rosbags/typesys/register.py index fbdd3e31..ebbb53bc 100644 --- a/src/rosbags/typesys/register.py +++ b/src/rosbags/typesys/register.py @@ -31,9 +31,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int """ if desc[0] == Nodetype.BASE: - if match := INTLIKE.match(desc[1]): # type: ignore - return match.group(1) - return 'str' + assert isinstance(desc[1], str) + return match.group(1) if (match := INTLIKE.match(desc[1])) else 'str' if desc[0] == Nodetype.NAME: assert isinstance(desc[1], str) @@ -43,7 +42,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int if INTLIKE.match(sub[1]): typ = 'bool8' if sub[1] == 'bool' else sub[1] return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]' - return f'list[{get_typehint(sub)}]' # type: ignore + assert isinstance(sub, tuple) + return f'list[{get_typehint(sub)}]' def generate_python_code(typs: Typesdict) -> str: @@ -142,6 +142,7 @@ def register_types(typs: Typesdict) -> None: Raises: TypesysError: Type already present with different definition. + """ code = generate_python_code(typs) name = 'rosbags.usertypes' @@ -150,7 +151,7 @@ def register_types(typs: Typesdict) -> None: module = module_from_spec(spec) sys.modules[name] = module exec(code, module.__dict__) # pylint: disable=exec-used - fielddefs: Typesdict = module.FIELDDEFS # type: ignore + fielddefs: Typesdict = module.FIELDDEFS for name, (_, fields) in fielddefs.items(): if name == 'std_msgs/msg/Header': diff --git a/tests/cdr.py b/tests/cdr.py index 036576da..562329f3 100644 --- a/tests/cdr.py +++ b/tests/cdr.py @@ -117,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des size = SIZEMAP[desc.args] pos = (pos + size - 1) & -size - ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore + ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): ndarr = ndarr.byteswap() # no inplace on readonly array return ndarr, pos + num * SIZEMAP[desc.args] @@ -297,7 +297,7 @@ def serialize_message( rawdata: memoryview, bmap: BasetypeMap, pos: int, - message: Any, + message: object, msgdef: Msgdef, ) -> int: """Serialize a message. @@ -369,7 +369,7 @@ def get_array_size(desc: Descriptor, val: Array, size: int) -> int: raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover -def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int: +def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int: """Calculate size of serialzied message. Args: @@ -413,7 +413,7 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int: def serialize( - message: Any, + message: object, typename: str, little_endian: bool = sys.byteorder == 'little', ) -> memoryview: diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index bb80b4d0..6743deed 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -38,6 +38,6 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None: rconnection, _, raw = next(gen) assert rconnection == wconnection msg = deserialize_cdr(raw, rconnection.msgtype) - assert msg.data == Foo.data + assert getattr(msg, 'data', None) == Foo.data with pytest.raises(StopIteration): next(gen) diff --git a/tests/test_roundtrip1.py b/tests/test_roundtrip1.py index 95d3ace7..60758b91 100644 --- a/tests/test_roundtrip1.py +++ b/tests/test_roundtrip1.py @@ -39,6 +39,6 @@ def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> N gen = rbag.messages() connection, _, raw = next(gen) msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype) - assert msg.data == Foo.data + assert getattr(msg, 'data', None) == Foo.data with pytest.raises(StopIteration): next(gen) diff --git a/tests/test_serde.py b/tests/test_serde.py index d834f34a..8359f52f 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -13,7 +13,10 @@ import pytest from rosbags.serde import SerdeError, cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr from rosbags.serde.messages import get_msgdef from rosbags.typesys import get_types_from_msg, register_types -from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__Header +from rosbags.typesys.types import builtin_interfaces__msg__Time as Time +from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon +from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField +from rosbags.typesys.types import std_msgs__msg__Header as Header from .cdr import deserialize, serialize @@ -184,6 +187,7 @@ def _comparable() -> Generator[None, None, None]: Notes: This solution is necessary as numpy.ndarray is not directly patchable. + """ frombuffer = numpy.frombuffer @@ -195,16 +199,16 @@ def _comparable() -> Generator[None, None, None]: class CNDArray(MagicMock): """Mock ndarray.""" - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401 super().__init__(*args, **kwargs) self.__eq__ = arreq # type: ignore - def byteswap(self, *args: Any) -> 'CNDArray': + def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401 """Wrap return value also in mock.""" return CNDArray(wraps=self._mock_wraps.byteswap(*args)) - def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: - return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore + def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401 + return CNDArray(wraps=frombuffer(*args, **kwargs)) with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer): yield @@ -217,7 +221,7 @@ def test_serde(message: tuple[bytes, str, bool]) -> None: serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little) assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little) - assert serdeser == rawdata[0:len(serdeser)] + assert serdeser == rawdata[:len(serdeser)] assert len(rawdata) - len(serdeser) < 4 assert all(x == 0 for x in rawdata[len(serdeser):]) @@ -227,6 +231,7 @@ def test_deserializer() -> None: """Test deserializer.""" msg = deserialize_cdr(*MSG_POLY[:2]) assert msg == deserialize(*MSG_POLY[:2]) + assert isinstance(msg, Polygon) assert len(msg.points) == 2 assert msg.points[0].x == 1 assert msg.points[0].y == 2 @@ -237,6 +242,7 @@ def test_deserializer() -> None: msg = deserialize_cdr(*MSG_MAGN[:2]) assert msg == deserialize(*MSG_MAGN[:2]) + assert isinstance(msg, MagneticField) assert 'MagneticField' in repr(msg) assert msg.header.stamp.sec == 708 assert msg.header.stamp.nanosec == 256 @@ -248,6 +254,7 @@ def test_deserializer() -> None: msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2]) assert msg_big == deserialize(*MSG_MAGN_BIG[:2]) + assert isinstance(msg_big, MagneticField) assert msg.magnetic_field == msg_big.magnetic_field @@ -285,7 +292,7 @@ def test_serializer_errors() -> None: class Foo: # pylint: disable=too-few-public-methods """Dummy class.""" - coef = numpy.array([1, 2, 3, 4]) + coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4]) msg = Foo() ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True) @@ -376,7 +383,8 @@ def test_custom_type() -> None: def test_ros1_to_cdr() -> None: """Test ROS1 to CDR conversion.""" register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) - msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_ros = (b'\x01\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_cdr = ( b'\x00\x01\x00\x00' b'\x01\x00' @@ -386,7 +394,8 @@ def test_ros1_to_cdr() -> None: assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) - msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_ros = (b'\x01\x00\x00\x00X' + b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_cdr = ( b'\x00\x01\x00\x00' b'\x02\x00\x00\x00X\x00' @@ -399,7 +408,8 @@ def test_ros1_to_cdr() -> None: def test_cdr_to_ros1() -> None: """Test CDR to ROS1 conversion.""" register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) - msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_ros = (b'\x01\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_cdr = ( b'\x00\x01\x00\x00' b'\x01\x00' @@ -409,7 +419,8 @@ def test_cdr_to_ros1() -> None: assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/static_16_64') == msg_ros register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) - msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02') + msg_ros = (b'\x01\x00\x00\x00X' + b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_cdr = ( b'\x00\x01\x00\x00' b'\x02\x00\x00\x00X\x00' @@ -418,7 +429,7 @@ def test_cdr_to_ros1() -> None: ) assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/dynamic_s_64') == msg_ros - header = std_msgs__msg__Header(stamp=builtin_interfaces__msg__Time(42, 666), frame_id='frame') + header = Header(stamp=Time(42, 666), frame_id='frame') msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header') assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame' @@ -426,7 +437,6 @@ def test_cdr_to_ros1() -> None: @pytest.mark.usefixtures('_comparable') def test_padding_empty_sequence() -> None: """Test empty sequences do not add item padding.""" - # pylint: disable=protected-access register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b'))) su64_b = get_msgdef('test_msgs/msg/su64_b').cls @@ -446,7 +456,6 @@ def test_padding_empty_sequence() -> None: @pytest.mark.usefixtures('_comparable') def test_align_after_empty_sequence() -> None: """Test alignment after empty sequences.""" - # pylint: disable=protected-access register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64'))) su64_b = get_msgdef('test_msgs/msg/su64_u64').cls diff --git a/tests/test_writer1.py b/tests/test_writer1.py index ab7d34aa..28ba6cee 100644 --- a/tests/test_writer1.py +++ b/tests/test_writer1.py @@ -49,7 +49,7 @@ def test_add_connection(tmp_path: Path) -> None: with Writer(path) as writer: res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH') - assert res.cid == 0 + assert res.cid == 0 data = path.read_bytes() assert data.count(b'MESSAGE_DEFINITION') == 2 assert data.count(b'HASH') == 2 @@ -57,7 +57,7 @@ def test_add_connection(tmp_path: Path) -> None: with Writer(path) as writer: res = writer.add_connection('/foo', 'std_msgs/msg/Int8') - assert res.cid == 0 + assert res.cid == 0 data = path.read_bytes() assert data.count(b'int8 data') == 2 assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2 @@ -85,7 +85,7 @@ def test_add_connection(tmp_path: Path) -> None: 'HASH', latching=1, ) - assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2) + assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2) def test_write_errors(tmp_path: Path) -> None: diff --git a/tools/bench/bench.py b/tools/bench/bench.py index 0a673541..b8e70b17 100644 --- a/tools/bench/bench.py +++ b/tools/bench/bench.py @@ -21,7 +21,14 @@ from rosbags.rosbag2 import Reader from rosbags.serde import deserialize_cdr if TYPE_CHECKING: - from typing import Any, Generator + from typing import Generator, Protocol + + class NativeMSG(Protocol): # pylint: disable=too-few-public-methods + """Minimal native ROS message interface used for benchmark.""" + + def get_fields_and_field_types(self) -> dict[str, str]: + """Introspect message type.""" + raise NotImplementedError class ReaderPy: # pylint: disable=too-few-public-methods @@ -42,13 +49,13 @@ class ReaderPy: # pylint: disable=too-few-public-methods yield topic, self.typemap[topic], timestamp, data -def deserialize_py(data: bytes, msgtype: str) -> Any: +def deserialize_py(data: bytes, msgtype: str) -> NativeMSG: """Deserialization helper for rosidl_runtime_py + rclpy.""" pytype = get_message(msgtype) - return deserialize_message(data, pytype) + return deserialize_message(data, pytype) # type: ignore -def compare_msg(lite: Any, native: Any) -> None: +def compare_msg(lite: object, native: NativeMSG) -> None: """Compare rosbag2 (lite) vs rosbag2_py (native) message content. Args: @@ -96,8 +103,8 @@ def compare(path: Path) -> None: msg = deserialize_cdr(data, connection.msgtype) compare_msg(msg, msg_py) - assert len(list(gens[0])) == 0 - assert len(list(gens[1])) == 0 + assert not list(gens[0]) + assert not list(gens[1]) def read_deser_rosbag2_py(path: Path) -> None: diff --git a/tools/compare/compare.py b/tools/compare/compare.py index cda5ee7e..f928737f 100644 --- a/tools/compare/compare.py +++ b/tools/compare/compare.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING from unittest.mock import Mock import genpy # type: ignore +import numpy import rosgraph_msgs.msg # type: ignore from rclpy.serialization import deserialize_message # type: ignore from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore @@ -25,9 +26,15 @@ rosgraph_msgs.msg.TopicStatistics = Mock() import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position if TYPE_CHECKING: - from typing import Any, Generator, List, Union + from typing import Generator, List, Protocol, Union, runtime_checkable - from rosbag.bag import _Connection_Info + @runtime_checkable + class NativeMSG(Protocol): # pylint: disable=too-few-public-methods + """Minimal native ROS message interface used for benchmark.""" + + def get_fields_and_field_types(self) -> dict[str, str]: + """Introspect message type.""" + raise NotImplementedError class Reader: # pylint: disable=too-few-public-methods @@ -47,7 +54,7 @@ class Reader: # pylint: disable=too-few-public-methods yield topic, timestamp, deserialize_message(data, pytype) -def fixup_ros1(conns: List[_Connection_Info]) -> None: +def fixup_ros1(conns: List[rosbag.bag._Connection_Info]) -> None: """Monkeypatch ROS2 fieldnames onto ROS1 objects. Args: @@ -61,7 +68,6 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None: if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None): print('Patching CameraInfo') # noqa: T001 - # pylint: disable=assignment-from-no-return,too-many-function-args cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010 cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010 @@ -69,7 +75,7 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None: cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010 -def compare(ref: Any, msg: Any) -> None: +def compare(ref: object, msg: object) -> None: """Compare message to its reference. Args: @@ -77,7 +83,7 @@ def compare(ref: Any, msg: Any) -> None: msg: Converted ROS2 message. """ - if hasattr(msg, 'get_fields_and_field_types'): + if isinstance(msg, NativeMSG): for name in msg.get_fields_and_field_types(): refval = getattr(ref, name) msgval = getattr(msg, name) @@ -87,9 +93,11 @@ def compare(ref: Any, msg: Any) -> None: if isinstance(ref, bytes): assert msg.tobytes() == ref else: + assert isinstance(msg, numpy.ndarray) assert (msg == ref).all() elif isinstance(msg, list): + assert isinstance(ref, (list, numpy.ndarray)) assert len(msg) == len(ref) for refitem, msgitem in zip(ref, msg): compare(refitem, msgitem) @@ -97,8 +105,9 @@ def compare(ref: Any, msg: Any) -> None: elif isinstance(msg, str): assert msg == ref - elif isinstance(msg, float) and math.isnan(ref): - assert math.isnan(msg) + elif isinstance(msg, float) and math.isnan(msg): + assert isinstance(ref, float) + assert math.isnan(ref) else: assert ref == msg