Update lint

This commit is contained in:
Marko Durkovic 2022-04-11 00:07:53 +02:00
parent 7315a4ab4d
commit 19f0678645
22 changed files with 215 additions and 149 deletions

View File

@ -33,7 +33,7 @@ def save_images() -> None:
frame_id=FRAMEID, frame_id=FRAMEID,
), ),
format='jpeg', # could also be 'png' format='jpeg', # could also be 'png'
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore data=numpy.fromfile(path, dtype=numpy.uint8),
) )
writer.write( writer.write(

View File

@ -33,7 +33,7 @@ def save_images() -> None:
frame_id=FRAMEID, frame_id=FRAMEID,
), ),
format='jpeg', # could also be 'png' format='jpeg', # could also be 'png'
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore data=numpy.fromfile(path, dtype=numpy.uint8),
) )
writer.write( writer.write(

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
NATIVE_CLASSES: dict[str, Any] = {} NATIVE_CLASSES: dict[str, Any] = {}
def to_native(msg: Any) -> Any: def to_native(msg: Any) -> Any: # noqa: ANN401
"""Convert rosbags message to native message. """Convert rosbags message to native message.
Args: Args:

View File

@ -171,5 +171,5 @@ def convert(src: Path, dst: Optional[Path]) -> None:
raise ConverterError(f'Reading source bag: {err}') from err raise ConverterError(f'Reading source bag: {err}') from err
except (WriterError1, WriterError2) as err: except (WriterError1, WriterError2) as err:
raise ConverterError(f'Writing destination bag: {err}') from err raise ConverterError(f'Writing destination bag: {err}') from err
except Exception as err: # pylint: disable=broad-except except Exception as err:
raise ConverterError(f'Converting rosbag: {err!r}') from err raise ConverterError(f'Converting rosbag: {err!r}') from err

View File

@ -106,9 +106,9 @@ class IndexData(NamedTuple):
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
"""Compare by time only.""" """Compare by time only."""
if not isinstance(other, IndexData): # pragma: no cover if isinstance(other, IndexData):
return NotImplemented
return self.time == other[0] return self.time == other[0]
return NotImplemented # pragma: no cover
def __ge__(self, other: tuple[int, ...]) -> bool: def __ge__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only.""" """Compare by time only."""
@ -120,9 +120,9 @@ class IndexData(NamedTuple):
def __ne__(self, other: object) -> bool: def __ne__(self, other: object) -> bool:
"""Compare by time only.""" """Compare by time only."""
if not isinstance(other, IndexData): # pragma: no cover if isinstance(other, IndexData):
return NotImplemented
return self.time != other[0] return self.time != other[0]
return NotImplemented # pragma: no cover
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
@ -369,7 +369,7 @@ class Reader:
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO()) self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
self.topics: dict[str, TopicInfo] = {} self.topics: dict[str, TopicInfo] = {}
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals
"""Open rosbag and read metadata.""" """Open rosbag and read metadata."""
try: try:
self.bio = self.path.open('rb') self.bio = self.path.open('rb')
@ -394,13 +394,11 @@ class Reader:
conn_count = header.get_uint32('conn_count') conn_count = header.get_uint32('conn_count')
chunk_count = header.get_uint32('chunk_count') chunk_count = header.get_uint32('chunk_count')
try: try:
encryptor = header.get_string('encryptor') encryptor: Optional[str] = header.get_string('encryptor')
if encryptor:
raise ValueError
except ValueError:
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
except ReaderError: except ReaderError:
pass encryptor = None
if encryptor:
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
if index_pos == 0: if index_pos == 0:
raise ReaderError('Bag is not indexed, reindex before reading.') raise ReaderError('Bag is not indexed, reindex before reading.')

View File

@ -31,6 +31,7 @@ class WriterError(Exception):
@dataclass @dataclass
class WriteChunk: class WriteChunk:
"""In progress chunk.""" """In progress chunk."""
data: BytesIO data: BytesIO
pos: int pos: int
start: int start: int
@ -126,7 +127,7 @@ class Header(Dict[str, Any]):
return size + 4 return size + 4
class Writer: # pylint: disable=too-many-instance-attributes class Writer:
"""Rosbag1 writer. """Rosbag1 writer.
This class implements writing of rosbag1 files in version 2.0. It should be This class implements writing of rosbag1 files in version 2.0. It should be
@ -212,7 +213,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
md5sum: Optional[str] = None, md5sum: Optional[str] = None,
callerid: Optional[str] = None, callerid: Optional[str] = None,
latching: Optional[int] = None, latching: Optional[int] = None,
**_kw: Any, **_kw: Any, # noqa: ANN401
) -> Connection: ) -> Connection:
"""Add a connection. """Add a connection.

View File

@ -18,7 +18,44 @@ from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union from typing import Any, Generator, Iterable, Literal, Optional, Type, TypedDict, Union
class StartingTime(TypedDict):
"""Bag starting time."""
nanoseconds_since_epoch: int
class Duration(TypedDict):
"""Bag starting time."""
nanoseconds: int
class TopicMetadata(TypedDict):
"""Topic metadata."""
name: str
type: str
serialization_format: str
offered_qos_profiles: str
class TopicWithMessageCount(TypedDict):
"""Topic with message count."""
message_count: int
topic_metadata: TopicMetadata
class Metadata(TypedDict):
"""Rosbag2 metadata file."""
version: int
storage_identifier: str
relative_file_paths: list[str]
starting_time: StartingTime
duration: Duration
message_count: int
compression_format: str
compression_mode: str
topics_with_message_count: list[TopicWithMessageCount]
class ReaderError(Exception): class ReaderError(Exception):
@ -72,13 +109,14 @@ class Reader:
Raises: Raises:
ReaderError: Bag not readable or bag metadata. ReaderError: Bag not readable or bag metadata.
""" """
path = Path(path) path = Path(path)
self.path = Path yamlpath = path / 'metadata.yaml'
self.path = path
self.bio = False self.bio = False
try: try:
yaml = YAML(typ='safe') yaml = YAML(typ='safe')
yamlpath = path / 'metadata.yaml'
dct = yaml.load(yamlpath.read_text()) dct = yaml.load(yamlpath.read_text())
except OSError as err: except OSError as err:
raise ReaderError(f'Could not read metadata at {yamlpath}: {err}.') from None raise ReaderError(f'Could not read metadata at {yamlpath}: {err}.') from None
@ -86,7 +124,7 @@ class Reader:
raise ReaderError(f'Could not load YAML from {yamlpath}: {exc}') from None raise ReaderError(f'Could not load YAML from {yamlpath}: {exc}') from None
try: try:
self.metadata = dct['rosbag2_bagfile_information'] self.metadata: Metadata = dct['rosbag2_bagfile_information']
if (ver := self.metadata['version']) > 4: if (ver := self.metadata['version']) > 4:
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.') raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
if storageid := self.metadata['storage_identifier'] != 'sqlite3': if storageid := self.metadata['storage_identifier'] != 'sqlite3':
@ -95,8 +133,7 @@ class Reader:
) )
self.paths = [path / Path(x).name for x in self.metadata['relative_file_paths']] self.paths = [path / Path(x).name for x in self.metadata['relative_file_paths']]
missing = [x for x in self.paths if not x.exists()] if missing := [x for x in self.paths if not x.exists()]:
if missing:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}') raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
self.connections = { self.connections = {
@ -110,7 +147,7 @@ class Reader:
) for idx, x in enumerate(self.metadata['topics_with_message_count']) ) for idx, x in enumerate(self.metadata['topics_with_message_count'])
} }
noncdr = { noncdr = {
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr' fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
} }
if noncdr: if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.') raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
@ -140,8 +177,7 @@ class Reader:
@property @property
def start_time(self) -> int: def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message.""" """Timestamp in nanoseconds of the earliest message."""
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch'] return self.metadata['starting_time']['nanoseconds_since_epoch']
return nsecs
@property @property
def end_time(self) -> int: def end_time(self) -> int:
@ -151,8 +187,7 @@ class Reader:
@property @property
def message_count(self) -> int: def message_count(self) -> int:
"""Total message count.""" """Total message count."""
count: int = self.metadata['message_count'] return self.metadata['message_count']
return count
@property @property
def compression_format(self) -> Optional[str]: def compression_format(self) -> Optional[str]:

View File

@ -18,6 +18,8 @@ if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Literal, Optional, Type, Union from typing import Any, Literal, Optional, Type, Union
from .reader import Metadata
class WriterError(Exception): class WriterError(Exception):
"""Writer Error.""" """Writer Error."""
@ -125,7 +127,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
msgtype: str, msgtype: str,
serialization_format: str = 'cdr', serialization_format: str = 'cdr',
offered_qos_profiles: str = '', offered_qos_profiles: str = '',
**_kw: Any, **_kw: Any, # noqa: ANN401
) -> Connection: ) -> Connection:
"""Add a connection. """Add a connection.
@ -218,7 +220,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compressor.copy_stream(infile, outfile) self.compressor.copy_stream(infile, outfile)
src.unlink() src.unlink()
metadata = { metadata: dict[str, Metadata] = {
'rosbag2_bagfile_information': { 'rosbag2_bagfile_information': {
'version': 4, 'version': 4,
'storage_identifier': 'sqlite3', 'storage_identifier': 'sqlite3',

View File

@ -86,22 +86,22 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
if subdesc.args.size_cdr: if subdesc.args.size_cdr:
for _ in range(length): for _ in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
size = (size + anext - 1) & -anext size = (size + anext_before - 1) & -anext_before
lines.append(f' pos += {subdesc.args.size_cdr}') lines.append(f' pos += {subdesc.args.size_cdr}')
size += subdesc.args.size_cdr size += subdesc.args.size_cdr
else: else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
lines.append(f' val = message.{fieldname}') lines.append(f' val = message.{fieldname}')
for idx in range(length): for idx in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos = func(pos, val[{idx}])') lines.append(f' pos = func(pos, val[{idx}])')
is_stat = False is_stat = False
aligned = align_after(subdesc) aligned = align_after(subdesc)
@ -117,45 +117,45 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
lines.append(' pos += 4 + len(val.encode()) + 1') lines.append(' pos += 4 + len(val.encode()) + 1')
aligned = 1 aligned = 1
else: else:
anext = align(subdesc) anext_before = align(subdesc)
if aligned < anext: if aligned < anext_before:
lines.append(f' if len(message.{fieldname}):') lines.append(f' if len(message.{fieldname}):')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}') lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}')
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' val = message.{fieldname}') lines.append(f' val = message.{fieldname}')
if subdesc.args.size_cdr: if subdesc.args.size_cdr:
if aligned < anext <= anext_after: if aligned < anext_before <= anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' for _ in val:') lines.append(' for _ in val:')
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos += {subdesc.args.size_cdr}') lines.append(f' pos += {subdesc.args.size_cdr}')
else: else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
if aligned < anext <= anext_after: if aligned < anext_before <= anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' for item in val:') lines.append(' for item in val:')
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' pos = func(pos, item)') lines.append(' pos = func(pos, item)')
aligned = align_after(subdesc) aligned = align_after(subdesc)
aligned = min([aligned, 4]) aligned = min([aligned, 4])
is_stat = False is_stat = False
if fnext and aligned < (anext := align(fnext.descriptor)): if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
is_stat = False is_stat = False
lines.append(' return pos') lines.append(' return pos')
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore return compile_lines(lines).getsize_cdr, is_stat * size
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer: def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
@ -240,14 +240,14 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append( lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
) )
for idx in range(length): for idx in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos = func(rawdata, pos, val[{idx}])') lines.append(f' pos = func(rawdata, pos, val[{idx}])')
aligned = align_after(subdesc) aligned = align_after(subdesc)
else: else:
@ -272,28 +272,28 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}') lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}')
if (endianess == 'le') != (sys.byteorder == 'little'): if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()') lines.append(' val = val.byteswap()')
if aligned < (anext := align(subdesc)): if aligned < (anext_before := align(subdesc)):
lines.append(' if size:') lines.append(' if size:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)') lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)')
lines.append(' pos += size') lines.append(' pos += size')
aligned = anext aligned = anext_before
if subdesc.valtype == Valtype.MESSAGE: if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc) anext_before = align(subdesc)
lines.append( lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
) )
lines.append(' for item in val:') lines.append(' for item in val:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' pos = func(rawdata, pos, item)') lines.append(' pos = func(rawdata, pos, item)')
aligned = align_after(subdesc) aligned = align_after(subdesc)
aligned = min([4, aligned]) aligned = min([4, aligned])
if fnext and aligned < (anext := align(fnext.descriptor)): if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
lines.append(' return pos') lines.append(' return pos')
return compile_lines(lines).serialize_cdr # type: ignore return compile_lines(lines).serialize_cdr # type: ignore
@ -384,13 +384,13 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
lines.append(f' pos += {size}') lines.append(f' pos += {size}')
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []') lines.append(' value = []')
for _ in range(length): for _ in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
lines.append(' value.append(obj)') lines.append(' value.append(obj)')
lines.append(' values.append(value)') lines.append(' values.append(value)')
@ -418,9 +418,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
aligned = 1 aligned = 1
else: else:
lines.append(f' length = size * {SIZEMAP[subdesc.args]}') lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if aligned < (anext := align(subdesc)): if aligned < (anext_before := align(subdesc)):
lines.append(' if size:') lines.append(' if size:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append( lines.append(
f' val = numpy.frombuffer(rawdata, ' f' val = numpy.frombuffer(rawdata, '
f'dtype=numpy.{subdesc.args}, count=size, offset=pos)', f'dtype=numpy.{subdesc.args}, count=size, offset=pos)',
@ -429,14 +429,14 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
lines.append(' val = val.byteswap()') lines.append(' val = val.byteswap()')
lines.append(' values.append(val)') lines.append(' values.append(val)')
lines.append(' pos += length') lines.append(' pos += length')
aligned = anext aligned = anext_before
if subdesc.valtype == Valtype.MESSAGE: if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc) anext_before = align(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []') lines.append(' value = []')
lines.append(' for _ in range(size):') lines.append(' for _ in range(size):')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
lines.append(' value.append(obj)') lines.append(' value.append(obj)')
lines.append(' values.append(value)') lines.append(' values.append(value)')
@ -444,9 +444,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
aligned = min([4, aligned]) aligned = min([4, aligned])
if fnext and aligned < (anext := align(fnext.descriptor)): if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
lines.append(' return cls(*values), pos') lines.append(' return cls(*values), pos')
return compile_lines(lines).deserialize_cdr # type: ignore return compile_lines(lines).deserialize_cdr # type: ignore

View File

@ -14,7 +14,7 @@ from .typing import Descriptor, Field, Msgdef
from .utils import Valtype from .utils import Valtype
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from rosbags.typesys.base import Fielddesc
MSGDEFCACHE: dict[str, Msgdef] = {} MSGDEFCACHE: dict[str, Msgdef] = {}
@ -38,14 +38,18 @@ def get_msgdef(typename: str) -> Msgdef:
if typename not in MSGDEFCACHE: if typename not in MSGDEFCACHE:
entries = types.FIELDDEFS[typename][1] entries = types.FIELDDEFS[typename][1]
def fixup(entry: Any) -> Descriptor: def fixup(entry: Fielddesc) -> Descriptor:
if entry[0] == Valtype.BASE: if entry[0] == Valtype.BASE:
assert isinstance(entry[1], str)
return Descriptor(Valtype.BASE, entry[1]) return Descriptor(Valtype.BASE, entry[1])
if entry[0] == Valtype.MESSAGE: if entry[0] == Valtype.MESSAGE:
assert isinstance(entry[1], str)
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1])) return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
if entry[0] == Valtype.ARRAY: if entry[0] == Valtype.ARRAY:
assert not isinstance(entry[1][0], str)
return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1])) return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
if entry[0] == Valtype.SEQUENCE: if entry[0] == Valtype.SEQUENCE:
assert not isinstance(entry[1][0], str)
return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1])) return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
raise SerdeError( # pragma: no cover raise SerdeError( # pragma: no cover
f'Unknown field type {entry[0]!r} encountered.', f'Unknown field type {entry[0]!r} encountered.',

View File

@ -18,7 +18,7 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Union # pylint: disable=ungrouped-imports from typing import Union
from .typing import Bitcvt, BitcvtSize from .typing import Bitcvt, BitcvtSize
@ -114,13 +114,13 @@ def generate_ros1_to_cdr(
aligned = SIZEMAP[subdesc.args] aligned = SIZEMAP[subdesc.args]
if subdesc.valtype == Valtype.MESSAGE: if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(length): for _ in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' opos = (opos + {anext} - 1) & -{anext}') lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)') lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = anext_after aligned = anext_after
else: else:
@ -150,30 +150,30 @@ def generate_ros1_to_cdr(
lines.append(' opos += length') lines.append(' opos += length')
aligned = 1 aligned = 1
else: else:
if aligned < (anext := align(subdesc)): if aligned < (anext_before := align(subdesc)):
lines.append(' if size:') lines.append(' if size:')
lines.append(f' opos = (opos + {anext} - 1) & -{anext}') lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(f' length = size * {SIZEMAP[subdesc.args]}') lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if copy: if copy:
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]') lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
lines.append(' ipos += length') lines.append(' ipos += length')
lines.append(' opos += length') lines.append(' opos += length')
aligned = anext aligned = anext_before
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
lines.append(' for _ in range(size):') lines.append(' for _ in range(size):')
lines.append(f' opos = (opos + {anext} - 1) & -{anext}') lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)') lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = align_after(subdesc) aligned = align_after(subdesc)
aligned = min([aligned, 4]) aligned = min([aligned, 4])
if fnext and aligned < (anext := align(fnext.descriptor)): if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' opos = (opos + {anext} - 1) & -{anext}') lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
lines.append(' return ipos, opos') lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) # type: ignore return getattr(compile_lines(lines), funcname) # type: ignore
@ -270,13 +270,13 @@ def generate_cdr_to_ros1(
aligned = SIZEMAP[subdesc.args] aligned = SIZEMAP[subdesc.args]
if subdesc.valtype == Valtype.MESSAGE: if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc) anext_before = align(subdesc)
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(length): for _ in range(length):
if anext > anext_after: if anext_before > anext_after:
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)') lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = anext_after aligned = anext_after
else: else:
@ -304,30 +304,30 @@ def generate_cdr_to_ros1(
lines.append(' opos += length') lines.append(' opos += length')
aligned = 1 aligned = 1
else: else:
if aligned < (anext := align(subdesc)): if aligned < (anext_before := align(subdesc)):
lines.append(' if size:') lines.append(' if size:')
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(f' length = size * {SIZEMAP[subdesc.args]}') lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if copy: if copy:
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]') lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
lines.append(' ipos += length') lines.append(' ipos += length')
lines.append(' opos += length') lines.append(' opos += length')
aligned = anext aligned = anext_before
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc) anext_before = align(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
lines.append(' for _ in range(size):') lines.append(' for _ in range(size):')
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)') lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = align_after(subdesc) aligned = align_after(subdesc)
aligned = min([aligned, 4]) aligned = min([aligned, 4])
if fnext and aligned < (anext := align(fnext.descriptor)): if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}') lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
aligned = anext aligned = anext_before
lines.append(' return ipos, opos') lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) # type: ignore return getattr(compile_lines(lines), funcname) # type: ignore

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from typing import Any from typing import Any
def deserialize_cdr(rawdata: bytes, typename: str) -> Any: def deserialize_cdr(rawdata: bytes, typename: str) -> Any: # noqa: ANN401
"""Deserialize raw data into a message object. """Deserialize raw data into a message object.
Args: Args:
@ -35,7 +35,7 @@ def deserialize_cdr(rawdata: bytes, typename: str) -> Any:
def serialize_cdr( def serialize_cdr(
message: Any, message: object,
typename: str, typename: str,
little_endian: bool = sys.byteorder == 'little', little_endian: bool = sys.byteorder == 'little',
) -> memoryview: ) -> memoryview:

