Type generics and missing return types

This commit is contained in:
Marko Durkovic
2021-11-25 14:26:17 +01:00
parent ac704bd890
commit 52480e2bad
26 changed files with 263 additions and 175 deletions
+3 -2
View File
@@ -9,6 +9,7 @@ from struct import Struct, pack_into, unpack_from
from typing import TYPE_CHECKING, Dict, List, Union, cast
import numpy
from numpy.typing import NDArray
from rosbags.serde.messages import SerdeError, get_msgdef
from rosbags.serde.typing import Msgdef
@@ -116,7 +117,7 @@ def deserialize_array(rawdata: bytes, bmap: BasetypeMap, pos: int, num: int, des
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos)
ndarr = numpy.frombuffer(rawdata, dtype=desc.args, count=num, offset=pos) # type: ignore
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
ndarr = ndarr.byteswap() # no inplace on readonly array
return ndarr, pos + num * SIZEMAP[desc.args]
@@ -278,7 +279,7 @@ def serialize_array(
size = SIZEMAP[desc.args]
pos = (pos + size - 1) & -size
size *= len(val)
val = cast(numpy.ndarray, val)
val = cast(NDArray[numpy.int_], val)
if (bmap is BASETYPEMAP_LE) != (sys.byteorder == 'little'):
val = val.byteswap() # no inplace on readonly array
rawdata[pos:pos + size] = memoryview(val.tobytes())
+2 -2
View File
@@ -15,7 +15,7 @@ from rosbags.rosbag1 import ReaderError
from rosbags.rosbag2 import WriterError
def test_cliwrapper(tmp_path: Path):
def test_cliwrapper(tmp_path: Path) -> None:
"""Test cli wrapper."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'ros1.bag').write_text('')
@@ -62,7 +62,7 @@ def test_cliwrapper(tmp_path: Path):
mock_print.assert_called_with('ERROR: exc')
def test_convert(tmp_path: Path):
def test_convert(tmp_path: Path) -> None:
"""Test conversion function."""
(tmp_path / 'subdir').mkdir()
(tmp_path / 'foo.bag').write_text('')
+13 -13
View File
@@ -142,13 +142,13 @@ module test_msgs {
"""
def test_parse_empty_msg():
def test_parse_empty_msg() -> None:
"""Test msg parser with empty message."""
ret = get_types_from_msg('', 'std_msgs/msg/Empty')
assert ret == {'std_msgs/msg/Empty': ([], [])}
def test_parse_bounds_msg():
def test_parse_bounds_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_BOUNDS, 'test_msgs/msg/Foo')
assert ret == {
@@ -168,7 +168,7 @@ def test_parse_bounds_msg():
}
def test_parse_defaults_msg():
def test_parse_defaults_msg() -> None:
"""Test msg parser."""
ret = get_types_from_msg(MSG_DEFAULTS, 'test_msgs/msg/Foo')
assert ret == {
@@ -188,7 +188,7 @@ def test_parse_defaults_msg():
}
def test_parse_msg():
def test_parse_msg() -> None:
"""Test msg parser."""
with pytest.raises(TypesysError, match='Could not parse'):
get_types_from_msg('invalid', 'test_msgs/msg/Foo')
@@ -208,7 +208,7 @@ def test_parse_msg():
assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg():
def test_parse_multi_msg() -> None:
"""Test multi msg parser."""
ret = get_types_from_msg(MULTI_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 3
@@ -223,7 +223,7 @@ def test_parse_multi_msg():
assert consts == [('static', 'uint32', 42)]
def test_parse_cstring_confusion():
def test_parse_cstring_confusion() -> None:
"""Test if msg separator is confused with const string."""
ret = get_types_from_msg(CSTRING_CONFUSION_MSG, 'test_msgs/msg/Foo')
assert len(ret) == 2
@@ -235,7 +235,7 @@ def test_parse_cstring_confusion():
assert fields[1][1][1] == 'string'
def test_parse_relative_siblings_msg():
def test_parse_relative_siblings_msg() -> None:
"""Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
@@ -246,7 +246,7 @@ def test_parse_relative_siblings_msg():
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl():
def test_parse_idl() -> None:
"""Test idl parser."""
ret = get_types_from_idl(IDL_LANG)
assert ret == {}
@@ -267,21 +267,21 @@ def test_parse_idl():
assert fields[6][1][0] == Nodetype.ARRAY
def test_register_types():
def test_register_types() -> None:
"""Test type registeration."""
assert 'foo' not in FIELDDEFS
register_types({})
register_types({'foo': [[], [('b', (1, 'bool'))]]})
register_types({'foo': [[], [('b', (1, 'bool'))]]}) # type: ignore
assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': [[], []]})
register_types({'std_msgs/msg/Header': [[], []]}) # type: ignore
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[], [('x', (1, 'bool'))]]})
register_types({'foo': [[], [('x', (1, 'bool'))]]}) # type: ignore
def test_generate_msgdef():
def test_generate_msgdef() -> None:
"""Test message definition generator."""
res = generate_msgdef('std_msgs/msg/Header')
assert res == ('uint32 seq\ntime stamp\nstring frame_id\n', '2176decaecbce78abc3b96ef049fabed')
+4 -4
View File
@@ -117,7 +117,7 @@ def bag(request: SubRequest, tmp_path: Path) -> Path:
return tmp_path
def test_reader(bag: Path):
def test_reader(bag: Path) -> None:
"""Test reader and deserializer on simple bag."""
with Reader(bag) as reader:
assert reader.duration == 43
@@ -151,7 +151,7 @@ def test_reader(bag: Path):
next(gen)
def test_message_filters(bag: Path):
def test_message_filters(bag: Path) -> None:
"""Test reader filters messages."""
with Reader(bag) as reader:
magn_connections = [x for x in reader.connections.values() if x.topic == '/magn']
@@ -188,14 +188,14 @@ def test_message_filters(bag: Path):
next(gen)
def test_user_errors(bag: Path):
def test_user_errors(bag: Path) -> None:
"""Test user errors."""
reader = Reader(bag)
with pytest.raises(ReaderError, match='Rosbag is not open'):
next(reader.messages())
def test_failure_cases(tmp_path: Path):
def test_failure_cases(tmp_path: Path) -> None:
"""Test bags with broken fs layout."""
with pytest.raises(ReaderError, match='not read metadata'):
Reader(tmp_path)
+35 -15
View File
@@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
"""Reader tests."""
from __future__ import annotations
from collections import defaultdict
from struct import pack
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
@@ -11,8 +14,12 @@ import pytest
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag1.reader import IndexData
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Union
def ser(data):
def ser(data: Union[dict[str, Any], bytes]) -> bytes:
"""Serialize record header."""
if isinstance(data, dict):
fields = []
@@ -23,7 +30,7 @@ def ser(data):
return pack('<L', len(data)) + data
def create_default_header():
def create_default_header() -> dict[str, bytes]:
"""Create empty rosbag header."""
return {
'op': b'\x03',
@@ -32,7 +39,11 @@ def create_default_header():
}
def create_connection(cid=1, topic=0, typ=0):
def create_connection(
cid: int = 1,
topic: int = 0,
typ: int = 0,
) -> tuple[dict[str, bytes], dict[str, bytes]]:
"""Create connection record."""
return {
'op': b'\x07',
@@ -45,7 +56,11 @@ def create_connection(cid=1, topic=0, typ=0):
}
def create_message(cid=1, time=0, msg=0):
def create_message(
cid: int = 1,
time: int = 0,
msg: int = 0,
) -> tuple[dict[str, Union[bytes, int]], bytes]:
"""Create message record."""
return {
'op': b'\x02',
@@ -54,7 +69,12 @@ def create_message(cid=1, time=0, msg=0):
}, f'MSGCONTENT{msg}'.encode()
def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-many-statements
def write_bag( # pylint: disable=too-many-locals,too-many-statements
bag: Path,
header: dict[str, bytes],
chunks: Sequence[Any] = (),
) -> None:
"""Write bag file."""
magic = b'#ROSBAG V2.0\n'
@@ -70,7 +90,7 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
chunk_bytes = b''
start_time = 2**32 - 1
end_time = 0
counts = defaultdict(int)
counts: dict[int, int] = defaultdict(int)
index = {}
offset = 0
@@ -95,8 +115,8 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
'count': 0,
'msgs': b'',
}
index[conn]['count'] += 1
index[conn]['msgs'] += pack('<LLL', time, 0, offset)
index[conn]['count'] += 1 # type: ignore
index[conn]['msgs'] += pack('<LLL', time, 0, offset) # type: ignore
add = ser(head) + ser(data)
chunk_bytes += add
@@ -140,19 +160,19 @@ def write_bag(bag, header, chunks=None): # pylint: disable=too-many-locals,too-
if 'index_pos' not in header:
header['index_pos'] = pack('<Q', pos)
header = ser(header)
header += b'\x20' * (4096 - len(header))
header_bytes = ser(header)
header_bytes += b'\x20' * (4096 - len(header_bytes))
bag.write_bytes(b''.join([
magic,
header,
header_bytes,
chunks_bytes,
connections,
chunkinfos,
]))
def test_indexdata():
def test_indexdata() -> None:
"""Test IndexData sort sorder."""
x42_1_0 = IndexData(42, 1, 0)
x42_2_0 = IndexData(42, 2, 0)
@@ -175,7 +195,7 @@ def test_indexdata():
assert not x42_1_0 > x43_3_0
def test_reader(tmp_path): # pylint: disable=too-many-statements
def test_reader(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test reader and deserializer on simple bag."""
# empty bag
bag = tmp_path / 'test.bag'
@@ -268,7 +288,7 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert msgs[0][2] == b'MSGCONTENT5'
def test_user_errors(tmp_path):
def test_user_errors(tmp_path: Path) -> None:
"""Test user errors."""
bag = tmp_path / 'test.bag'
write_bag(bag, create_default_header(), chunks=[[
@@ -281,7 +301,7 @@ def test_user_errors(tmp_path):
next(reader.messages())
def test_failure_cases(tmp_path): # pylint: disable=too-many-statements
def test_failure_cases(tmp_path: Path) -> None: # pylint: disable=too-many-statements
"""Test failure cases."""
bag = tmp_path / 'test.bag'
with pytest.raises(ReaderError, match='does not exist'):
+1 -1
View File
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('mode', [*Writer.CompressionMode])
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path):
def test_roundtrip(mode: Writer.CompressionMode, tmp_path: Path) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods
+1 -1
View File
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
def test_roundtrip(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test full data roundtrip."""
class Foo: # pylint: disable=too-few-public-methods
+13 -12
View File
@@ -18,7 +18,7 @@ from rosbags.typesys.types import builtin_interfaces__msg__Time, std_msgs__msg__
from .cdr import deserialize, serialize
if TYPE_CHECKING:
from typing import Any, Tuple, Union
from typing import Any, Generator, Union
MSG_POLY = (
(
@@ -169,7 +169,7 @@ test_msgs/msg/dynamic_s_64[] seq_msg_ds6
@pytest.fixture()
def _comparable():
def _comparable() -> Generator[None, None, None]:
"""Make messages containing numpy arrays comparable.
Notes:
@@ -180,7 +180,7 @@ def _comparable():
def arreq(self: MagicMock, other: Union[MagicMock, Any]) -> bool:
lhs = self._mock_wraps # pylint: disable=protected-access
rhs = getattr(other, '_mock_wraps', other)
return (lhs == rhs).all()
return (lhs == rhs).all() # type: ignore
class CNDArray(MagicMock):
"""Mock ndarray."""
@@ -194,14 +194,14 @@ def _comparable():
return CNDArray(wraps=self._mock_wraps.byteswap(*args))
def wrap_frombuffer(*args: Any, **kwargs: Any) -> CNDArray:
return CNDArray(wraps=frombuffer(*args, **kwargs))
return CNDArray(wraps=frombuffer(*args, **kwargs)) # type: ignore
with patch.object(numpy, 'frombuffer', side_effect=wrap_frombuffer):
yield
@pytest.mark.parametrize('message', MESSAGES)
def test_serde(message: Tuple[bytes, str, bool]):
def test_serde(message: tuple[bytes, str, bool]) -> None:
"""Test serialization deserialization roundtrip."""
rawdata, typ, is_little = message
@@ -213,7 +213,7 @@ def test_serde(message: Tuple[bytes, str, bool]):
@pytest.mark.usefixtures('_comparable')
def test_deserializer():
def test_deserializer() -> None:
"""Test deserializer."""
msg = deserialize_cdr(*MSG_POLY[:2])
assert msg == deserialize(*MSG_POLY[:2])
@@ -233,7 +233,8 @@ def test_deserializer():
assert msg.header.frame_id == 'foo42'
field = msg.magnetic_field
assert (field.x, field.y, field.z) == (128., 128., 128.)
assert (numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) == [1., 1., 1.]).all()
diag = numpy.diag(msg.magnetic_field_covariance.reshape(3, 3)) # type: ignore
assert (diag == [1., 1., 1.]).all()
msg_big = deserialize_cdr(*MSG_MAGN_BIG[:2])
assert msg_big == deserialize(*MSG_MAGN_BIG[:2])
@@ -241,7 +242,7 @@ def test_deserializer():
@pytest.mark.usefixtures('_comparable')
def test_serializer():
def test_serializer() -> None:
"""Test serializer."""
class Foo: # pylint: disable=too-few-public-methods
@@ -268,7 +269,7 @@ def test_serializer():
@pytest.mark.usefixtures('_comparable')
def test_serializer_errors():
def test_serializer_errors() -> None:
"""Test seralizer with broken messages."""
class Foo: # pylint: disable=too-few-public-methods
@@ -286,7 +287,7 @@ def test_serializer_errors():
@pytest.mark.usefixtures('_comparable')
def test_custom_type():
def test_custom_type() -> None:
"""Test custom type."""
cname = 'test_msgs/msg/custom'
register_types(dict(get_types_from_msg(STATIC_64_64, 'test_msgs/msg/static_64_64')))
@@ -362,7 +363,7 @@ def test_custom_type():
assert res == msg
def test_ros1_to_cdr():
def test_ros1_to_cdr() -> None:
"""Test ROS1 to CDR conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
@@ -385,7 +386,7 @@ def test_ros1_to_cdr():
assert ros1_to_cdr(msg_ros, 'test_msgs/msg/dynamic_s_64') == msg_cdr
def test_cdr_to_ros1():
def test_cdr_to_ros1() -> None:
"""Test CDR to ROS1 conversion."""
register_types(dict(get_types_from_msg(STATIC_16_64, 'test_msgs/msg/static_16_64')))
msg_ros = (b'\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x02')
+2 -2
View File
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
from pathlib import Path
def test_writer(tmp_path: Path):
def test_writer(tmp_path: Path) -> None:
"""Test Writer."""
path = (tmp_path / 'rosbag2')
with Writer(path) as bag:
@@ -60,7 +60,7 @@ def test_writer(tmp_path: Path):
assert size > (path / 'compress_message.db3').stat().st_size
def test_failure_cases(tmp_path: Path):
def test_failure_cases(tmp_path: Path) -> None:
"""Test writer failure cases."""
with pytest.raises(WriterError, match='exists'):
Writer(tmp_path)
+7 -7
View File
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Optional
def test_no_overwrite(tmp_path: Path):
def test_no_overwrite(tmp_path: Path) -> None:
"""Test writer does not touch existing files."""
path = tmp_path / 'test.bag'
path.write_text('foo')
@@ -30,7 +30,7 @@ def test_no_overwrite(tmp_path: Path):
writer.open()
def test_empty(tmp_path: Path):
def test_empty(tmp_path: Path) -> None:
"""Test empty bag."""
path = tmp_path / 'test.bag'
@@ -40,7 +40,7 @@ def test_empty(tmp_path: Path):
assert len(data) == 13 + 4096
def test_add_connection(tmp_path: Path):
def test_add_connection(tmp_path: Path) -> None:
"""Test adding of connections."""
path = tmp_path / 'test.bag'
@@ -88,7 +88,7 @@ def test_add_connection(tmp_path: Path):
assert (res1.cid, res2.cid, res3.cid) == (0, 1, 2)
def test_write_errors(tmp_path: Path):
def test_write_errors(tmp_path: Path) -> None:
"""Test write errors."""
path = tmp_path / 'test.bag'
@@ -101,7 +101,7 @@ def test_write_errors(tmp_path: Path):
path.unlink()
def test_write_simple(tmp_path: Path):
def test_write_simple(tmp_path: Path) -> None:
"""Test writing of messages."""
path = tmp_path / 'test.bag'
@@ -179,7 +179,7 @@ def test_write_simple(tmp_path: Path):
path.unlink()
def test_compression_errors(tmp_path: Path):
def test_compression_errors(tmp_path: Path) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
with Writer(path) as writer, \
@@ -188,7 +188,7 @@ def test_compression_errors(tmp_path: Path):
@pytest.mark.parametrize('fmt', [None, Writer.CompressionFormat.BZ2, Writer.CompressionFormat.LZ4])
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]):
def test_compression_modes(tmp_path: Path, fmt: Optional[Writer.CompressionFormat]) -> None:
"""Test compression modes."""
path = tmp_path / 'test.bag'
writer = Writer(path)