Update lint
This commit is contained in:
parent
7315a4ab4d
commit
19f0678645
@ -33,7 +33,7 @@ def save_images() -> None:
|
||||
frame_id=FRAMEID,
|
||||
),
|
||||
format='jpeg', # could also be 'png'
|
||||
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore
|
||||
data=numpy.fromfile(path, dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
writer.write(
|
||||
|
||||
@ -33,7 +33,7 @@ def save_images() -> None:
|
||||
frame_id=FRAMEID,
|
||||
),
|
||||
format='jpeg', # could also be 'png'
|
||||
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore
|
||||
data=numpy.fromfile(path, dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
writer.write(
|
||||
|
||||
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
NATIVE_CLASSES: dict[str, Any] = {}
|
||||
|
||||
|
||||
def to_native(msg: Any) -> Any:
|
||||
def to_native(msg: Any) -> Any: # noqa: ANN401
|
||||
"""Convert rosbags message to native message.
|
||||
|
||||
Args:
|
||||
|
||||
@ -171,5 +171,5 @@ def convert(src: Path, dst: Optional[Path]) -> None:
|
||||
raise ConverterError(f'Reading source bag: {err}') from err
|
||||
except (WriterError1, WriterError2) as err:
|
||||
raise ConverterError(f'Writing destination bag: {err}') from err
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
except Exception as err:
|
||||
raise ConverterError(f'Converting rosbag: {err!r}') from err
|
||||
|
||||
@ -106,9 +106,9 @@ class IndexData(NamedTuple):
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare by time only."""
|
||||
if not isinstance(other, IndexData): # pragma: no cover
|
||||
return NotImplemented
|
||||
return self.time == other[0]
|
||||
if isinstance(other, IndexData):
|
||||
return self.time == other[0]
|
||||
return NotImplemented # pragma: no cover
|
||||
|
||||
def __ge__(self, other: tuple[int, ...]) -> bool:
|
||||
"""Compare by time only."""
|
||||
@ -120,9 +120,9 @@ class IndexData(NamedTuple):
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
"""Compare by time only."""
|
||||
if not isinstance(other, IndexData): # pragma: no cover
|
||||
return NotImplemented
|
||||
return self.time != other[0]
|
||||
if isinstance(other, IndexData):
|
||||
return self.time != other[0]
|
||||
return NotImplemented # pragma: no cover
|
||||
|
||||
|
||||
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
|
||||
@ -369,7 +369,7 @@ class Reader:
|
||||
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
|
||||
self.topics: dict[str, TopicInfo] = {}
|
||||
|
||||
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
|
||||
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals
|
||||
"""Open rosbag and read metadata."""
|
||||
try:
|
||||
self.bio = self.path.open('rb')
|
||||
@ -394,13 +394,11 @@ class Reader:
|
||||
conn_count = header.get_uint32('conn_count')
|
||||
chunk_count = header.get_uint32('chunk_count')
|
||||
try:
|
||||
encryptor = header.get_string('encryptor')
|
||||
if encryptor:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
|
||||
encryptor: Optional[str] = header.get_string('encryptor')
|
||||
except ReaderError:
|
||||
pass
|
||||
encryptor = None
|
||||
if encryptor:
|
||||
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
|
||||
|
||||
if index_pos == 0:
|
||||
raise ReaderError('Bag is not indexed, reindex before reading.')
|
||||
|
||||
@ -31,6 +31,7 @@ class WriterError(Exception):
|
||||
@dataclass
|
||||
class WriteChunk:
|
||||
"""In progress chunk."""
|
||||
|
||||
data: BytesIO
|
||||
pos: int
|
||||
start: int
|
||||
@ -126,7 +127,7 @@ class Header(Dict[str, Any]):
|
||||
return size + 4
|
||||
|
||||
|
||||
class Writer: # pylint: disable=too-many-instance-attributes
|
||||
class Writer:
|
||||
"""Rosbag1 writer.
|
||||
|
||||
This class implements writing of rosbag1 files in version 2.0. It should be
|
||||
@ -212,7 +213,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
md5sum: Optional[str] = None,
|
||||
callerid: Optional[str] = None,
|
||||
latching: Optional[int] = None,
|
||||
**_kw: Any,
|
||||
**_kw: Any, # noqa: ANN401
|
||||
) -> Connection:
|
||||
"""Add a connection.
|
||||
|
||||
|
||||
@ -18,7 +18,44 @@ from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, TypedDict, Union
|
||||
|
||||
class StartingTime(TypedDict):
|
||||
"""Bag starting time."""
|
||||
|
||||
nanoseconds_since_epoch: int
|
||||
|
||||
class Duration(TypedDict):
|
||||
"""Bag starting time."""
|
||||
|
||||
nanoseconds: int
|
||||
|
||||
class TopicMetadata(TypedDict):
|
||||
"""Topic metadata."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
serialization_format: str
|
||||
offered_qos_profiles: str
|
||||
|
||||
class TopicWithMessageCount(TypedDict):
|
||||
"""Topic with message count."""
|
||||
|
||||
message_count: int
|
||||
topic_metadata: TopicMetadata
|
||||
|
||||
class Metadata(TypedDict):
|
||||
"""Rosbag2 metadata file."""
|
||||
|
||||
version: int
|
||||
storage_identifier: str
|
||||
relative_file_paths: list[str]
|
||||
starting_time: StartingTime
|
||||
duration: Duration
|
||||
message_count: int
|
||||
compression_format: str
|
||||
compression_mode: str
|
||||
topics_with_message_count: list[TopicWithMessageCount]
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
@ -72,13 +109,14 @@ class Reader:
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not readable or bag metadata.
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
self.path = Path
|
||||
yamlpath = path / 'metadata.yaml'
|
||||
self.path = path
|
||||
self.bio = False
|
||||
try:
|
||||
yaml = YAML(typ='safe')
|
||||
yamlpath = path / 'metadata.yaml'
|
||||
dct = yaml.load(yamlpath.read_text())
|
||||
except OSError as err:
|
||||
raise ReaderError(f'Could not read metadata at {yamlpath}: {err}.') from None
|
||||
@ -86,7 +124,7 @@ class Reader:
|
||||
raise ReaderError(f'Could not load YAML from {yamlpath}: {exc}') from None
|
||||
|
||||
try:
|
||||
self.metadata = dct['rosbag2_bagfile_information']
|
||||
self.metadata: Metadata = dct['rosbag2_bagfile_information']
|
||||
if (ver := self.metadata['version']) > 4:
|
||||
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
|
||||
if storageid := self.metadata['storage_identifier'] != 'sqlite3':
|
||||
@ -95,8 +133,7 @@ class Reader:
|
||||
)
|
||||
|
||||
self.paths = [path / Path(x).name for x in self.metadata['relative_file_paths']]
|
||||
missing = [x for x in self.paths if not x.exists()]
|
||||
if missing:
|
||||
if missing := [x for x in self.paths if not x.exists()]:
|
||||
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
|
||||
|
||||
self.connections = {
|
||||
@ -110,7 +147,7 @@ class Reader:
|
||||
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
|
||||
}
|
||||
noncdr = {
|
||||
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
|
||||
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
|
||||
}
|
||||
if noncdr:
|
||||
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
|
||||
@ -140,8 +177,7 @@ class Reader:
|
||||
@property
|
||||
def start_time(self) -> int:
|
||||
"""Timestamp in nanoseconds of the earliest message."""
|
||||
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
return nsecs
|
||||
return self.metadata['starting_time']['nanoseconds_since_epoch']
|
||||
|
||||
@property
|
||||
def end_time(self) -> int:
|
||||
@ -151,8 +187,7 @@ class Reader:
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total message count."""
|
||||
count: int = self.metadata['message_count']
|
||||
return count
|
||||
return self.metadata['message_count']
|
||||
|
||||
@property
|
||||
def compression_format(self) -> Optional[str]:
|
||||
|
||||
@ -18,6 +18,8 @@ if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Literal, Optional, Type, Union
|
||||
|
||||
from .reader import Metadata
|
||||
|
||||
|
||||
class WriterError(Exception):
|
||||
"""Writer Error."""
|
||||
@ -125,7 +127,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
msgtype: str,
|
||||
serialization_format: str = 'cdr',
|
||||
offered_qos_profiles: str = '',
|
||||
**_kw: Any,
|
||||
**_kw: Any, # noqa: ANN401
|
||||
) -> Connection:
|
||||
"""Add a connection.
|
||||
|
||||
@ -218,7 +220,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compressor.copy_stream(infile, outfile)
|
||||
src.unlink()
|
||||
|
||||
metadata = {
|
||||
metadata: dict[str, Metadata] = {
|
||||
'rosbag2_bagfile_information': {
|
||||
'version': 4,
|
||||
'storage_identifier': 'sqlite3',
|
||||
|
||||
@ -86,22 +86,22 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
|
||||
if subdesc.args.size_cdr:
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
size = (size + anext - 1) & -anext
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
size = (size + anext_before - 1) & -anext_before
|
||||
lines.append(f' pos += {subdesc.args.size_cdr}')
|
||||
size += subdesc.args.size_cdr
|
||||
else:
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
|
||||
lines.append(f' val = message.{fieldname}')
|
||||
for idx in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' pos = func(pos, val[{idx}])')
|
||||
is_stat = False
|
||||
aligned = align_after(subdesc)
|
||||
@ -117,45 +117,45 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
|
||||
lines.append(' pos += 4 + len(val.encode()) + 1')
|
||||
aligned = 1
|
||||
else:
|
||||
anext = align(subdesc)
|
||||
if aligned < anext:
|
||||
anext_before = align(subdesc)
|
||||
if aligned < anext_before:
|
||||
lines.append(f' if len(message.{fieldname}):')
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}')
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
lines.append(f' val = message.{fieldname}')
|
||||
if subdesc.args.size_cdr:
|
||||
if aligned < anext <= anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if aligned < anext_before <= anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' for _ in val:')
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' pos += {subdesc.args.size_cdr}')
|
||||
|
||||
else:
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
|
||||
if aligned < anext <= anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if aligned < anext_before <= anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' for item in val:')
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' pos = func(pos, item)')
|
||||
aligned = align_after(subdesc)
|
||||
|
||||
aligned = min([aligned, 4])
|
||||
is_stat = False
|
||||
|
||||
if fnext and aligned < (anext := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
if fnext and aligned < (anext_before := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
is_stat = False
|
||||
lines.append(' return pos')
|
||||
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
|
||||
return compile_lines(lines).getsize_cdr, is_stat * size
|
||||
|
||||
|
||||
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||
@ -240,14 +240,14 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
lines.append(
|
||||
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
|
||||
)
|
||||
for idx in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' pos = func(rawdata, pos, val[{idx}])')
|
||||
aligned = align_after(subdesc)
|
||||
else:
|
||||
@ -272,28 +272,28 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
|
||||
lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}')
|
||||
if (endianess == 'le') != (sys.byteorder == 'little'):
|
||||
lines.append(' val = val.byteswap()')
|
||||
if aligned < (anext := align(subdesc)):
|
||||
if aligned < (anext_before := align(subdesc)):
|
||||
lines.append(' if size:')
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)')
|
||||
lines.append(' pos += size')
|
||||
aligned = anext
|
||||
aligned = anext_before
|
||||
|
||||
if subdesc.valtype == Valtype.MESSAGE:
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
lines.append(
|
||||
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
|
||||
)
|
||||
lines.append(' for item in val:')
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' pos = func(rawdata, pos, item)')
|
||||
aligned = align_after(subdesc)
|
||||
|
||||
aligned = min([4, aligned])
|
||||
|
||||
if fnext and aligned < (anext := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
if fnext and aligned < (anext_before := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
lines.append(' return pos')
|
||||
return compile_lines(lines).serialize_cdr # type: ignore
|
||||
|
||||
@ -384,13 +384,13 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
lines.append(f' pos += {size}')
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
|
||||
lines.append(' value = []')
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
|
||||
lines.append(' value.append(obj)')
|
||||
lines.append(' values.append(value)')
|
||||
@ -418,9 +418,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
aligned = 1
|
||||
else:
|
||||
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
|
||||
if aligned < (anext := align(subdesc)):
|
||||
if aligned < (anext_before := align(subdesc)):
|
||||
lines.append(' if size:')
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(
|
||||
f' val = numpy.frombuffer(rawdata, '
|
||||
f'dtype=numpy.{subdesc.args}, count=size, offset=pos)',
|
||||
@ -429,14 +429,14 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
lines.append(' val = val.byteswap()')
|
||||
lines.append(' values.append(val)')
|
||||
lines.append(' pos += length')
|
||||
aligned = anext
|
||||
aligned = anext_before
|
||||
|
||||
if subdesc.valtype == Valtype.MESSAGE:
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
|
||||
lines.append(' value = []')
|
||||
lines.append(' for _ in range(size):')
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
|
||||
lines.append(' value.append(obj)')
|
||||
lines.append(' values.append(value)')
|
||||
@ -444,9 +444,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
|
||||
|
||||
aligned = min([4, aligned])
|
||||
|
||||
if fnext and aligned < (anext := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
if fnext and aligned < (anext_before := align(fnext.descriptor)):
|
||||
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
|
||||
lines.append(' return cls(*values), pos')
|
||||
return compile_lines(lines).deserialize_cdr # type: ignore
|
||||
|
||||
@ -14,7 +14,7 @@ from .typing import Descriptor, Field, Msgdef
|
||||
from .utils import Valtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from rosbags.typesys.base import Fielddesc
|
||||
|
||||
MSGDEFCACHE: dict[str, Msgdef] = {}
|
||||
|
||||
@ -38,14 +38,18 @@ def get_msgdef(typename: str) -> Msgdef:
|
||||
if typename not in MSGDEFCACHE:
|
||||
entries = types.FIELDDEFS[typename][1]
|
||||
|
||||
def fixup(entry: Any) -> Descriptor:
|
||||
def fixup(entry: Fielddesc) -> Descriptor:
|
||||
if entry[0] == Valtype.BASE:
|
||||
assert isinstance(entry[1], str)
|
||||
return Descriptor(Valtype.BASE, entry[1])
|
||||
if entry[0] == Valtype.MESSAGE:
|
||||
assert isinstance(entry[1], str)
|
||||
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
|
||||
if entry[0] == Valtype.ARRAY:
|
||||
assert not isinstance(entry[1][0], str)
|
||||
return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
|
||||
if entry[0] == Valtype.SEQUENCE:
|
||||
assert not isinstance(entry[1][0], str)
|
||||
return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
|
||||
raise SerdeError( # pragma: no cover
|
||||
f'Unknown field type {entry[0]!r} encountered.',
|
||||
|
||||
@ -18,7 +18,7 @@ from .typing import Field
|
||||
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Union # pylint: disable=ungrouped-imports
|
||||
from typing import Union
|
||||
|
||||
from .typing import Bitcvt, BitcvtSize
|
||||
|
||||
@ -114,13 +114,13 @@ def generate_ros1_to_cdr(
|
||||
aligned = SIZEMAP[subdesc.args]
|
||||
|
||||
if subdesc.valtype == Valtype.MESSAGE:
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' ipos, opos = func(input, ipos, output, opos)')
|
||||
aligned = anext_after
|
||||
else:
|
||||
@ -150,30 +150,30 @@ def generate_ros1_to_cdr(
|
||||
lines.append(' opos += length')
|
||||
aligned = 1
|
||||
else:
|
||||
if aligned < (anext := align(subdesc)):
|
||||
if aligned < (anext_before := align(subdesc)):
|
||||
lines.append(' if size:')
|
||||
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
|
||||
if copy:
|
||||
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
|
||||
lines.append(' ipos += length')
|
||||
lines.append(' opos += length')
|
||||
aligned = anext
|
||||
aligned = anext_before
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
|
||||
lines.append(' for _ in range(size):')
|
||||
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' ipos, opos = func(input, ipos, output, opos)')
|
||||
aligned = align_after(subdesc)
|
||||
|
||||
aligned = min([aligned, 4])
|
||||
|
||||
if fnext and aligned < (anext := align(fnext.descriptor)):
|
||||
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
if fnext and aligned < (anext_before := align(fnext.descriptor)):
|
||||
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
@ -270,13 +270,13 @@ def generate_cdr_to_ros1(
|
||||
aligned = SIZEMAP[subdesc.args]
|
||||
|
||||
if subdesc.valtype == Valtype.MESSAGE:
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
anext_after = align_after(subdesc)
|
||||
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
|
||||
if anext_before > anext_after:
|
||||
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' ipos, opos = func(input, ipos, output, opos)')
|
||||
aligned = anext_after
|
||||
else:
|
||||
@ -304,30 +304,30 @@ def generate_cdr_to_ros1(
|
||||
lines.append(' opos += length')
|
||||
aligned = 1
|
||||
else:
|
||||
if aligned < (anext := align(subdesc)):
|
||||
if aligned < (anext_before := align(subdesc)):
|
||||
lines.append(' if size:')
|
||||
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
|
||||
if copy:
|
||||
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
|
||||
lines.append(' ipos += length')
|
||||
lines.append(' opos += length')
|
||||
aligned = anext
|
||||
aligned = anext_before
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
anext = align(subdesc)
|
||||
anext_before = align(subdesc)
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
|
||||
lines.append(' for _ in range(size):')
|
||||
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
|
||||
lines.append(' ipos, opos = func(input, ipos, output, opos)')
|
||||
aligned = align_after(subdesc)
|
||||
|
||||
aligned = min([aligned, 4])
|
||||
|
||||
if fnext and aligned < (anext := align(fnext.descriptor)):
|
||||
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
|
||||
aligned = anext
|
||||
if fnext and aligned < (anext_before := align(fnext.descriptor)):
|
||||
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
|
||||
aligned = anext_before
|
||||
|
||||
lines.append(' return ipos, opos')
|
||||
return getattr(compile_lines(lines), funcname) # type: ignore
|
||||
|
||||
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def deserialize_cdr(rawdata: bytes, typename: str) -> Any:
|
||||
def deserialize_cdr(rawdata: bytes, typename: str) -> Any: # noqa: ANN401
|
||||
"""Deserialize raw data into a message object.
|
||||
|
||||
Args:
|
||||
@ -35,7 +35,7 @@ def deserialize_cdr(rawdata: bytes, typename: str) -> Any:
|
||||
|
||||
|
||||
def serialize_cdr(
|
||||
message: Any,
|
||||
message: object,
|
||||
typename: str,
|
||||
little_endian: bool = sys.byteorder == 'little',
|
||||
) -> memoryview:
|
||||
|
||||
@ -13,8 +13,8 @@ if TYPE_CHECKING:
|
||||
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
|
||||
|
||||
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
|
||||
CDRSer = Callable[[bytes, int, type], int]
|
||||
CDRSerSize = Callable[[int, type], int]
|
||||
CDRSer = Callable[[bytes, int, object], int]
|
||||
CDRSerSize = Callable[[int, object], int]
|
||||
|
||||
|
||||
class Descriptor(NamedTuple):
|
||||
|
||||
@ -68,5 +68,5 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
|
||||
npos, trees = rule.parse(text, pos)
|
||||
assert npos == len(text), f'Could not parse: {text!r}'
|
||||
return visitor.visit(trees) # type: ignore
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
except Exception as err:
|
||||
raise TypesysError(f'Could not parse: {text!r}') from err
|
||||
|
||||
@ -31,9 +31,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
|
||||
|
||||
"""
|
||||
if desc[0] == Nodetype.BASE:
|
||||
if match := INTLIKE.match(desc[1]): # type: ignore
|
||||
return match.group(1)
|
||||
return 'str'
|
||||
assert isinstance(desc[1], str)
|
||||
return match.group(1) if (match := INTLIKE.match(desc[1])) else 'str'
|
||||
|
||||
if desc[0] == Nodetype.NAME:
|
||||
assert isinstance(desc[1], str)
|
||||
@ -43,7 +42,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
|
||||
if INTLIKE.match(sub[1]):
|
||||
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
||||
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
||||
return f'list[{get_typehint(sub)}]' # type: ignore
|
||||
assert isinstance(sub, tuple)
|
||||
return f'list[{get_typehint(sub)}]'
|
||||
|
||||
|
||||
def generate_python_code(typs: Typesdict) -> str:
|
||||
@ -142,6 +142,7 @@ def register_types(typs: Typesdict) -> None:
|
||||
|
||||
Raises:
|
||||
TypesysError: Type already present with different definition.
|
||||
|
||||
"""
|
||||
code = generate_python_code(typs)
|
||||
name = 'rosbags.usertypes'
|
||||
@ -150,7 +151,7 @@ def register_types(typs: Typesdict) -> None:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[name] = module
|
||||
exec(code, module.__dict__) # pylint: disable=exec-used
|
||||
fielddefs: Typesdict = module.FIELDDEFS # type: ignore
|
||||
fielddefs: Typesdict = module.FIELDDEFS
|
||||
|
||||
for name, (_, fields) in fielddefs.items():
|
||||
if name == 'std_msgs/msg/Header':
|
||||
|
||||
@ -117,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
|
||||
|
||||
size = SIZEMAP[desc.args]
|
||||
pos = (pos + size - 1) & -size
|
||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
|
||||
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
|
||||
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
|
||||
ndarr = ndarr.byteswap() # no inplace on readonly array
|
||||
return ndarr, pos + num * SIZEMAP[desc.args]
|
||||
@ -297,7 +297,7 @@ def serialize_message(
|
||||
rawdata: memoryview,
|
||||
bmap: BasetypeMap,
|
||||
pos: int,
|
||||
message: Any,
|
||||
message: object,
|
||||
msgdef: Msgdef,
|
||||
) -> int:
|
||||
"""Serialize a message.
|
||||
@ -369,7 +369,7 @@ def get_array_size(desc: Descriptor, val: Array, size: int) -> int:
|
||||
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
|
||||
|
||||
|
||||
def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
|
||||
def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int:
|
||||
"""Calculate size of serialzied message.
|
||||
|
||||
Args:
|
||||
@ -413,7 +413,7 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
|
||||
|
||||
|
||||
def serialize(
|
||||
message: Any,
|
||||
message: object,
|
||||
typename: str,
|
||||
little_endian: bool = sys.byteorder == 'little',
|
||||
) -> memoryview:
|
||||
|
||||
@ -38,6 +38,6 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
|
||||
rconnection, _, raw = next(gen)
|
||||
assert rconnection == wconnection
|
||||
msg = deserialize_cdr(raw, rconnection.msgtype)
|
||||
assert msg.data == Foo.data
|
||||
assert getattr(msg, 'data', None) == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
@ -39,6 +39,6 @@ def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> N
|
||||
gen = rbag.messages()
|
||||
connection, _, raw = next(gen)
|
||||
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
|
||||
assert msg.data == Foo.data
|
||||
assert getattr(msg, 'data', None) == Foo.data
|
||||
with pytest.raises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
@ -13,7 +13,10 @@ import pytest
|
||||
from rosbags.serde import SerdeError, cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
|
||||
from rosbags.serde.messages import get_msgdef
|
||||
from rosbags.typesys import get_types_from_msg, register_types
|
||||
from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__Header
|
||||
from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
|
||||
from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon
|
||||
from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField
|
||||
from rosbags.typesys.types import std_msgs__msg__Header as Header
|
||||
|
||||
from .cdr import deserialize, serialize
|
||||
|
||||
@ -184,6 +187,7 @@ def _comparable() -> Generator[None, None, None]:
|
||||
|
||||
Notes:
|
||||
This solution is necessary as numpy.ndarray is not directly patchable.
|
||||
|
||||
"""
|
||||
frombuffer = numpy.frombuffer
|
||||
|
||||
@ -195,16 +199,16 @@ def _comparable() -> Generator[None, None, None]:
|
||||
class CNDArray(MagicMock):
|
||||
"""Mock ndarray."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__eq__ = arreq # type: ignore
|
||||
|
||||
def byteswap(self, *args: Any) -> 'CNDArray':
|
||||
def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401
|
||||
"""Wrap return value also in mock."""
|
||||
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
|
||||
|
||||
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
|
||||
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
|
||||
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401
|
||||
return CNDArray(wraps=frombuffer(*args, **kwargs))
|
||||
|
||||
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
|
||||
yield
|
||||
@ -217,7 +221,7 @@ def test_serde(message: tuple[bytes, str, bool]) -> None:
|
||||
|
||||
serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little)
|
||||
assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little)
|
||||
assert serdeser == rawdata[0:len(serdeser)]
|
||||
assert serdeser == rawdata[:len(serdeser)]
|
||||
assert len(rawdata) - len(serdeser) < 4
|
||||
assert all(x == 0 for x in rawdata[len(serdeser):])
|
||||
|
||||
@ -227,6 +231,7 @@ def test_deserializer() -> None:
|
||||
"""Test deserializer."""
|
||||
msg = deserialize_cdr(*MSG_POLY[:2])
|
||||
assert msg == deserialize(*MSG_POLY[:2])
|
||||
assert isinstance(msg, Polygon)
|
||||
assert len(msg.points) == 2
|
||||
assert msg.points[0].x == 1
|
||||
assert msg.points[0].y == 2
|
||||
@ -237,6 +242,7 @@ def test_deserializer() -> None:
|
||||
|
||||
msg = deserialize_cdr(*MSG_MAGN[:2])
|
||||
assert msg == deserialize(*MSG_MAGN[:2])
|
||||
assert isinstance(msg, MagneticField)
|
||||
assert 'MagneticField' in repr(msg)
|
||||
assert msg.header.stamp.sec == 708
|
||||
assert msg.header.stamp.nanosec == 256
|
||||
@ -248,6 +254,7 @@ def test_deserializer() -> None:
|
||||
|
||||
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
|
||||
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
|
||||
assert isinstance(msg_big, MagneticField)
|
||||
assert msg.magnetic_field == msg_big.magnetic_field
|
||||
|
||||
|
||||
@ -285,7 +292,7 @@ def test_serializer_errors() -> None:
|
||||
class Foo: # pylint: disable=too-few-public-methods
|
||||
"""Dummy class."""
|
||||
|
||||
coef = numpy.array([1, 2, 3, 4])
|
||||
coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4])
|
||||
|
||||
msg = Foo()
|
||||
ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
|
||||
@ -376,7 +383,8 @@ def test_custom_type() -> None:
|
||||
def test_ros1_to_cdr() -> None:
|
||||
"""Test ROS1 to CDR conversion."""
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_ros = (b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x01\x00'
|
||||
@ -386,7 +394,8 @@ def test_ros1_to_cdr() -> None:
|
||||
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr
|
||||
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
|
||||
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_ros = (b'\x01\x00\x00\x00X'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x02\x00\x00\x00X\x00'
|
||||
@ -399,7 +408,8 @@ def test_ros1_to_cdr() -> None:
|
||||
def test_cdr_to_ros1() -> None:
|
||||
"""Test CDR to ROS1 conversion."""
|
||||
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
|
||||
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_ros = (b'\x01\x00'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x01\x00'
|
||||
@ -409,7 +419,8 @@ def test_cdr_to_ros1() -> None:
|
||||
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/static_16_64') == msg_ros
|
||||
|
||||
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
|
||||
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_ros = (b'\x01\x00\x00\x00X'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x02')
|
||||
msg_cdr = (
|
||||
b'\x00\x01\x00\x00'
|
||||
b'\x02\x00\x00\x00X\x00'
|
||||
@ -418,7 +429,7 @@ def test_cdr_to_ros1() -> None:
|
||||
)
|
||||
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/dynamic_s_64') == msg_ros
|
||||
|
||||
header = std_msgs__msg__Header(stamp=builtin_interfaces__msg__Time(42, 666), frame_id='frame')
|
||||
header = Header(stamp=Time(42, 666), frame_id='frame')
|
||||
msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header')
|
||||
assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame'
|
||||
|
||||
@ -426,7 +437,6 @@ def test_cdr_to_ros1() -> None:
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_padding_empty_sequence() -> None:
|
||||
"""Test empty sequences do not add item padding."""
|
||||
# pylint: disable=protected-access
|
||||
register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b')))
|
||||
|
||||
su64_b = get_msgdef('test_msgs/msg/su64_b').cls
|
||||
@ -446,7 +456,6 @@ def test_padding_empty_sequence() -> None:
|
||||
@pytest.mark.usefixtures('_comparable')
|
||||
def test_align_after_empty_sequence() -> None:
|
||||
"""Test alignment after empty sequences."""
|
||||
# pylint: disable=protected-access
|
||||
register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64')))
|
||||
|
||||
su64_b = get_msgdef('test_msgs/msg/su64_u64').cls
|
||||
|
||||
@ -49,7 +49,7 @@ def test_add_connection(tmp_path: Path) -> None:
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
|
||||
assert res.cid == 0
|
||||
assert res.cid == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'MESSAGE_DEFINITION') == 2
|
||||
assert data.count(b'HASH') == 2
|
||||
@ -57,7 +57,7 @@ def test_add_connection(tmp_path: Path) -> None:
|
||||
|
||||
with Writer(path) as writer:
|
||||
res = writer.add_connection('/foo', 'std_msgs/msg/Int8')
|
||||
assert res.cid == 0
|
||||
assert res.cid == 0
|
||||
data = path.read_bytes()
|
||||
assert data.count(b'int8 data') == 2
|
||||
assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2
|
||||
@ -85,7 +85,7 @@ def test_add_connection(tmp_path: Path) -> None:
|
||||
'HASH',
|
||||
latching=1,
|
||||
)
|
||||
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
||||
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
|
||||
|
||||
|
||||
def test_write_errors(tmp_path: Path) -> None:
|
||||
|
||||
@ -21,7 +21,14 @@ from rosbags.rosbag2 import Reader
|
||||
from rosbags.serde import deserialize_cdr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Generator
|
||||
from typing import Generator, Protocol
|
||||
|
||||
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
|
||||
"""Minimal native ROS message interface used for benchmark."""
|
||||
|
||||
def get_fields_and_field_types(self) -> dict[str, str]:
|
||||
"""Introspect message type."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReaderPy: # pylint: disable=too-few-public-methods
|
||||
@ -42,13 +49,13 @@ class ReaderPy: # pylint: disable=too-few-public-methods
|
||||
yield topic, self.typemap[topic], timestamp, data
|
||||
|
||||
|
||||
def deserialize_py(data: bytes, msgtype: str) -> Any:
|
||||
def deserialize_py(data: bytes, msgtype: str) -> NativeMSG:
|
||||
"""Deserialization helper for rosidl_runtime_py + rclpy."""
|
||||
pytype = get_message(msgtype)
|
||||
return deserialize_message(data, pytype)
|
||||
return deserialize_message(data, pytype) # type: ignore
|
||||
|
||||
|
||||
def compare_msg(lite: Any, native: Any) -> None:
|
||||
def compare_msg(lite: object, native: NativeMSG) -> None:
|
||||
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
|
||||
|
||||
Args:
|
||||
@ -96,8 +103,8 @@ def compare(path: Path) -> None:
|
||||
msg = deserialize_cdr(data, connection.msgtype)
|
||||
|
||||
compare_msg(msg, msg_py)
|
||||
assert len(list(gens[0])) == 0
|
||||
assert len(list(gens[1])) == 0
|
||||
assert not list(gens[0])
|
||||
assert not list(gens[1])
|
||||
|
||||
|
||||
def read_deser_rosbag2_py(path: Path) -> None:
|
||||
|
||||
@ -14,6 +14,7 @@ from typing import TYPE_CHECKING
|
||||
from unittest.mock import Mock
|
||||
|
||||
import genpy # type: ignore
|
||||
import numpy
|
||||
import rosgraph_msgs.msg # type: ignore
|
||||
from rclpy.serialization import deserialize_message # type: ignore
|
||||
from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore
|
||||
@ -25,9 +26,15 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
|
||||
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Generator, List, Union
|
||||
from typing import Generator, List, Protocol, Union, runtime_checkable
|
||||
|
||||
from rosbag.bag import _Connection_Info
|
||||
@runtime_checkable
|
||||
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
|
||||
"""Minimal native ROS message interface used for benchmark."""
|
||||
|
||||
def get_fields_and_field_types(self) -> dict[str, str]:
|
||||
"""Introspect message type."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Reader: # pylint: disable=too-few-public-methods
|
||||
@ -47,7 +54,7 @@ class Reader: # pylint: disable=too-few-public-methods
|
||||
yield topic, timestamp, deserialize_message(data, pytype)
|
||||
|
||||
|
||||
def fixup_ros1(conns: List[_Connection_Info]) -> None:
|
||||
def fixup_ros1(conns: List[rosbag.bag._Connection_Info]) -> None:
|
||||
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
|
||||
|
||||
Args:
|
||||
@ -61,7 +68,6 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
|
||||
|
||||
if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None):
|
||||
print('Patching CameraInfo') # noqa: T001
|
||||
# pylint: disable=assignment-from-no-return,too-many-function-args
|
||||
cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access
|
||||
cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010
|
||||
cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010
|
||||
@ -69,7 +75,7 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
|
||||
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
|
||||
|
||||
|
||||
def compare(ref: Any, msg: Any) -> None:
|
||||
def compare(ref: object, msg: object) -> None:
|
||||
"""Compare message to its reference.
|
||||
|
||||
Args:
|
||||
@ -77,7 +83,7 @@ def compare(ref: Any, msg: Any) -> None:
|
||||
msg: Converted ROS2 message.
|
||||
|
||||
"""
|
||||
if hasattr(msg, 'get_fields_and_field_types'):
|
||||
if isinstance(msg, NativeMSG):
|
||||
for name in msg.get_fields_and_field_types():
|
||||
refval = getattr(ref, name)
|
||||
msgval = getattr(msg, name)
|
||||
@ -87,9 +93,11 @@ def compare(ref: Any, msg: Any) -> None:
|
||||
if isinstance(ref, bytes):
|
||||
assert msg.tobytes() == ref
|
||||
else:
|
||||
assert isinstance(msg, numpy.ndarray)
|
||||
assert (msg == ref).all()
|
||||
|
||||
elif isinstance(msg, list):
|
||||
assert isinstance(ref, (list, numpy.ndarray))
|
||||
assert len(msg) == len(ref)
|
||||
for refitem, msgitem in zip(ref, msg):
|
||||
compare(refitem, msgitem)
|
||||
@ -97,8 +105,9 @@ def compare(ref: Any, msg: Any) -> None:
|
||||
elif isinstance(msg, str):
|
||||
assert msg == ref
|
||||
|
||||
elif isinstance(msg, float) and math.isnan(ref):
|
||||
assert math.isnan(msg)
|
||||
elif isinstance(msg, float) and math.isnan(msg):
|
||||
assert isinstance(ref, float)
|
||||
assert math.isnan(ref)
|
||||
|
||||
else:
|
||||
assert ref == msg
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user