View File

@ -13,8 +13,8 @@ if TYPE_CHECKING:
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]] BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]] CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
CDRSer = Callable[[bytes, int, type], int] CDRSer = Callable[[bytes, int, object], int]
CDRSerSize = Callable[[int, type], int] CDRSerSize = Callable[[int, object], int]
class Descriptor(NamedTuple): class Descriptor(NamedTuple):

View File

@ -68,5 +68,5 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
npos, trees = rule.parse(text, pos) npos, trees = rule.parse(text, pos)
assert npos == len(text), f'Could not parse: {text!r}' assert npos == len(text), f'Could not parse: {text!r}'
return visitor.visit(trees) # type: ignore return visitor.visit(trees) # type: ignore
except Exception as err: # pylint: disable=broad-except except Exception as err:
raise TypesysError(f'Could not parse: {text!r}') from err raise TypesysError(f'Could not parse: {text!r}') from err

View File

@ -31,9 +31,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
""" """
if desc[0] == Nodetype.BASE: if desc[0] == Nodetype.BASE:
if match := INTLIKE.match(desc[1]): # type: ignore assert isinstance(desc[1], str)
return match.group(1) return match.group(1) if (match := INTLIKE.match(desc[1])) else 'str'
return 'str'
if desc[0] == Nodetype.NAME: if desc[0] == Nodetype.NAME:
assert isinstance(desc[1], str) assert isinstance(desc[1], str)
@ -43,7 +42,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
if INTLIKE.match(sub[1]): if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1] typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]' return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
return f'list[{get_typehint(sub)}]' # type: ignore assert isinstance(sub, tuple)
return f'list[{get_typehint(sub)}]'
def generate_python_code(typs: Typesdict) -> str: def generate_python_code(typs: Typesdict) -> str:
@ -142,6 +142,7 @@ def register_types(typs: Typesdict) -> None:
Raises: Raises:
TypesysError: Type already present with different definition. TypesysError: Type already present with different definition.
""" """
code = generate_python_code(typs) code = generate_python_code(typs)
name = 'rosbags.usertypes' name = 'rosbags.usertypes'
@ -150,7 +151,7 @@ def register_types(typs: Typesdict) -> None:
module = module_from_spec(spec) module = module_from_spec(spec)
sys.modules[name] = module sys.modules[name] = module
exec(code, module.__dict__) # pylint: disable=exec-used exec(code, module.__dict__) # pylint: disable=exec-used
fielddefs: Typesdict = module.FIELDDEFS # type: ignore fielddefs: Typesdict = module.FIELDDEFS
for name, (_, fields) in fielddefs.items(): for name, (_, fields) in fielddefs.items():
if name == 'std_msgs/msg/Header': if name == 'std_msgs/msg/Header':

