Update lint

This commit is contained in:
Marko Durkovic 2022-04-11 00:07:53 +02:00
parent 7315a4ab4d
commit 19f0678645
22 changed files with 215 additions and 149 deletions

View File

@ -33,7 +33,7 @@ def save_images() -> None:
frame_id=FRAMEID,
),
format='jpeg', # could also be 'png'
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore
data=numpy.fromfile(path, dtype=numpy.uint8),
)
writer.write(

View File

@ -33,7 +33,7 @@ def save_images() -> None:
frame_id=FRAMEID,
),
format='jpeg', # could also be 'png'
data=numpy.fromfile(path, dtype=numpy.uint8), # type: ignore
data=numpy.fromfile(path, dtype=numpy.uint8),
)
writer.write(

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
NATIVE_CLASSES: dict[str, Any] = {}
def to_native(msg: Any) -> Any:
def to_native(msg: Any) -> Any: # noqa: ANN401
"""Convert rosbags message to native message.
Args:

View File

@ -171,5 +171,5 @@ def convert(src: Path, dst: Optional[Path]) -> None:
raise ConverterError(f'Reading source bag: {err}') from err
except (WriterError1, WriterError2) as err:
raise ConverterError(f'Writing destination bag: {err}') from err
except Exception as err: # pylint: disable=broad-except
except Exception as err:
raise ConverterError(f'Converting rosbag: {err!r}') from err

View File

@ -106,9 +106,9 @@ class IndexData(NamedTuple):
def __eq__(self, other: object) -> bool:
"""Compare by time only."""
if not isinstance(other, IndexData): # pragma: no cover
return NotImplemented
return self.time == other[0]
if isinstance(other, IndexData):
return self.time == other[0]
return NotImplemented # pragma: no cover
def __ge__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only."""
@ -120,9 +120,9 @@ class IndexData(NamedTuple):
def __ne__(self, other: object) -> bool:
"""Compare by time only."""
if not isinstance(other, IndexData): # pragma: no cover
return NotImplemented
return self.time != other[0]
if isinstance(other, IndexData):
return self.time != other[0]
return NotImplemented # pragma: no cover
deserialize_uint8: Callable[[bytes], tuple[int]] = struct.Struct('<B').unpack # type: ignore
@ -369,7 +369,7 @@ class Reader:
self.current_chunk: tuple[int, BinaryIO] = (-1, BytesIO())
self.topics: dict[str, TopicInfo] = {}
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals,too-many-statements
def open(self) -> None: # pylint: disable=too-many-branches,too-many-locals
"""Open rosbag and read metadata."""
try:
self.bio = self.path.open('rb')
@ -394,13 +394,11 @@ class Reader:
conn_count = header.get_uint32('conn_count')
chunk_count = header.get_uint32('chunk_count')
try:
encryptor = header.get_string('encryptor')
if encryptor:
raise ValueError
except ValueError:
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
encryptor: Optional[str] = header.get_string('encryptor')
except ReaderError:
pass
encryptor = None
if encryptor:
raise ReaderError(f'Bag encryption {encryptor!r} is not supported.') from None
if index_pos == 0:
raise ReaderError('Bag is not indexed, reindex before reading.')

View File

@ -31,6 +31,7 @@ class WriterError(Exception):
@dataclass
class WriteChunk:
"""In progress chunk."""
data: BytesIO
pos: int
start: int
@ -126,7 +127,7 @@ class Header(Dict[str, Any]):
return size + 4
class Writer: # pylint: disable=too-many-instance-attributes
class Writer:
"""Rosbag1 writer.
This class implements writing of rosbag1 files in version 2.0. It should be
@ -212,7 +213,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
md5sum: Optional[str] = None,
callerid: Optional[str] = None,
latching: Optional[int] = None,
**_kw: Any,
**_kw: Any, # noqa: ANN401
) -> Connection:
"""Add a connection.

View File

@ -18,7 +18,44 @@ from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
from typing import Any, Generator, Iterable, Literal, Optional, Type, TypedDict, Union
class StartingTime(TypedDict):
"""Bag starting time."""
nanoseconds_since_epoch: int
class Duration(TypedDict):
"""Bag starting time."""
nanoseconds: int
class TopicMetadata(TypedDict):
"""Topic metadata."""
name: str
type: str
serialization_format: str
offered_qos_profiles: str
class TopicWithMessageCount(TypedDict):
"""Topic with message count."""
message_count: int
topic_metadata: TopicMetadata
class Metadata(TypedDict):
"""Rosbag2 metadata file."""
version: int
storage_identifier: str
relative_file_paths: list[str]
starting_time: StartingTime
duration: Duration
message_count: int
compression_format: str
compression_mode: str
topics_with_message_count: list[TopicWithMessageCount]
class ReaderError(Exception):
@ -72,13 +109,14 @@ class Reader:
Raises:
ReaderError: Bag not readable or bag metadata.
"""
path = Path(path)
self.path = Path
yamlpath = path / 'metadata.yaml'
self.path = path
self.bio = False
try:
yaml = YAML(typ='safe')
yamlpath = path / 'metadata.yaml'
dct = yaml.load(yamlpath.read_text())
except OSError as err:
raise ReaderError(f'Could not read metadata at {yamlpath}: {err}.') from None
@ -86,7 +124,7 @@ class Reader:
raise ReaderError(f'Could not load YAML from {yamlpath}: {exc}') from None
try:
self.metadata = dct['rosbag2_bagfile_information']
self.metadata: Metadata = dct['rosbag2_bagfile_information']
if (ver := self.metadata['version']) > 4:
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
if storageid := self.metadata['storage_identifier'] != 'sqlite3':
@ -95,8 +133,7 @@ class Reader:
)
self.paths = [path / Path(x).name for x in self.metadata['relative_file_paths']]
missing = [x for x in self.paths if not x.exists()]
if missing:
if missing := [x for x in self.paths if not x.exists()]:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
self.connections = {
@ -110,7 +147,7 @@ class Reader:
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
}
noncdr = {
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
fmt for x in self.connections.values() if (fmt := x.serialization_format) != 'cdr'
}
if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
@ -140,8 +177,7 @@ class Reader:
@property
def start_time(self) -> int:
"""Timestamp in nanoseconds of the earliest message."""
nsecs: int = self.metadata['starting_time']['nanoseconds_since_epoch']
return nsecs
return self.metadata['starting_time']['nanoseconds_since_epoch']
@property
def end_time(self) -> int:
@ -151,8 +187,7 @@ class Reader:
@property
def message_count(self) -> int:
"""Total message count."""
count: int = self.metadata['message_count']
return count
return self.metadata['message_count']
@property
def compression_format(self) -> Optional[str]:

View File

@ -18,6 +18,8 @@ if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Literal, Optional, Type, Union
from .reader import Metadata
class WriterError(Exception):
"""Writer Error."""
@ -125,7 +127,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
msgtype: str,
serialization_format: str = 'cdr',
offered_qos_profiles: str = '',
**_kw: Any,
**_kw: Any, # noqa: ANN401
) -> Connection:
"""Add a connection.
@ -218,7 +220,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compressor.copy_stream(infile, outfile)
src.unlink()
metadata = {
metadata: dict[str, Metadata] = {
'rosbag2_bagfile_information': {
'version': 4,
'storage_identifier': 'sqlite3',

View File

@ -86,22 +86,22 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
if subdesc.args.size_cdr:
for _ in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
size = (size + anext - 1) & -anext
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
size = (size + anext_before - 1) & -anext_before
lines.append(f' pos += {subdesc.args.size_cdr}')
size += subdesc.args.size_cdr
else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
lines.append(f' val = message.{fieldname}')
for idx in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos = func(pos, val[{idx}])')
is_stat = False
aligned = align_after(subdesc)
@ -117,45 +117,45 @@ def generate_getsize_cdr(fields: list[Field]) -> tuple[CDRSerSize, int]:
lines.append(' pos += 4 + len(val.encode()) + 1')
aligned = 1
else:
anext = align(subdesc)
if aligned < anext:
anext_before = align(subdesc)
if aligned < anext_before:
lines.append(f' if len(message.{fieldname}):')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
aligned = anext
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
lines.append(f' pos += len(message.{fieldname}) * {SIZEMAP[subdesc.args]}')
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
lines.append(f' val = message.{fieldname}')
if subdesc.args.size_cdr:
if aligned < anext <= anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if aligned < anext_before <= anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' for _ in val:')
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos += {subdesc.args.size_cdr}')
else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
if aligned < anext <= anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if aligned < anext_before <= anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' for item in val:')
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' pos = func(pos, item)')
aligned = align_after(subdesc)
aligned = min([aligned, 4])
is_stat = False
if fnext and aligned < (anext := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
aligned = anext
if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
is_stat = False
lines.append(' return pos')
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
return compile_lines(lines).getsize_cdr, is_stat * size
def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
@ -240,14 +240,14 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
)
for idx in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' pos = func(rawdata, pos, val[{idx}])')
aligned = align_after(subdesc)
else:
@ -272,28 +272,28 @@ def generate_serialize_cdr(fields: list[Field], endianess: str) -> CDRSer:
lines.append(f' size = len(val) * {SIZEMAP[subdesc.args]}')
if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()')
if aligned < (anext := align(subdesc)):
if aligned < (anext_before := align(subdesc)):
lines.append(' if size:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' rawdata[pos:pos + size] = val.view(numpy.uint8)')
lines.append(' pos += size')
aligned = anext
aligned = anext_before
if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc)
anext_before = align(subdesc)
lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
)
lines.append(' for item in val:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(' pos = func(rawdata, pos, item)')
aligned = align_after(subdesc)
aligned = min([4, aligned])
if fnext and aligned < (anext := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
aligned = anext
if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
lines.append(' return pos')
return compile_lines(lines).serialize_cdr # type: ignore
@ -384,13 +384,13 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
lines.append(f' pos += {size}')
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []')
for _ in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
lines.append(' value.append(obj)')
lines.append(' values.append(value)')
@ -418,9 +418,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
aligned = 1
else:
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if aligned < (anext := align(subdesc)):
if aligned < (anext_before := align(subdesc)):
lines.append(' if size:')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(
f' val = numpy.frombuffer(rawdata, '
f'dtype=numpy.{subdesc.args}, count=size, offset=pos)',
@ -429,14 +429,14 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
lines.append(' val = val.byteswap()')
lines.append(' values.append(val)')
lines.append(' pos += length')
aligned = anext
aligned = anext_before
if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc)
anext_before = align(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []')
lines.append(' for _ in range(size):')
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
lines.append(' value.append(obj)')
lines.append(' values.append(value)')
@ -444,9 +444,9 @@ def generate_deserialize_cdr(fields: list[Field], endianess: str) -> CDRDeser:
aligned = min([4, aligned])
if fnext and aligned < (anext := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
aligned = anext
if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' pos = (pos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
lines.append(' return cls(*values), pos')
return compile_lines(lines).deserialize_cdr # type: ignore

View File

@ -14,7 +14,7 @@ from .typing import Descriptor, Field, Msgdef
from .utils import Valtype
if TYPE_CHECKING:
from typing import Any
from rosbags.typesys.base import Fielddesc
MSGDEFCACHE: dict[str, Msgdef] = {}
@ -38,14 +38,18 @@ def get_msgdef(typename: str) -> Msgdef:
if typename not in MSGDEFCACHE:
entries = types.FIELDDEFS[typename][1]
def fixup(entry: Any) -> Descriptor:
def fixup(entry: Fielddesc) -> Descriptor:
if entry[0] == Valtype.BASE:
assert isinstance(entry[1], str)
return Descriptor(Valtype.BASE, entry[1])
if entry[0] == Valtype.MESSAGE:
assert isinstance(entry[1], str)
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
if entry[0] == Valtype.ARRAY:
assert not isinstance(entry[1][0], str)
return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
if entry[0] == Valtype.SEQUENCE:
assert not isinstance(entry[1][0], str)
return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
raise SerdeError( # pragma: no cover
f'Unknown field type {entry[0]!r} encountered.',

View File

@ -18,7 +18,7 @@ from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING:
from typing import Union # pylint: disable=ungrouped-imports
from typing import Union
from .typing import Bitcvt, BitcvtSize
@ -114,13 +114,13 @@ def generate_ros1_to_cdr(
aligned = SIZEMAP[subdesc.args]
if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(length):
if anext > anext_after:
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = anext_after
else:
@ -150,30 +150,30 @@ def generate_ros1_to_cdr(
lines.append(' opos += length')
aligned = 1
else:
if aligned < (anext := align(subdesc)):
if aligned < (anext_before := align(subdesc)):
lines.append(' if size:')
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if copy:
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
lines.append(' ipos += length')
lines.append(' opos += length')
aligned = anext
aligned = anext_before
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
lines.append(' for _ in range(size):')
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = align_after(subdesc)
aligned = min([aligned, 4])
if fnext and aligned < (anext := align(fnext.descriptor)):
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
aligned = anext
if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' opos = (opos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) # type: ignore
@ -270,13 +270,13 @@ def generate_cdr_to_ros1(
aligned = SIZEMAP[subdesc.args]
if subdesc.valtype == Valtype.MESSAGE:
anext = align(subdesc)
anext_before = align(subdesc)
anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(length):
if anext > anext_after:
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
if anext_before > anext_after:
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = anext_after
else:
@ -304,30 +304,30 @@ def generate_cdr_to_ros1(
lines.append(' opos += length')
aligned = 1
else:
if aligned < (anext := align(subdesc)):
if aligned < (anext_before := align(subdesc)):
lines.append(' if size:')
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(f' length = size * {SIZEMAP[subdesc.args]}')
if copy:
lines.append(' output[opos:opos + length] = input[ipos:ipos + length]')
lines.append(' ipos += length')
lines.append(' opos += length')
aligned = anext
aligned = anext_before
else:
assert subdesc.valtype == Valtype.MESSAGE
anext = align(subdesc)
anext_before = align(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
lines.append(' for _ in range(size):')
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
lines.append(' ipos, opos = func(input, ipos, output, opos)')
aligned = align_after(subdesc)
aligned = min([aligned, 4])
if fnext and aligned < (anext := align(fnext.descriptor)):
lines.append(f' ipos = (ipos + {anext} - 1) & -{anext}')
aligned = anext
if fnext and aligned < (anext_before := align(fnext.descriptor)):
lines.append(f' ipos = (ipos + {anext_before} - 1) & -{anext_before}')
aligned = anext_before
lines.append(' return ipos, opos')
return getattr(compile_lines(lines), funcname) # type: ignore

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from typing import Any
def deserialize_cdr(rawdata: bytes, typename: str) -> Any:
def deserialize_cdr(rawdata: bytes, typename: str) -> Any: # noqa: ANN401
"""Deserialize raw data into a message object.
Args:
@ -35,7 +35,7 @@ def deserialize_cdr(rawdata: bytes, typename: str) -> Any:
def serialize_cdr(
message: Any,
message: object,
typename: str,
little_endian: bool = sys.byteorder == 'little',
) -> memoryview:

View File

@ -13,8 +13,8 @@ if TYPE_CHECKING:
BitcvtSize = Callable[[bytes, int, None, int], Tuple[int, int]]
CDRDeser = Callable[[bytes, int, type], Tuple[Any, int]]
CDRSer = Callable[[bytes, int, type], int]
CDRSerSize = Callable[[int, type], int]
CDRSer = Callable[[bytes, int, object], int]
CDRSerSize = Callable[[int, object], int]
class Descriptor(NamedTuple):

View File

@ -68,5 +68,5 @@ def parse_message_definition(visitor: Visitor, text: str) -> Typesdict:
npos, trees = rule.parse(text, pos)
assert npos == len(text), f'Could not parse: {text!r}'
return visitor.visit(trees) # type: ignore
except Exception as err: # pylint: disable=broad-except
except Exception as err:
raise TypesysError(f'Could not parse: {text!r}') from err

View File

@ -31,9 +31,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
"""
if desc[0] == Nodetype.BASE:
if match := INTLIKE.match(desc[1]): # type: ignore
return match.group(1)
return 'str'
assert isinstance(desc[1], str)
return match.group(1) if (match := INTLIKE.match(desc[1])) else 'str'
if desc[0] == Nodetype.NAME:
assert isinstance(desc[1], str)
@ -43,7 +42,8 @@ def get_typehint(desc: tuple[int, Union[str, tuple[tuple[int, str], Optional[int
if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
return f'list[{get_typehint(sub)}]' # type: ignore
assert isinstance(sub, tuple)
return f'list[{get_typehint(sub)}]'
def generate_python_code(typs: Typesdict) -> str:
@ -142,6 +142,7 @@ def register_types(typs: Typesdict) -> None:
Raises:
TypesysError: Type already present with different definition.
"""
code = generate_python_code(typs)
name = 'rosbags.usertypes'
@ -150,7 +151,7 @@ def register_types(typs: Typesdict) -> None:
module = module_from_spec(spec)
sys.modules[name] = module
exec(code, module.__dict__) # pylint: disable=exec-used
fielddefs: Typesdict = module.FIELDDEFS # type: ignore
fielddefs: Typesdict = module.FIELDDEFS
for name, (_, fields) in fielddefs.items():
if name == 'std_msgs/msg/Header':

View File

@ -117,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args]
@ -297,7 +297,7 @@ def serialize_message(
rawdata: memoryview,
bmap: BasetypeMap,
pos: int,
message: Any,
message: object,
msgdef: Msgdef,
) -> int:
"""Serialize a message.
@ -369,7 +369,7 @@ def get_array_size(desc: Descriptor, val: Array, size: int) -> int:
raise SerdeError(f'Nested arrays {desc!r} are not supported.') # pragma: no cover
def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
def get_size(message: object, msgdef: Msgdef, size: int = 0) -> int:
"""Calculate size of serialzied message.
Args:
@ -413,7 +413,7 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
def serialize(
message: Any,
message: object,
typename: str,
little_endian: bool = sys.byteorder == 'little',
) -> memoryview:

View File

@ -38,6 +38,6 @@ def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
rconnection, _, raw = next(gen)
assert rconnection == wconnection
msg = deserialize_cdr(raw, rconnection.msgtype)
assert msg.data == Foo.data
assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration):
next(gen)

View File

@ -39,6 +39,6 @@ def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> N
gen = rbag.messages()
connection, _, raw = next(gen)
msg = deserialize_cdr(ros1_to_cdr(raw, connection.msgtype), connection.msgtype)
assert msg.data == Foo.data
assert getattr(msg, 'data', None) == Foo.data
with pytest.raises(StopIteration):
next(gen)

View File

@ -13,7 +13,10 @@ import pytest
from rosbags.serde import SerdeError, cdr_to_ros1, deserialize_cdr, ros1_to_cdr, serialize_cdr
from rosbags.serde.messages import get_msgdef
from rosbags.typesys import get_types_from_msg, register_types
from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__Header
from rosbags.typesys.types import builtin_interfaces__msg__Time as Time
from rosbags.typesys.types import geometry_msgs__msg__Polygon as Polygon
from rosbags.typesys.types import sensor_msgs__msg__MagneticField as MagneticField
from rosbags.typesys.types import std_msgs__msg__Header as Header
from .cdr import deserialize, serialize
@ -184,6 +187,7 @@ def _comparable() -> Generator[None, None, None]:
Notes:
This solution is necessary as numpy.ndarray is not directly patchable.
"""
frombuffer = numpy.frombuffer
@ -195,16 +199,16 @@ def _comparable() -> Generator[None, None, None]:
class CNDArray(MagicMock):
"""Mock ndarray."""
def __init__(self, *args: Any, **kwargs: Any):
def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
super().__init__(*args, **kwargs)
self.__eq__ = arreq # type: ignore
def byteswap(self, *args: Any) -> 'CNDArray':
def byteswap(self, *args: Any) -> CNDArray: # noqa: ANN401
"""Wrap return value also in mock."""
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray: # noqa: ANN401
return CNDArray(wraps=frombuffer(*args, **kwargs))
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield
@ -217,7 +221,7 @@ def test_serde(message: tuple[bytes, str, bool]) -> None:
serdeser = serialize_cdr(deserialize_cdr(rawdata, typ), typ, is_little)
assert serdeser == serialize(deserialize(rawdata, typ), typ, is_little)
assert serdeser == rawdata[0:len(serdeser)]
assert serdeser == rawdata[:len(serdeser)]
assert len(rawdata) - len(serdeser) < 4
assert all(x == 0 for x in rawdata[len(serdeser):])
@ -227,6 +231,7 @@ def test_deserializer() -> None:
"""Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2])
assert isinstance(msg, Polygon)
assert len(msg.points) == 2
assert msg.points[0].x == 1
assert msg.points[0].y == 2
@ -237,6 +242,7 @@ def test_deserializer() -> None:
msg = deserialize_cdr(*MSG_MAGN[:2])
assert msg == deserialize(*MSG_MAGN[:2])
assert isinstance(msg, MagneticField)
assert 'MagneticField' in repr(msg)
assert msg.header.stamp.sec == 708
assert msg.header.stamp.nanosec == 256
@ -248,6 +254,7 @@ def test_deserializer() -> None:
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
assert isinstance(msg_big, MagneticField)
assert msg.magnetic_field == msg_big.magnetic_field
@ -285,7 +292,7 @@ def test_serializer_errors() -> None:
class Foo: # pylint: disable=too-few-public-methods
"""Dummy class."""
coef = numpy.array([1, 2, 3, 4])
coef: numpy.ndarray[Any, numpy.dtype[numpy.int_]] = numpy.array([1, 2, 3, 4])
msg = Foo()
ret = serialize_cdr(msg, 'shape_msgs/msg/Plane', True)
@ -376,7 +383,8 @@ def test_custom_type() -> None:
def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x01\x00'
@ -386,7 +394,8 @@ def test_ros1_to_cdr() -> None:
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/static_16_64') == msg_cdr
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00'
@ -399,7 +408,8 @@ def test_ros1_to_cdr() -> None:
def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_ros = (b'\x01\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x01\x00'
@ -409,7 +419,8 @@ def test_cdr_to_ros1() -> None:
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/static_16_64') == msg_ros
register_types(dict(get_types_from_msg(DYNAMIC_S_64, 'test_msgs/msg/dynamic_s_64')))
msg_ros = (b'\x01\x00\x00\x00X' b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_ros = (b'\x01\x00\x00\x00X'
b'\x00\x00\x00\x00\x00\x00\x00\x02')
msg_cdr = (
b'\x00\x01\x00\x00'
b'\x02\x00\x00\x00X\x00'
@ -418,7 +429,7 @@ def test_cdr_to_ros1() -> None:
)
assert cdr_to_ros1(msg_cdr, 'test_msgs/msg/dynamic_s_64') == msg_ros
header = std_msgs__msg__Header(stamp=builtin_interfaces__msg__Time(42, 666), frame_id='frame')
header = Header(stamp=Time(42, 666), frame_id='frame')
msg_ros = cdr_to_ros1(serialize_cdr(header, 'std_msgs/msg/Header'), 'std_msgs/msg/Header')
assert msg_ros == b'\x00\x00\x00\x00*\x00\x00\x00\x9a\x02\x00\x00\x05\x00\x00\x00frame'
@ -426,7 +437,6 @@ def test_cdr_to_ros1() -> None:
@pytest.mark.usefixtures('_comparable')
def test_padding_empty_sequence() -> None:
"""Test empty sequences do not add item padding."""
# pylint: disable=protected-access
register_types(dict(get_types_from_msg(SU64_B, 'test_msgs/msg/su64_b')))
su64_b = get_msgdef('test_msgs/msg/su64_b').cls
@ -446,7 +456,6 @@ def test_padding_empty_sequence() -> None:
@pytest.mark.usefixtures('_comparable')
def test_align_after_empty_sequence() -> None:
"""Test alignment after empty sequences."""
# pylint: disable=protected-access
register_types(dict(get_types_from_msg(SU64_U64, 'test_msgs/msg/su64_u64')))
su64_b = get_msgdef('test_msgs/msg/su64_u64').cls

View File

@ -49,7 +49,7 @@ def test_add_connection(tmp_path: Path) -> None:
with Writer(path) as writer:
res = writer.add_connection('/foo', 'test_msgs/msg/Test', 'MESSAGE_DEFINITION', 'HASH')
assert res.cid == 0
assert res.cid == 0
data = path.read_bytes()
assert data.count(b'MESSAGE_DEFINITION') == 2
assert data.count(b'HASH') == 2
@ -57,7 +57,7 @@ def test_add_connection(tmp_path: Path) -> None:
with Writer(path) as writer:
res = writer.add_connection('/foo', 'std_msgs/msg/Int8')
assert res.cid == 0
assert res.cid == 0
data = path.read_bytes()
assert data.count(b'int8 data') == 2
assert data.count(b'27ffa0c9c4b8fb8492252bcad9e5c57b') == 2
@ -85,7 +85,7 @@ def test_add_connection(tmp_path: Path) -> None:
'HASH',
latching=1,
)
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
def test_write_errors(tmp_path: Path) -> None:

View File

@ -21,7 +21,14 @@ from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr
if TYPE_CHECKING:
from typing import Any, Generator
from typing import Generator, Protocol
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
"""Minimal native ROS message interface used for benchmark."""
def get_fields_and_field_types(self) -> dict[str, str]:
"""Introspect message type."""
raise NotImplementedError
class ReaderPy: # pylint: disable=too-few-public-methods
@ -42,13 +49,13 @@ class ReaderPy: # pylint: disable=too-few-public-methods
yield topic, self.typemap[topic], timestamp, data
def deserialize_py(data: bytes, msgtype: str) -> Any:
def deserialize_py(data: bytes, msgtype: str) -> NativeMSG:
"""Deserialization helper for rosidl_runtime_py + rclpy."""
pytype = get_message(msgtype)
return deserialize_message(data, pytype)
return deserialize_message(data, pytype) # type: ignore
def compare_msg(lite: Any, native: Any) -> None:
def compare_msg(lite: object, native: NativeMSG) -> None:
"""Compare rosbag2 (lite) vs rosbag2_py (native) message content.
Args:
@ -96,8 +103,8 @@ def compare(path: Path) -> None:
msg = deserialize_cdr(data, connection.msgtype)
compare_msg(msg, msg_py)
assert len(list(gens[0])) == 0
assert len(list(gens[1])) == 0
assert not list(gens[0])
assert not list(gens[1])
def read_deser_rosbag2_py(path: Path) -> None:

View File

@ -14,6 +14,7 @@ from typing import TYPE_CHECKING
from unittest.mock import Mock
import genpy # type: ignore
import numpy
import rosgraph_msgs.msg # type: ignore
from rclpy.serialization import deserialize_message # type: ignore
from rosbag2_py import ConverterOptions, SequentialReader, StorageOptions # type: ignore
@ -25,9 +26,15 @@ rosgraph_msgs.msg.TopicStatistics = Mock()
import rosbag.bag # type:ignore # noqa: E402 pylint: disable=wrong-import-position
if TYPE_CHECKING:
from typing import Any, Generator, List, Union
from typing import Generator, List, Protocol, Union, runtime_checkable
from rosbag.bag import _Connection_Info
@runtime_checkable
class NativeMSG(Protocol): # pylint: disable=too-few-public-methods
"""Minimal native ROS message interface used for benchmark."""
def get_fields_and_field_types(self) -> dict[str, str]:
"""Introspect message type."""
raise NotImplementedError
class Reader: # pylint: disable=too-few-public-methods
@ -47,7 +54,7 @@ class Reader: # pylint: disable=too-few-public-methods
yield topic, timestamp, deserialize_message(data, pytype)
def fixup_ros1(conns: List[_Connection_Info]) -> None:
def fixup_ros1(conns: List[rosbag.bag._Connection_Info]) -> None:
"""Monkeypatch ROS2 fieldnames onto ROS1 objects.
Args:
@ -61,7 +68,6 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
if conn := next((x for x in conns if x.datatype == 'sensor_msgs/CameraInfo'), None):
print('Patching CameraInfo') # noqa: T001
# pylint: disable=assignment-from-no-return,too-many-function-args
cls = rosbag.bag._get_message_type(conn) # pylint: disable=protected-access
cls.d = property(lambda x: x.D, lambda x, y: setattr(x, 'D', y)) # noqa: B010
cls.k = property(lambda x: x.K, lambda x, y: setattr(x, 'K', y)) # noqa: B010
@ -69,7 +75,7 @@ def fixup_ros1(conns: List[_Connection_Info]) -> None:
cls.p = property(lambda x: x.P, lambda x, y: setattr(x, 'P', y)) # noqa: B010
def compare(ref: Any, msg: Any) -> None:
def compare(ref: object, msg: object) -> None:
"""Compare message to its reference.
Args:
@ -77,7 +83,7 @@ def compare(ref: Any, msg: Any) -> None:
msg: Converted ROS2 message.
"""
if hasattr(msg, 'get_fields_and_field_types'):
if isinstance(msg, NativeMSG):
for name in msg.get_fields_and_field_types():
refval = getattr(ref, name)
msgval = getattr(msg, name)
@ -87,9 +93,11 @@ def compare(ref: Any, msg: Any) -> None:
if isinstance(ref, bytes):
assert msg.tobytes() == ref
else:
assert isinstance(msg, numpy.ndarray)
assert (msg == ref).all()
elif isinstance(msg, list):
assert isinstance(ref, (list, numpy.ndarray))
assert len(msg) == len(ref)
for refitem, msgitem in zip(ref, msg):
compare(refitem, msgitem)
@ -97,8 +105,9 @@ def compare(ref: Any, msg: Any) -> None:
elif isinstance(msg, str):
assert msg == ref
elif isinstance(msg, float) and math.isnan(ref):
assert math.isnan(msg)
elif isinstance(msg, float) and math.isnan(msg):
assert isinstance(ref, float)
assert math.isnan(ref)
else:
assert ref == msg