Add const fields to type representations

This commit is contained in:
Marko Durkovic 2021-08-01 17:38:18 +02:00 committed by Florian Friesdorf
parent fa57b16765
commit 03b4d7e5c7
12 changed files with 1573 additions and 907 deletions

View File

@ -68,19 +68,19 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
size += SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1]
subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
lines.append(f' val = message.{fieldname}')
for idx in range(desc.args[0]):
for idx in range(length):
lines.append(' pos = (pos + 4 - 1) & -4')
lines.append(f' pos += 4 + len(val[{idx}].encode()) + 1')
aligned = 1
is_stat = False
else:
lines.append(f' pos += {desc.args[0] * SIZEMAP[subdesc.args]}')
size += desc.args[0] * SIZEMAP[subdesc.args]
lines.append(f' pos += {length * SIZEMAP[subdesc.args]}')
size += length * SIZEMAP[subdesc.args]
else:
assert subdesc.valtype == Valtype.MESSAGE
@ -88,7 +88,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
anext_after = align_after(subdesc)
if subdesc.args.size_cdr:
for _ in range(desc.args[0]):
for _ in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
size = (size + anext - 1) & -anext
@ -97,7 +97,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
lines.append(f' val = message.{fieldname}')
for idx in range(desc.args[0]):
for idx in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = func(pos, val[{idx}])')
@ -107,7 +107,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
assert desc.valtype == Valtype.SEQUENCE
lines.append(' pos += 4')
aligned = 4
subdesc = desc.args
subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
lines.append(f' for val in message.{fieldname}:')
@ -211,13 +211,13 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
aligned = SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1]
lines.append(f' if len(val) != {desc.args[0]}:')
subdesc, length = desc.args
lines.append(f' if len(val) != {length}:')
lines.append(' raise SerdeError(\'Unexpected array length\')')
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
for idx in range(desc.args[0]):
for idx in range(length):
lines.append(f' bval = memoryview(val[{idx}].encode())')
lines.append(' length = len(bval) + 1')
lines.append(' pos = (pos + 4 - 1) & -4')
@ -229,7 +229,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
else:
if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()')
size = desc.args[0] * SIZEMAP[subdesc.args]
size = length * SIZEMAP[subdesc.args]
lines.append(f' rawdata[pos:pos + {size}] = val.view(numpy.uint8)')
lines.append(f' pos += {size}')
@ -240,7 +240,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
)
for idx in range(desc.args[0]):
for idx in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = func(rawdata, pos, val[{idx}])')
@ -250,7 +250,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(f' pack_int32_{endianess}(rawdata, pos, len(val))')
lines.append(' pos += 4')
aligned = 4
subdesc = desc.args
subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
@ -350,11 +350,11 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
aligned = SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1]
subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
lines.append(' value = []')
for idx in range(desc.args[0]):
for idx in range(length):
if idx:
lines.append(' pos = (pos + 4 - 1) & -4')
lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]')
@ -365,10 +365,10 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(' values.append(value)')
aligned = 1
else:
size = desc.args[0] * SIZEMAP[subdesc.args]
size = length * SIZEMAP[subdesc.args]
lines.append(
f' val = numpy.frombuffer(rawdata, '
f'dtype=numpy.{subdesc.args}, count={desc.args[0]}, offset=pos)',
f'dtype=numpy.{subdesc.args}, count={length}, offset=pos)',
)
if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()')
@ -380,7 +380,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
anext_after = align_after(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []')
for _ in range(desc.args[0]):
for _ in range(length):
if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
@ -393,7 +393,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(f' size = unpack_int32_{endianess}(rawdata, pos)[0]')
lines.append(' pos += 4')
aligned = 4
subdesc = desc.args
subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':

View File

@ -10,13 +10,12 @@ from rosbags.typesys import types
from .cdr import generate_deserialize_cdr, generate_getsize_cdr, generate_serialize_cdr
from .ros1 import generate_ros1_to_cdr
from .typing import Field, Msgdef
from .utils import Descriptor, Valtype
from .typing import Descriptor, Field, Msgdef
from .utils import Valtype
if TYPE_CHECKING:
from typing import Any, Dict
MSGDEFCACHE: Dict[str, Msgdef] = {}
@ -37,7 +36,7 @@ def get_msgdef(typename: str) -> Msgdef:
"""
if typename not in MSGDEFCACHE:
entries = types.FIELDDEFS[typename]
entries = types.FIELDDEFS[typename][1]
def fixup(entry: Any) -> Descriptor:
if entry[0] == Valtype.BASE:
@ -45,9 +44,9 @@ def get_msgdef(typename: str) -> Msgdef:
if entry[0] == Valtype.MESSAGE:
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
if entry[0] == Valtype.ARRAY:
return Descriptor(Valtype.ARRAY, (entry[1], fixup(entry[2])))
return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
if entry[0] == Valtype.SEQUENCE:
return Descriptor(Valtype.SEQUENCE, fixup(entry[1]))
return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
raise SerdeError( # pragma: no cover
f'Unknown field type {entry[0]!r} encountered.',
)

View File

@ -22,12 +22,12 @@ if TYPE_CHECKING:
def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable:
"""Generate CDR serialization function.
"""Generate ROS1 to CDR conversion function.
Args:
fields: Fields of message.
typename: Message type name.
copy: Generate serialization or sizing function.
copy: Generate conversion or sizing function.
Returns:
ROS1 to CDR conversion function.
@ -42,17 +42,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
'import sys',
'import numpy',
'from rosbags.serde.messages import SerdeError, get_msgdef',
'from rosbags.serde.primitives import pack_bool_le',
'from rosbags.serde.primitives import pack_int8_le',
'from rosbags.serde.primitives import pack_int16_le',
'from rosbags.serde.primitives import pack_int32_le',
'from rosbags.serde.primitives import pack_int64_le',
'from rosbags.serde.primitives import pack_uint8_le',
'from rosbags.serde.primitives import pack_uint16_le',
'from rosbags.serde.primitives import pack_uint32_le',
'from rosbags.serde.primitives import pack_uint64_le',
'from rosbags.serde.primitives import pack_float32_le',
'from rosbags.serde.primitives import pack_float64_le',
'from rosbags.serde.primitives import unpack_int32_le',
f'def {funcname}(input, ipos, output, opos):',
]
@ -89,11 +79,11 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
aligned = size
elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1]
subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string':
for _ in range(desc.args[0]):
for _ in range(length):
lines.append(' opos = (opos + 4 - 1) & -4')
lines.append(' length = unpack_int32_le(input, ipos)[0] + 1')
if copy:
@ -108,7 +98,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
lines.append(' opos += length')
aligned = 1
else:
size = desc.args[0] * SIZEMAP[subdesc.args]
size = length * SIZEMAP[subdesc.args]
if copy:
lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]')
lines.append(f' ipos += {size}')
@ -120,7 +110,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(desc.args[0]):
for _ in range(length):
if anext > anext_after:
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
lines.append(' ipos, opos = func(input, ipos, output, opos)')
@ -132,7 +122,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
lines.append(' pack_int32_le(output, opos, size)')
lines.append(' ipos += 4')
lines.append(' opos += 4')
subdesc = desc.args
subdesc = desc.args[0]
aligned = 4
if subdesc.valtype == Valtype.BASE:

View File

@ -7,9 +7,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Any, Callable, List # pylint: disable=ungrouped-imports
from typing import Any, Callable, List
from .utils import Descriptor
class Descriptor(NamedTuple):
"""Value type descriptor."""
valtype: int
args: Any
class Field(NamedTuple):

View File

@ -6,11 +6,13 @@ from __future__ import annotations
from enum import IntEnum
from importlib.util import module_from_spec, spec_from_loader
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import ModuleType
from typing import Any, Dict, List
from typing import Dict, List
from .typing import Descriptor
class Valtype(IntEnum):
@ -22,13 +24,6 @@ class Valtype(IntEnum):
SEQUENCE = 4
class Descriptor(NamedTuple):
"""Value type descriptor."""
valtype: Valtype
args: Any # Union[Descriptor, Msgdef, Tuple[int, Descriptor], str]
SIZEMAP: Dict[str, int] = {
'bool': 1,
'int8': 1,
@ -61,7 +56,7 @@ def align(entry: Descriptor) -> int:
if entry.valtype == Valtype.MESSAGE:
return align(entry.args.fields[0].descriptor)
if entry.valtype == Valtype.ARRAY:
return align(entry.args[1])
return align(entry.args[0])
assert entry.valtype == Valtype.SEQUENCE
return 4
@ -83,9 +78,9 @@ def align_after(entry: Descriptor) -> int:
if entry.valtype == Valtype.MESSAGE:
return align_after(entry.args.fields[-1].descriptor)
if entry.valtype == Valtype.ARRAY:
return align_after(entry.args[1])
return align_after(entry.args[0])
assert entry.valtype == Valtype.SEQUENCE
return min([4, align_after(entry.args)])
return min([4, align_after(entry.args[0])])
def compile_lines(lines: List[str]) -> ModuleType:

View File

@ -8,12 +8,14 @@ from enum import IntEnum, auto
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from .peg import Visitor
Fielddefs = List[Tuple[Any, Any]]
Typesdict = Dict[str, Fielddefs]
Constdefs = List[Tuple[str, str, Any]]
Fielddesc = Tuple[int, Union[str, Tuple[Tuple[int, str], Optional[int]]]]
Fielddefs = List[Tuple[str, Fielddesc]]
Typesdict = Dict[str, Tuple[Constdefs, Fielddefs]]
class TypesysError(Exception):

View File

@ -261,8 +261,20 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol, return only children of modules."""
children = [x[0] for x in children if x is not None]
modules = [y for t, x in children if t == Nodetype.MODULE for y in x]
return {x[1]: x[2] for x in modules if x[0] == Nodetype.STRUCT}
structs = {}
consts: dict[str, list[tuple[str, str, Any]]] = {}
for item in children:
if item[0] != Nodetype.MODULE:
continue
for subitem in item[1]:
if subitem[0] == Nodetype.STRUCT:
structs[subitem[1]] = subitem[2]
elif subitem[0] == Nodetype.CONST and '_Constants/' in subitem[1][1]:
structname, varname = subitem[1][1].split('_Constants/')
if structname not in consts:
consts[structname] = []
consts[structname].append((varname, subitem[1][0], subitem[1][2]))
return {k: (consts.get(k, []), v) for k, v in structs.items()}
def visit_comment(self, children: Any) -> Any:
"""Process comment, suppress output."""
@ -273,12 +285,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
def visit_include(self, children: Any) -> Any:
"""Process include, suppress output."""
def visit_type_dcl(self, children: Any) -> Any:
"""Process typedef, pass structs, suppress otherwise."""
if children[0] == Nodetype.STRUCT:
return children
return None
def visit_module_dcl(self, children: Any) -> Any:
"""Process module declaration."""
assert len(children) == 6
@ -288,7 +294,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
children = children[4]
consts = []
structs = []
modules = []
for item in children:
if not item or item[0] is None:
continue
@ -299,20 +304,23 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
structs.append(item)
else:
assert item[0] == Nodetype.MODULE
modules.append(item)
consts += [x for x in item[1] if x[0] == Nodetype.CONST]
structs += [x for x in item[1] if x[0] == Nodetype.STRUCT]
for _, module in modules:
consts += [x for x in module if x[0] == Nodetype.CONST]
structs += [x for x in module if x[0] == Nodetype.STRUCT]
consts = [(x[0], f'{name}/{x[1][0]}', *x[1][1:]) for x in consts]
consts = [(x[0], (x[1][0], f'{name}/{x[1][1]}', x[1][2])) for x in consts]
structs = [(x[0], f'{name}/{x[1]}', *x[2:]) for x in structs]
return (Nodetype.MODULE, consts + structs)
def visit_const_dcl(self, children: Any) -> Any:
"""Process const declaration."""
return (Nodetype.CONST, (children[1][1], *children[2:]))
return (Nodetype.CONST, (children[1][1], children[2][1], children[4][1]))
def visit_type_dcl(self, children: Any) -> Any:
"""Process type, pass structs, suppress otherwise."""
if children[0] == Nodetype.STRUCT:
return children
return None
def visit_type_declarator(self, children: Any) -> Any:
"""Process type declarator, register type mapping in instance typedef dictionary."""
@ -323,7 +331,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
declarators = [children[1][0], *[x[1:][0] for x in children[1][1]]]
for declarator in declarators:
if declarator[0] == Nodetype.ADECLARATOR:
value = (Nodetype.ARRAY, declarator[2][1], base)
value = (Nodetype.ARRAY, (base, declarator[2][1]))
else:
value = base
self.typedefs[declarator[1][1]] = value
@ -333,8 +341,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
assert len(children) in [4, 6]
if len(children) == 6:
assert children[4][0] == Nodetype.LITERAL_NUMBER
return (Nodetype.SEQUENCE, children[2])
return (Nodetype.SEQUENCE, children[2])
return (Nodetype.SEQUENCE, (children[2], None))
def create_struct_field(self, parts: Any) -> Any:
"""Create struct field and expand typedefs."""
@ -346,7 +353,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
name = self.typedefs[name[1]]
return name
yield from ((resolve_name(typename), x[1]) for x in params if x)
yield from ((x[1][1], resolve_name(typename)) for x in params if x)
def visit_struct_dcl(self, children: Any) -> Any:
"""Process struct declaration."""

View File

@ -21,7 +21,7 @@ from .peg import Rule, Visitor, parse_grammar
if TYPE_CHECKING:
from typing import Any, List
from .base import Typesdict
from .base import Fielddesc, Typesdict
GRAMMAR_MSG = r"""
specification
@ -42,7 +42,7 @@ comment
= r'#[^\n]*'
const_dcl
= type_spec identifier '=' r'[^=][^\n]*'
= type_spec identifier '=' integer_literal
field_dcl
= type_spec identifier
@ -89,7 +89,7 @@ def normalize_msgtype(name: str) -> str:
return str(path)
def normalize_fieldtype(typename: str, field: Any, names: List[str]):
def normalize_fieldtype(typename: str, field: Fielddesc, names: List[str]) -> Fielddesc:
"""Normalize field typename.
Args:
@ -97,18 +97,20 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
field: Field definition.
names: Valid message names.
Returns:
Normalized fieldtype.
"""
dct = {Path(name).name: name for name in names}
namedef = field[0]
if namedef[0] == Nodetype.NAME:
name = namedef[1]
elif namedef[0] == Nodetype.SEQUENCE:
name = namedef[1][1]
ftype, args = field
if ftype == Nodetype.NAME:
name = args
else:
name = namedef[2][1]
name = args[0][1]
assert isinstance(name, str)
if name in VisitorMSG.BASETYPES:
inamedef = (Nodetype.BASE, name)
ifield = (Nodetype.BASE, name)
else:
if name in dct:
name = dct[name]
@ -118,16 +120,13 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
name = str(Path(typename).parent / name)
elif '/msg/' not in name:
name = str((path := Path(name)).parent / 'msg' / path.name)
inamedef = (Nodetype.NAME, name)
ifield = (Nodetype.NAME, name)
if namedef[0] == Nodetype.NAME:
namedef = inamedef
elif namedef[0] == Nodetype.SEQUENCE:
namedef = (Nodetype.SEQUENCE, inamedef)
else:
namedef = (Nodetype.ARRAY, namedef[1], inamedef)
if ftype == Nodetype.NAME:
return ifield
field[0] = namedef
assert not isinstance(args, str)
return (ftype, (ifield, args[1]))
class VisitorMSG(Visitor):
@ -157,6 +156,7 @@ class VisitorMSG(Visitor):
def visit_const_dcl(self, children: Any) -> Any:
"""Process const declaration, suppress output."""
return Nodetype.CONST, (children[0][1], children[1][1], children[3])
def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol."""
@ -164,8 +164,10 @@ class VisitorMSG(Visitor):
typedict = dict(typelist)
names = list(typedict.keys())
for name, fields in typedict.items():
for field in fields:
normalize_fieldtype(name, field, names)
consts = [(x[1][1], x[1][0], x[1][2]) for x in fields if x[0] == Nodetype.CONST]
fields = [x for x in fields if x[0] != Nodetype.CONST]
fields = [(field[1][1], normalize_fieldtype(name, field[0], names)) for field in fields]
typedict[name] = consts, fields
return typedict
def visit_msgdef(self, children: Any) -> Any:
@ -180,8 +182,8 @@ class VisitorMSG(Visitor):
"""Process array type specifier."""
length = children[1][1]
if length:
return (Nodetype.ARRAY, int(length[0]), children[0])
return (Nodetype.SEQUENCE, children[0])
return Nodetype.ARRAY, (children[0], length[0])
return Nodetype.SEQUENCE, (children[0], None)
def visit_simple_type_spec(self, children: Any) -> Any:
"""Process simple type specifier."""
@ -204,6 +206,10 @@ class VisitorMSG(Visitor):
"""Process identifier."""
return (Nodetype.NAME, children)
def visit_integer_literal(self, children: Any) -> Any:
"""Process integer literal."""
return int(children)
def get_types_from_msg(text: str, name: str) -> Typesdict:
"""Get type from msg message definition.

View File

@ -4,7 +4,6 @@
from __future__ import annotations
import json
import re
import sys
from importlib.util import module_from_spec, spec_from_loader
@ -37,7 +36,7 @@ def get_typehint(desc: tuple) -> str:
if desc[0] == Nodetype.NAME:
return desc[1].replace('/', '__')
sub = desc[2 if desc[0] == Nodetype.ARRAY else 1]
sub = desc[1][0]
if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
@ -71,21 +70,27 @@ def generate_python_code(typs: Typesdict) -> str:
'from typing import TYPE_CHECKING',
'',
'if TYPE_CHECKING:',
' from typing import Any',
' from typing import Any, ClassVar',
'',
' import numpy',
'',
' from .base import Typesdict',
'',
]
for name, fields in typs.items():
for name, (consts, fields) in typs.items():
pyname = name.replace('/', '__')
lines += [
'@dataclass',
f'class {pyname}:',
f' """Class for {name}."""',
'',
*[f' {fname[1]}: {get_typehint(desc)}' for desc, fname in fields],
*[f' {fname}: {get_typehint(desc)}' for fname, desc in fields],
*[
f' {fname}: ClassVar[{get_typehint((1, ftype))}] = {fvalue}'
for fname, ftype, fvalue in consts
],
f' __msgtype__: ClassVar[str] = {name!r}',
]
lines += [
@ -93,16 +98,20 @@ def generate_python_code(typs: Typesdict) -> str:
'',
]
lines += ['FIELDDEFS = {']
for name, fields in typs.items():
def get_ftype(ftype: tuple) -> tuple:
if ftype[0] <= 2:
return int(ftype[0]), ftype[1]
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
lines += ['FIELDDEFS: Typesdict = {']
for name, (consts, fields) in typs.items():
pyname = name.replace('/', '__')
lines += [
f' \'{name}\': [',
*[
f' ({repr(fname[1])}, {json.loads(json.dumps(ftype))}),'
for ftype, fname in fields
],
' ],',
f' \'{name}\': ([',
*[f' ({fname!r}, {ftype!r}, {fvalue!r}),' for fname, ftype, fvalue in consts],
' ], [',
*[f' ({fname!r}, {get_ftype(ftype)!r}),' for fname, ftype in fields],
' ]),',
]
lines += [
'}',
@ -127,15 +136,16 @@ def register_types(typs: Typesdict) -> None:
module = module_from_spec(spec)
sys.modules[name] = module
exec(code, module.__dict__) # pylint: disable=exec-used
fielddefs = module.FIELDDEFS # type: ignore
fielddefs: Typesdict = module.FIELDDEFS # type: ignore
for name, fields in fielddefs.items():
for name, (_, fields) in fielddefs.items():
if name == 'std_msgs/msg/Header':
continue
if have := types.FIELDDEFS.get(name):
have = [(x[0].lower(), x[1]) for x in have]
_, have_fields = have
have_fields = [(x[0].lower(), x[1]) for x in have_fields]
fields = [(x[0].lower(), x[1]) for x in fields]
if have != fields:
if have_fields != fields:
raise TypesysError(f'Type {name!r} is already present with different definition.')
for name in fielddefs.keys() - types.FIELDDEFS.keys():

File diff suppressed because it is too large Load Diff

View File

@ -161,12 +161,13 @@ def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msg
values.append(num)
elif desc.valtype == Valtype.ARRAY:
arr, pos = deserialize_array(rawdata, bmap, pos, *desc.args)
subdesc, length = desc.args
arr, pos = deserialize_array(rawdata, bmap, pos, length, subdesc)
values.append(arr)
elif desc.valtype == Valtype.SEQUENCE:
size, pos = deserialize_number(rawdata, bmap, pos, 'int32')
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args)
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args[0])
values.append(arr)
return msgdef.cls(*values), pos
@ -323,12 +324,12 @@ def serialize_message(
pos = serialize_number(rawdata, bmap, pos, desc.args, val)
elif desc.valtype == Valtype.ARRAY:
pos = serialize_array(rawdata, bmap, pos, desc.args[1], val)
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
elif desc.valtype == Valtype.SEQUENCE:
size = len(val)
pos = serialize_number(rawdata, bmap, pos, 'int32', size)
pos = serialize_array(rawdata, bmap, pos, desc.args, val)
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
return pos
@ -397,14 +398,15 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
size += isize
elif desc.valtype == Valtype.ARRAY:
if len(val) != desc.args[0]:
raise SerdeError(f'Unexpected array length: {len(val)} != {desc.args[0]}.')
size = get_array_size(desc.args[1], val, size)
subdesc, length = desc.args
if len(val) != length:
raise SerdeError(f'Unexpected array length: {len(val)} != {length}.')
size = get_array_size(subdesc, val, size)
elif desc.valtype == Valtype.SEQUENCE:
size = (size + 4 - 1) & -4
size += 4
size = get_array_size(desc.args, val, size)
size = get_array_size(desc.args[0], val, size)
return size

View File

@ -35,6 +35,7 @@ time time
================================================================================
MSG: test_msgs/Other
uint64[3] Header
uint32 static = 42
"""
RELSIBLING_MSG = """
@ -81,6 +82,11 @@ module test_msgs {
typedef test_msgs::msg::Bar Bar;
typedef double d4[4];
module Foo_Constants {
const int32 FOO = 32;
const int64 BAR = 64;
};
@comment(type="text", text="ignore")
struct Foo {
std_msgs::msg::Header header;
@ -102,17 +108,18 @@ def test_parse_msg():
get_types_from_msg('', 'test_msgs/msg/Foo')
ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo')
assert 'test_msgs/msg/Foo' in ret
fields = ret['test_msgs/msg/Foo']
assert fields[0][0][1] == 'std_msgs/msg/Header'
assert fields[0][1][1] == 'header'
assert fields[1][0][1] == 'std_msgs/msg/Bool'
assert fields[1][1][1] == 'bool'
assert fields[2][0][1] == 'test_msgs/msg/Bar'
assert fields[2][1][1] == 'sibling'
assert fields[3][0][0] == Nodetype.BASE
assert fields[4][0][0] == Nodetype.SEQUENCE
assert fields[5][0][0] == Nodetype.SEQUENCE
assert fields[6][0][0] == Nodetype.ARRAY
consts, fields = ret['test_msgs/msg/Foo']
assert consts == [('global', 'int32', 42)]
assert fields[0][0] == 'header'
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][0] == 'bool'
assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][0] == 'sibling'
assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[3][1][0] == Nodetype.BASE
assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg():
@ -122,20 +129,23 @@ def test_parse_multi_msg():
assert 'test_msgs/msg/Foo' in ret
assert 'std_msgs/msg/Header' in ret
assert 'test_msgs/msg/Other' in ret
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][1][0][1] == 'uint8'
assert ret['test_msgs/msg/Foo'][2][0][1] == 'uint8'
fields = ret['test_msgs/msg/Foo'][1]
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][1][1] == 'uint8'
assert fields[2][1][1] == 'uint8'
consts = ret['test_msgs/msg/Other'][0]
assert consts == [('static', 'uint32', 42)]
def test_parse_relative_siblings_msg():
"""Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][1][0][1] == 'test_msgs/msg/Other'
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][1][1][1][1] == 'test_msgs/msg/Other'
ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo')
assert ret['rel_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
assert ret['rel_msgs/msg/Foo'][1][0][1] == 'rel_msgs/msg/Other'
assert ret['rel_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl():
@ -145,28 +155,29 @@ def test_parse_idl():
ret = get_types_from_idl(IDL)
assert 'test_msgs/msg/Foo' in ret
fields = ret['test_msgs/msg/Foo']
assert fields[0][0][1] == 'std_msgs/msg/Header'
assert fields[0][1][1] == 'header'
assert fields[1][0][1] == 'std_msgs/msg/Bool'
assert fields[1][1][1] == 'bool'
assert fields[2][0][1] == 'test_msgs/msg/Bar'
assert fields[2][1][1] == 'sibling'
assert fields[3][0][0] == Nodetype.BASE
assert fields[4][0][0] == Nodetype.SEQUENCE
assert fields[5][0][0] == Nodetype.SEQUENCE
assert fields[6][0][0] == Nodetype.ARRAY
consts, fields = ret['test_msgs/msg/Foo']
assert consts == [('FOO', 'int32', 32), ('BAR', 'int64', 64)]
assert fields[0][0] == 'header'
assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][0] == 'bool'
assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][0] == 'sibling'
assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[3][1][0] == Nodetype.BASE
assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
def test_register_types():
"""Test type registeration."""
assert 'foo' not in FIELDDEFS
register_types({})
register_types({'foo': [[(1, 'bool'), (2, 'b')]]})
register_types({'foo': [[], [('b', (1, 'bool'))]]})
assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': []})
assert len(FIELDDEFS['std_msgs/msg/Header']) == 2
register_types({'std_msgs/msg/Header': [[], []]})
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[(1, 'bool'), (2, 'x')]]})
register_types({'foo': [[], [('x', (1, 'bool'))]]})