View File

@ -117,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
size = SIZEMAP[desc.args] size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'): if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args] return ndarr, pos + num * SIZEMAP[desc.args]
@ -297,7 +297,7 @@ def serialize_message(
rawdata: memoryview, rawdata: memoryview,
bmap: BasetypeMap, bmap: BasetypeMap,
pos: int, pos: int,
message: Any, message: object,
msgdef: Msgdef, msgdef: Msgdef,
) -> int: ) -> int:
"""Serialize a message. """Serialize a message.
@ -369,7 +369,7 @@ def get_array_size(desc: Descriptor, val: Array, size: int) -> int:
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int: def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int:
"""Calculate size of serialzied message. """Calculate size of serialzied message.
Args: Args:
@ -413,7 +413,7 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
def serialize( def serialize(
message: Any, message: object,
typename: str, typename: str,
little_endian: bool = sys.byteorder == 'little', little_endian: bool = sys.byteorder == 'little',
) -> memoryview: ) -> memoryview:

View File

@ -38,6 +38,6 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
rconnection, _, raw = next(gen) rconnection, _, raw = next(gen)
assert rconnection == wconnection assert rconnection == wconnection
msg = deserialize_cdr(raw, rconnection.msgtype) msg = deserialize_cdr(raw, rconnection.msgtype)
assert msg.data == Foo.data assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)

View File

@ -39,6 +39,6 @@ def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> N
gen = rbag.messages() gen = rbag.messages()
connection, _, raw = next(gen) connection, _, raw = next(gen)
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype) msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
assert msg.data == Foo.data assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(gen) next(gen)

View File

@ -13,7 +13,10 @@ import pytest
from rosbags.serde import SerdeError, cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr from rosbags.serde import SerdeError, cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
from rosbags.serde.messages import get_msgdef from rosbags.serde.messages import get_msgdef
from rosbags.typesys import get_types_from_msg, register_types from rosbags.typesys import get_types_from_msg, register_types
from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__Header from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon
from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField
from rosbags.typesys.types import std_msgs__msg__Header as Header
from .cdr import deserialize, serialize from .cdr import deserialize, serialize
@ -184,6 +187,7 @@ def _comparable() -> Generator[None, None, None]:
Notes: Notes:
This solution is necessary as numpy.ndarray is not directly patchable. This solution is necessary as numpy.ndarray is not directly patchable.
""" """
frombuffer = numpy.frombuffer frombuffer = numpy.frombuffer
@ -195,16 +199,16 @@ def _comparable() -> Generator[None, None, None]:
class CNDArray(MagicMock): class CNDArray(MagicMock):
"""Mock ndarray.""" """Mock ndarray."""
def __init__(self, *args: Any, **kwargs: Any): def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.__eq__ = arreq # type: ignore self.__eq__ = arreq # type: ignore
def byteswap(self, *args: Any) -> 'CNDArray': def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401
"""Wrap return value also in mock.""" """Wrap return value also in mock."""
return CNDArray(wraps=self._mock_wraps.byteswap(*args)) return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore return CNDArray(wraps=frombuffer(*args, **kwargs))
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer): with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield yield
@ -217,7 +221,7 @@ def test_serde(message: tuple[bytes, str, bool]) -> None:
serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little) serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little)
assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little) assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little)
assert serdeser == rawdata[0:len(serdeser)] assert serdeser == rawdata[:len(serdeser)]
assert len(rawdata) - len(serdeser) < 4 assert len(rawdata) - len(serdeser) < 4
assert all(x == 0 for x in rawdata[len(serdeser):]) assert all(x == 0 for x in rawdata[len(serdeser):])
@ -227,6 +231,7 @@ def test_deserializer() -> None:
"""Test deserializer.""" """Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2]) msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2]) assert msg == deserialize(*MSG_POLY[:2])
assert isinstance(msg, Polygon)
assert len(msg.points) == 2 assert len(msg.points) == 2
assert msg.points[0].x == 1 assert msg.points[0].x == 1
assert msg.points[0].y == 2 assert msg.points[0].y == 2
@ -237,6 +242,7 @@ def test_deserializer() -> None:
msg = deserialize_cdr(*MSG_MAGN[:2]) msg = deserialize_cdr(*MSG_MAGN[:2])
assert msg == deserialize(*MSG_MAGN[:2]) assert msg == deserialize(*MSG_MAGN[:2])
assert isinstance(msg, MagneticField)
assert 'MagneticField' in repr(msg) assert 'MagneticField' in repr(msg)
assert msg.header.stamp.sec == 708 assert msg.header.stamp.sec == 708
assert msg.header.stamp.nanosec == 256 assert msg.header.stamp.nanosec == 256
@ -248,6 +254,7 @@ def test_deserializer() -> None:
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2]) msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2]) assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
assert isinstance(msg_big, MagneticField)
assert msg.magnetic_field == msg_big.magnetic_field assert msg.magnetic_field == msg_big.magnetic_field
@ -285,7 +292,7 @@ def test_serializer_errors() -> None:
class Foo: # pylint: disable=too-few-public-methods class Foo: # pylint: disable=too-few-public-methods
"""Dummy class.""" """Dummy class."""
coef = numpy.array([1, 2, 3, 4]) coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4])
msg = Foo() msg = Foo()
ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True) ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
@ -376,7 +383,8 @@ def test_custom_type() -> None:
def test_ros1_to_cdr() -> None: def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion.""" """Test ROS1 to CDR conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = ( msg_cdr = (
b'\x00\x01\x00\x00' b'\x00\x01\x00\x00'
b'\x01\x00' b'\x01\x00'
@ -386,7 +394,8 @@ def test_ros1_to_cdr() -> None:
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = ( msg_cdr = (
b'\x00\x01\x00\x00' b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00' b'\x02\x00\x00\x00X\x00'
@ -399,7 +408,8 @@ def test_ros1_to_cdr() -> None:
def test_cdr_to_ros1() -> None: def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion.""" """Test CDR to ROS1 conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64'))) register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = ( msg_cdr = (
b'\x00\x01\x00\x00' b'\x00\x01\x00\x00'
b'\x01\x00' b'\x01\x00'
@ -409,7 +419,8 @@ def test_cdr_to_ros1() -> None:
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/static_16_64') == msg_ros assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/static_16_64') == msg_ros
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64'))) register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02') msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = ( msg_cdr = (
b'\x00\x01\x00\x00' b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00' b'\x02\x00\x00\x00X\x00'
@ -418,7 +429,7 @@ def test_cdr_to_ros1() -> None:
) )
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/dynamic_s_64') == msg_ros assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/dynamic_s_64') == msg_ros
header = std_msgs__msg__Header(stamp=builtin_interfaces__msg__Time(42, 666), frame_id='frame') header = Header(stamp=Time(42, 666), frame_id='frame')
msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header') msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header')
assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame' assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame'
@ -426,7 +437,6 @@ def test_cdr_to_ros1() -> None:
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_padding_empty_sequence() -> None: def test_padding_empty_sequence() -> None:
"""Test empty sequences do not add item padding.""" """Test empty sequences do not add item padding."""
# pylint: disable=protected-access
register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b'))) register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b')))
su64_b = get_msgdef('test_msgs/msg/su64_b').cls su64_b = get_msgdef('test_msgs/msg/su64_b').cls
@ -446,7 +456,6 @@ def test_padding_empty_sequence() -> None:
@pytest.mark.usefixtures('_comparable') @pytest.mark.usefixtures('_comparable')
def test_align_after_empty_sequence() -> None: def test_align_after_empty_sequence() -> None:
"""Test alignment after empty sequences.""" """Test alignment after empty sequences."""
# pylint: disable=protected-access
register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64'))) register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64')))
su64_b = get_msgdef('test_msgs/msg/su64_u64').cls su64_b = get_msgdef('test_msgs/msg/su64_u64').cls

View File

@ -21,7 +21,14 @@ from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr from rosbags.serde import deserialize_cdr
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Generator from typing import Generator, Protocol
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
"""Minimal native ROS message interface used for benchmark."""
def get_fields_and_field_types(self) -> dict[str, str]:
"""Introspect message type."""
raise NotImplementedError
class ReaderPy: # pylint: disable=too-few-public-methods class ReaderPy: # pylint: disable=too-few-public-methods
@ -42,13 +49,13 @@ class ReaderPy: # pylint: disable=too-few-public-methods
yield topic, self.typemap[topic], timestamp, data yield topic, self.typemap[topic], timestamp, data
def deserialize_py(data: bytes, msgtype: str) -> Any: def deserialize_py(data: bytes, msgtype: str) -> NativeMSG:
"""Deserialization helper for rosidl_runtime_py + rclpy.""" """Deserialization helper for rosidl_runtime_py + rclpy."""
pytype = get_message(msgtype) pytype = get_message(msgtype)
return deserialize_message(data, pytype) return deserialize_message(data, pytype) # type: ignore
def compare_msg(lite: Any, native: Any) -> None: def compare_msg(lite: object, native: NativeMSG) -> None:
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content. """Compare rosbag2 (lite) vs rosbag2_py (native) message content.
Args: Args:
@ -96,8 +103,8 @@ def compare(path: Path) -> None:
msg = deserialize_cdr(data, connection.msgtype) msg = deserialize_cdr(data, connection.msgtype)
compare_msg(msg, msg_py) compare_msg(msg, msg_py)
assert len(list(gens[0])) == 0 assert not list(gens[0])
assert len(list(gens[1])) == 0 assert not list(gens[1])
def read_deser_rosbag2_py(path: Path) -> None: def read_deser_rosbag2_py(path: Path) -> None:

View File

@ -14,6 +14,7 @@ from typing import TYPE_CHECKING
from unittest.mock import Mock from unittest.mock import Mock
import genpy # type: ignore import genpy # type: ignore
import numpy
import rosgraph_msgs.msg # type: ignore import rosgraph_msgs.msg # type: ignore
from rclpy.serialization import deserialize_message # type: ignore from rclpy.serialization import deserialize_message # type: ignore
from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore
@ -25,9 +26,15 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Generator, List, Union from typing import Generator, List, Protocol, Union, runtime_checkable
from rosbag.bag import _Connection_Info @runtime_checkable
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
"""Minimal native ROS message interface used for benchmark."""
def get_fields_and_field_types(self) -> dict[str, str]:
"""Introspect message type."""
raise NotImplementedError
class Reader: # pylint: disable=too-few-public-methods class Reader: # pylint: disable=too-few-public-methods
@ -47,7 +54,7 @@ class Reader: # pylint: disable=too-few-public-methods
yield topic, timestamp, deserialize_message(data, pytype) yield topic, timestamp, deserialize_message(data, pytype)
def fixup_ros1(conns: List[_Connection_Info]) -> None: def fixup_ros1(conns: List[rosbag.bag._Connection_Info]) -> None:
"""Monkeypatch ROS2 fieldnames onto ROS1 objects. """Monkeypatch ROS2 fieldnames onto ROS1 objects.
Args: Args:
@ -61,7 +68,6 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None): if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None):
print('Patching CameraInfo') # noqa: T001 print('Patching CameraInfo') # noqa: T001
# pylint: disable=assignment-from-no-return,too-many-function-args
cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access
cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010 cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010
cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010 cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010
@ -69,7 +75,7 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010 cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
def compare(ref: Any, msg: Any) -> None: def compare(ref: object, msg: object) -> None:
"""Compare message to its reference. """Compare message to its reference.
Args: Args:
@ -77,7 +83,7 @@ def compare(ref: Any, msg: Any) -> None:
msg: Converted ROS2 message. msg: Converted ROS2 message.
""" """
if hasattr(msg, 'get_fields_and_field_types'): if isinstance(msg, NativeMSG):
for name in msg.get_fields_and_field_types(): for name in msg.get_fields_and_field_types():
refval = getattr(ref, name) refval = getattr(ref, name)
msgval = getattr(msg, name) msgval = getattr(msg, name)
@ -87,9 +93,11 @@ def compare(ref: Any, msg: Any) -> None:
if isinstance(ref, bytes): if isinstance(ref, bytes):
assert msg.tobytes() == ref assert msg.tobytes() == ref
else: else:
assert isinstance(msg, numpy.ndarray)
assert (msg == ref).all() assert (msg == ref).all()
elif isinstance(msg, list): elif isinstance(msg, list):
assert isinstance(ref, (list, numpy.ndarray))
assert len(msg) == len(ref) assert len(msg) == len(ref)
for refitem, msgitem in zip(ref, msg): for refitem, msgitem in zip(ref, msg):
compare(refitem, msgitem) compare(refitem, msgitem)
@ -97,8 +105,9 @@ def compare(ref: Any, msg: Any) -> None:
elif isinstance(msg, str): elif isinstance(msg, str):
assert msg == ref assert msg == ref
elif isinstance(msg, float) and math.isnan(ref): elif isinstance(msg, float) and math.isnan(msg):
assert math.isnan(msg) assert isinstance(ref, float)
assert math.isnan(ref)
else: else:
assert ref == msg assert ref == msg