Add const fields to type representations

This commit is contained in:
Marko Durkovic 2021-08-01 17:38:18 +02:00 committed by Florian Friesdorf
parent fa57b16765
commit 03b4d7e5c7
12 changed files with 1573 additions and 907 deletions

View File

@ -68,19 +68,19 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
size += SIZEMAP[desc.args] size += SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1] subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
lines.append(f' val = message.{fieldname}') lines.append(f' val = message.{fieldname}')
for idx in range(desc.args[0]): for idx in range(length):
lines.append(' pos = (pos + 4 - 1) & -4') lines.append(' pos = (pos + 4 - 1) & -4')
lines.append(f' pos += 4 + len(val[{idx}].encode()) + 1') lines.append(f' pos += 4 + len(val[{idx}].encode()) + 1')
aligned = 1 aligned = 1
is_stat = False is_stat = False
else: else:
lines.append(f' pos += {desc.args[0] * SIZEMAP[subdesc.args]}') lines.append(f' pos += {length * SIZEMAP[subdesc.args]}')
size += desc.args[0] * SIZEMAP[subdesc.args] size += length * SIZEMAP[subdesc.args]
else: else:
assert subdesc.valtype == Valtype.MESSAGE assert subdesc.valtype == Valtype.MESSAGE
@ -88,7 +88,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
if subdesc.args.size_cdr: if subdesc.args.size_cdr:
for _ in range(desc.args[0]): for _ in range(length):
if anext > anext_after: if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
size = (size + anext - 1) & -anext size = (size + anext - 1) & -anext
@ -97,7 +97,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
else: else:
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr') lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
lines.append(f' val = message.{fieldname}') lines.append(f' val = message.{fieldname}')
for idx in range(desc.args[0]): for idx in range(length):
if anext > anext_after: if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = func(pos, val[{idx}])') lines.append(f' pos = func(pos, val[{idx}])')
@ -107,7 +107,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
assert desc.valtype == Valtype.SEQUENCE assert desc.valtype == Valtype.SEQUENCE
lines.append(' pos += 4') lines.append(' pos += 4')
aligned = 4 aligned = 4
subdesc = desc.args subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
lines.append(f' for val in message.{fieldname}:') lines.append(f' for val in message.{fieldname}:')
@ -211,13 +211,13 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
aligned = SIZEMAP[desc.args] aligned = SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1] subdesc, length = desc.args
lines.append(f' if len(val) != {desc.args[0]}:') lines.append(f' if len(val) != {length}:')
lines.append(' raise SerdeError(\'Unexpected array length\')') lines.append(' raise SerdeError(\'Unexpected array length\')')
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
for idx in range(desc.args[0]): for idx in range(length):
lines.append(f' bval = memoryview(val[{idx}].encode())') lines.append(f' bval = memoryview(val[{idx}].encode())')
lines.append(' length = len(bval) + 1') lines.append(' length = len(bval) + 1')
lines.append(' pos = (pos + 4 - 1) & -4') lines.append(' pos = (pos + 4 - 1) & -4')
@ -229,7 +229,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
else: else:
if (endianess == 'le') != (sys.byteorder == 'little'): if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()') lines.append(' val = val.byteswap()')
size = desc.args[0] * SIZEMAP[subdesc.args] size = length * SIZEMAP[subdesc.args]
lines.append(f' rawdata[pos:pos + {size}] = val.view(numpy.uint8)') lines.append(f' rawdata[pos:pos + {size}] = val.view(numpy.uint8)')
lines.append(f' pos += {size}') lines.append(f' pos += {size}')
@ -240,7 +240,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append( lines.append(
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}', f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
) )
for idx in range(desc.args[0]): for idx in range(length):
if anext > anext_after: if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' pos = func(rawdata, pos, val[{idx}])') lines.append(f' pos = func(rawdata, pos, val[{idx}])')
@ -250,7 +250,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(f' pack_int32_{endianess}(rawdata, pos, len(val))') lines.append(f' pack_int32_{endianess}(rawdata, pos, len(val))')
lines.append(' pos += 4') lines.append(' pos += 4')
aligned = 4 aligned = 4
subdesc = desc.args subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
@ -350,11 +350,11 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
aligned = SIZEMAP[desc.args] aligned = SIZEMAP[desc.args]
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1] subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
lines.append(' value = []') lines.append(' value = []')
for idx in range(desc.args[0]): for idx in range(length):
if idx: if idx:
lines.append(' pos = (pos + 4 - 1) & -4') lines.append(' pos = (pos + 4 - 1) & -4')
lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]') lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]')
@ -365,10 +365,10 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(' values.append(value)') lines.append(' values.append(value)')
aligned = 1 aligned = 1
else: else:
size = desc.args[0] * SIZEMAP[subdesc.args] size = length * SIZEMAP[subdesc.args]
lines.append( lines.append(
f' val = numpy.frombuffer(rawdata, ' f' val = numpy.frombuffer(rawdata, '
f'dtype=numpy.{subdesc.args}, count={desc.args[0]}, offset=pos)', f'dtype=numpy.{subdesc.args}, count={length}, offset=pos)',
) )
if (endianess == 'le') != (sys.byteorder == 'little'): if (endianess == 'le') != (sys.byteorder == 'little'):
lines.append(' val = val.byteswap()') lines.append(' val = val.byteswap()')
@ -380,7 +380,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")') lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
lines.append(' value = []') lines.append(' value = []')
for _ in range(desc.args[0]): for _ in range(length):
if anext > anext_after: if anext > anext_after:
lines.append(f' pos = (pos + {anext} - 1) & -{anext}') lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)') lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
@ -393,7 +393,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
lines.append(f' size = unpack_int32_{endianess}(rawdata, pos)[0]') lines.append(f' size = unpack_int32_{endianess}(rawdata, pos)[0]')
lines.append(' pos += 4') lines.append(' pos += 4')
aligned = 4 aligned = 4
subdesc = desc.args subdesc = desc.args[0]
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':

View File

@ -10,13 +10,12 @@ from rosbags.typesys import types
from .cdr import generate_deserialize_cdr, generate_getsize_cdr, generate_serialize_cdr from .cdr import generate_deserialize_cdr, generate_getsize_cdr, generate_serialize_cdr
from .ros1 import generate_ros1_to_cdr from .ros1 import generate_ros1_to_cdr
from .typing import Field, Msgdef from .typing import Descriptor, Field, Msgdef
from .utils import Descriptor, Valtype from .utils import Valtype
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Dict from typing import Any, Dict
MSGDEFCACHE: Dict[str, Msgdef] = {} MSGDEFCACHE: Dict[str, Msgdef] = {}
@ -37,7 +36,7 @@ def get_msgdef(typename: str) -> Msgdef:
""" """
if typename not in MSGDEFCACHE: if typename not in MSGDEFCACHE:
entries = types.FIELDDEFS[typename] entries = types.FIELDDEFS[typename][1]
def fixup(entry: Any) -> Descriptor: def fixup(entry: Any) -> Descriptor:
if entry[0] == Valtype.BASE: if entry[0] == Valtype.BASE:
@ -45,9 +44,9 @@ def get_msgdef(typename: str) -> Msgdef:
if entry[0] == Valtype.MESSAGE: if entry[0] == Valtype.MESSAGE:
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1])) return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
if entry[0] == Valtype.ARRAY: if entry[0] == Valtype.ARRAY:
return Descriptor(Valtype.ARRAY, (entry[1], fixup(entry[2]))) return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
if entry[0] == Valtype.SEQUENCE: if entry[0] == Valtype.SEQUENCE:
return Descriptor(Valtype.SEQUENCE, fixup(entry[1])) return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
raise SerdeError( # pragma: no cover raise SerdeError( # pragma: no cover
f'Unknown field type {entry[0]!r} encountered.', f'Unknown field type {entry[0]!r} encountered.',
) )

View File

@ -22,12 +22,12 @@ if TYPE_CHECKING:
def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable: def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable:
"""Generate CDR serialization function. """Generate ROS1 to CDR conversion function.
Args: Args:
fields: Fields of message. fields: Fields of message.
typename: Message type name. typename: Message type name.
copy: Generate serialization or sizing function. copy: Generate conversion or sizing function.
Returns: Returns:
ROS1 to CDR conversion function. ROS1 to CDR conversion function.
@ -42,17 +42,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
'import sys', 'import sys',
'import numpy', 'import numpy',
'from rosbags.serde.messages import SerdeError, get_msgdef', 'from rosbags.serde.messages import SerdeError, get_msgdef',
'from rosbags.serde.primitives import pack_bool_le',
'from rosbags.serde.primitives import pack_int8_le',
'from rosbags.serde.primitives import pack_int16_le',
'from rosbags.serde.primitives import pack_int32_le', 'from rosbags.serde.primitives import pack_int32_le',
'from rosbags.serde.primitives import pack_int64_le',
'from rosbags.serde.primitives import pack_uint8_le',
'from rosbags.serde.primitives import pack_uint16_le',
'from rosbags.serde.primitives import pack_uint32_le',
'from rosbags.serde.primitives import pack_uint64_le',
'from rosbags.serde.primitives import pack_float32_le',
'from rosbags.serde.primitives import pack_float64_le',
'from rosbags.serde.primitives import unpack_int32_le', 'from rosbags.serde.primitives import unpack_int32_le',
f'def {funcname}(input, ipos, output, opos):', f'def {funcname}(input, ipos, output, opos):',
] ]
@ -89,11 +79,11 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
aligned = size aligned = size
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
subdesc = desc.args[1] subdesc, length = desc.args
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:
if subdesc.args == 'string': if subdesc.args == 'string':
for _ in range(desc.args[0]): for _ in range(length):
lines.append(' opos = (opos + 4 - 1) & -4') lines.append(' opos = (opos + 4 - 1) & -4')
lines.append(' length = unpack_int32_le(input, ipos)[0] + 1') lines.append(' length = unpack_int32_le(input, ipos)[0] + 1')
if copy: if copy:
@ -108,7 +98,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
lines.append(' opos += length') lines.append(' opos += length')
aligned = 1 aligned = 1
else: else:
size = desc.args[0] * SIZEMAP[subdesc.args] size = length * SIZEMAP[subdesc.args]
if copy: if copy:
lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]') lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]')
lines.append(f' ipos += {size}') lines.append(f' ipos += {size}')
@ -120,7 +110,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
anext_after = align_after(subdesc) anext_after = align_after(subdesc)
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}') lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
for _ in range(desc.args[0]): for _ in range(length):
if anext > anext_after: if anext > anext_after:
lines.append(f' opos = (opos + {anext} - 1) & -{anext}') lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
lines.append(' ipos, opos = func(input, ipos, output, opos)') lines.append(' ipos, opos = func(input, ipos, output, opos)')
@ -132,7 +122,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
lines.append(' pack_int32_le(output, opos, size)') lines.append(' pack_int32_le(output, opos, size)')
lines.append(' ipos += 4') lines.append(' ipos += 4')
lines.append(' opos += 4') lines.append(' opos += 4')
subdesc = desc.args subdesc = desc.args[0]
aligned = 4 aligned = 4
if subdesc.valtype == Valtype.BASE: if subdesc.valtype == Valtype.BASE:

View File

@ -7,9 +7,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Callable, List # pylint: disable=ungrouped-imports from typing import Any, Callable, List
from .utils import Descriptor
class Descriptor(NamedTuple):
"""Value type descriptor."""
valtype: int
args: Any
class Field(NamedTuple): class Field(NamedTuple):

View File

@ -6,11 +6,13 @@ from __future__ import annotations
from enum import IntEnum from enum import IntEnum
from importlib.util import module_from_spec, spec_from_loader from importlib.util import module_from_spec, spec_from_loader
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from types import ModuleType from types import ModuleType
from typing import Any, Dict, List from typing import Dict, List
from .typing import Descriptor
class Valtype(IntEnum): class Valtype(IntEnum):
@ -22,13 +24,6 @@ class Valtype(IntEnum):
SEQUENCE = 4 SEQUENCE = 4
class Descriptor(NamedTuple):
"""Value type descriptor."""
valtype: Valtype
args: Any # Union[Descriptor, Msgdef, Tuple[int, Descriptor], str]
SIZEMAP: Dict[str, int] = { SIZEMAP: Dict[str, int] = {
'bool': 1, 'bool': 1,
'int8': 1, 'int8': 1,
@ -61,7 +56,7 @@ def align(entry: Descriptor) -> int:
if entry.valtype == Valtype.MESSAGE: if entry.valtype == Valtype.MESSAGE:
return align(entry.args.fields[0].descriptor) return align(entry.args.fields[0].descriptor)
if entry.valtype == Valtype.ARRAY: if entry.valtype == Valtype.ARRAY:
return align(entry.args[1]) return align(entry.args[0])
assert entry.valtype == Valtype.SEQUENCE assert entry.valtype == Valtype.SEQUENCE
return 4 return 4
@ -83,9 +78,9 @@ def align_after(entry: Descriptor) -> int:
if entry.valtype == Valtype.MESSAGE: if entry.valtype == Valtype.MESSAGE:
return align_after(entry.args.fields[-1].descriptor) return align_after(entry.args.fields[-1].descriptor)
if entry.valtype == Valtype.ARRAY: if entry.valtype == Valtype.ARRAY:
return align_after(entry.args[1]) return align_after(entry.args[0])
assert entry.valtype == Valtype.SEQUENCE assert entry.valtype == Valtype.SEQUENCE
return min([4, align_after(entry.args)]) return min([4, align_after(entry.args[0])])
def compile_lines(lines: List[str]) -> ModuleType: def compile_lines(lines: List[str]) -> ModuleType:

View File

@ -8,12 +8,14 @@ from enum import IntEnum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from .peg import Visitor from .peg import Visitor
Fielddefs = List[Tuple[Any, Any]] Constdefs = List[Tuple[str, str, Any]]
Typesdict = Dict[str, Fielddefs] Fielddesc = Tuple[int, Union[str, Tuple[Tuple[int, str], Optional[int]]]]
Fielddefs = List[Tuple[str, Fielddesc]]
Typesdict = Dict[str, Tuple[Constdefs, Fielddefs]]
class TypesysError(Exception): class TypesysError(Exception):

View File

@ -261,8 +261,20 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
def visit_specification(self, children: Any) -> Typesdict: def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol, return only children of modules.""" """Process start symbol, return only children of modules."""
children = [x[0] for x in children if x is not None] children = [x[0] for x in children if x is not None]
modules = [y for t, x in children if t == Nodetype.MODULE for y in x] structs = {}
return {x[1]: x[2] for x in modules if x[0] == Nodetype.STRUCT} consts: dict[str, list[tuple[str, str, Any]]] = {}
for item in children:
if item[0] != Nodetype.MODULE:
continue
for subitem in item[1]:
if subitem[0] == Nodetype.STRUCT:
structs[subitem[1]] = subitem[2]
elif subitem[0] == Nodetype.CONST and '_Constants/' in subitem[1][1]:
structname, varname = subitem[1][1].split('_Constants/')
if structname not in consts:
consts[structname] = []
consts[structname].append((varname, subitem[1][0], subitem[1][2]))
return {k: (consts.get(k, []), v) for k, v in structs.items()}
def visit_comment(self, children: Any) -> Any: def visit_comment(self, children: Any) -> Any:
"""Process comment, suppress output.""" """Process comment, suppress output."""
@ -273,12 +285,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
def visit_include(self, children: Any) -> Any: def visit_include(self, children: Any) -> Any:
"""Process include, suppress output.""" """Process include, suppress output."""
def visit_type_dcl(self, children: Any) -> Any:
"""Process typedef, pass structs, suppress otherwise."""
if children[0] == Nodetype.STRUCT:
return children
return None
def visit_module_dcl(self, children: Any) -> Any: def visit_module_dcl(self, children: Any) -> Any:
"""Process module declaration.""" """Process module declaration."""
assert len(children) == 6 assert len(children) == 6
@ -288,7 +294,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
children = children[4] children = children[4]
consts = [] consts = []
structs = [] structs = []
modules = []
for item in children: for item in children:
if not item or item[0] is None: if not item or item[0] is None:
continue continue
@ -299,20 +304,23 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
structs.append(item) structs.append(item)
else: else:
assert item[0] == Nodetype.MODULE assert item[0] == Nodetype.MODULE
modules.append(item) consts += [x for x in item[1] if x[0] == Nodetype.CONST]
structs += [x for x in item[1] if x[0] == Nodetype.STRUCT]
for _, module in modules: consts = [(x[0], (x[1][0], f'{name}/{x[1][1]}', x[1][2])) for x in consts]
consts += [x for x in module if x[0] == Nodetype.CONST]
structs += [x for x in module if x[0] == Nodetype.STRUCT]
consts = [(x[0], f'{name}/{x[1][0]}', *x[1][1:]) for x in consts]
structs = [(x[0], f'{name}/{x[1]}', *x[2:]) for x in structs] structs = [(x[0], f'{name}/{x[1]}', *x[2:]) for x in structs]
return (Nodetype.MODULE, consts + structs) return (Nodetype.MODULE, consts + structs)
def visit_const_dcl(self, children: Any) -> Any: def visit_const_dcl(self, children: Any) -> Any:
"""Process const declaration.""" """Process const declaration."""
return (Nodetype.CONST, (children[1][1], *children[2:])) return (Nodetype.CONST, (children[1][1], children[2][1], children[4][1]))
def visit_type_dcl(self, children: Any) -> Any:
"""Process type, pass structs, suppress otherwise."""
if children[0] == Nodetype.STRUCT:
return children
return None
def visit_type_declarator(self, children: Any) -> Any: def visit_type_declarator(self, children: Any) -> Any:
"""Process type declarator, register type mapping in instance typedef dictionary.""" """Process type declarator, register type mapping in instance typedef dictionary."""
@ -323,7 +331,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
declarators = [children[1][0], *[x[1:][0] for x in children[1][1]]] declarators = [children[1][0], *[x[1:][0] for x in children[1][1]]]
for declarator in declarators: for declarator in declarators:
if declarator[0] == Nodetype.ADECLARATOR: if declarator[0] == Nodetype.ADECLARATOR:
value = (Nodetype.ARRAY, declarator[2][1], base) value = (Nodetype.ARRAY, (base, declarator[2][1]))
else: else:
value = base value = base
self.typedefs[declarator[1][1]] = value self.typedefs[declarator[1][1]] = value
@ -333,8 +341,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
assert len(children) in [4, 6] assert len(children) in [4, 6]
if len(children) == 6: if len(children) == 6:
assert children[4][0] == Nodetype.LITERAL_NUMBER assert children[4][0] == Nodetype.LITERAL_NUMBER
return (Nodetype.SEQUENCE, children[2]) return (Nodetype.SEQUENCE, (children[2], None))
return (Nodetype.SEQUENCE, children[2])
def create_struct_field(self, parts: Any) -> Any: def create_struct_field(self, parts: Any) -> Any:
"""Create struct field and expand typedefs.""" """Create struct field and expand typedefs."""
@ -346,7 +353,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
name = self.typedefs[name[1]] name = self.typedefs[name[1]]
return name return name
yield from ((resolve_name(typename), x[1]) for x in params if x) yield from ((x[1][1], resolve_name(typename)) for x in params if x)
def visit_struct_dcl(self, children: Any) -> Any: def visit_struct_dcl(self, children: Any) -> Any:
"""Process struct declaration.""" """Process struct declaration."""

View File

@ -21,7 +21,7 @@ from .peg import Rule, Visitor, parse_grammar
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, List from typing import Any, List
from .base import Typesdict from .base import Fielddesc, Typesdict
GRAMMAR_MSG = r""" GRAMMAR_MSG = r"""
specification specification
@ -42,7 +42,7 @@ comment
= r'#[^\n]*' = r'#[^\n]*'
const_dcl const_dcl
= type_spec identifier '=' r'[^=][^\n]*' = type_spec identifier '=' integer_literal
field_dcl field_dcl
= type_spec identifier = type_spec identifier
@ -89,7 +89,7 @@ def normalize_msgtype(name: str) -> str:
return str(path) return str(path)
def normalize_fieldtype(typename: str, field: Any, names: List[str]): def normalize_fieldtype(typename: str, field: Fielddesc, names: List[str]) -> Fielddesc:
"""Normalize field typename. """Normalize field typename.
Args: Args:
@ -97,18 +97,20 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
field: Field definition. field: Field definition.
names: Valid message names. names: Valid message names.
Returns:
Normalized fieldtype.
""" """
dct = {Path(name).name: name for name in names} dct = {Path(name).name: name for name in names}
namedef = field[0] ftype, args = field
if namedef[0] == Nodetype.NAME: if ftype == Nodetype.NAME:
name = namedef[1] name = args
elif namedef[0] == Nodetype.SEQUENCE:
name = namedef[1][1]
else: else:
name = namedef[2][1] name = args[0][1]
assert isinstance(name, str)
if name in VisitorMSG.BASETYPES: if name in VisitorMSG.BASETYPES:
inamedef = (Nodetype.BASE, name) ifield = (Nodetype.BASE, name)
else: else:
if name in dct: if name in dct:
name = dct[name] name = dct[name]
@ -118,16 +120,13 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
name = str(Path(typename).parent / name) name = str(Path(typename).parent / name)
elif '/msg/' not in name: elif '/msg/' not in name:
name = str((path := Path(name)).parent / 'msg' / path.name) name = str((path := Path(name)).parent / 'msg' / path.name)
inamedef = (Nodetype.NAME, name) ifield = (Nodetype.NAME, name)
if namedef[0] == Nodetype.NAME: if ftype == Nodetype.NAME:
namedef = inamedef return ifield
elif namedef[0] == Nodetype.SEQUENCE:
namedef = (Nodetype.SEQUENCE, inamedef)
else:
namedef = (Nodetype.ARRAY, namedef[1], inamedef)
field[0] = namedef assert not isinstance(args, str)
return (ftype, (ifield, args[1]))
class VisitorMSG(Visitor): class VisitorMSG(Visitor):
@ -157,6 +156,7 @@ class VisitorMSG(Visitor):
def visit_const_dcl(self, children: Any) -> Any: def visit_const_dcl(self, children: Any) -> Any:
"""Process const declaration, suppress output.""" """Process const declaration, suppress output."""
return Nodetype.CONST, (children[0][1], children[1][1], children[3])
def visit_specification(self, children: Any) -> Typesdict: def visit_specification(self, children: Any) -> Typesdict:
"""Process start symbol.""" """Process start symbol."""
@ -164,8 +164,10 @@ class VisitorMSG(Visitor):
typedict = dict(typelist) typedict = dict(typelist)
names = list(typedict.keys()) names = list(typedict.keys())
for name, fields in typedict.items(): for name, fields in typedict.items():
for field in fields: consts = [(x[1][1], x[1][0], x[1][2]) for x in fields if x[0] == Nodetype.CONST]
normalize_fieldtype(name, field, names) fields = [x for x in fields if x[0] != Nodetype.CONST]
fields = [(field[1][1], normalize_fieldtype(name, field[0], names)) for field in fields]
typedict[name] = consts, fields
return typedict return typedict
def visit_msgdef(self, children: Any) -> Any: def visit_msgdef(self, children: Any) -> Any:
@ -180,8 +182,8 @@ class VisitorMSG(Visitor):
"""Process array type specifier.""" """Process array type specifier."""
length = children[1][1] length = children[1][1]
if length: if length:
return (Nodetype.ARRAY, int(length[0]), children[0]) return Nodetype.ARRAY, (children[0], length[0])
return (Nodetype.SEQUENCE, children[0]) return Nodetype.SEQUENCE, (children[0], None)
def visit_simple_type_spec(self, children: Any) -> Any: def visit_simple_type_spec(self, children: Any) -> Any:
"""Process simple type specifier.""" """Process simple type specifier."""
@ -204,6 +206,10 @@ class VisitorMSG(Visitor):
"""Process identifier.""" """Process identifier."""
return (Nodetype.NAME, children) return (Nodetype.NAME, children)
def visit_integer_literal(self, children: Any) -> Any:
"""Process integer literal."""
return int(children)
def get_types_from_msg(text: str, name: str) -> Typesdict: def get_types_from_msg(text: str, name: str) -> Typesdict:
"""Get type from msg message definition. """Get type from msg message definition.

View File

@ -4,7 +4,6 @@
from __future__ import annotations from __future__ import annotations
import json
import re import re
import sys import sys
from importlib.util import module_from_spec, spec_from_loader from importlib.util import module_from_spec, spec_from_loader
@ -37,7 +36,7 @@ def get_typehint(desc: tuple) -> str:
if desc[0] == Nodetype.NAME: if desc[0] == Nodetype.NAME:
return desc[1].replace('/', '__') return desc[1].replace('/', '__')
sub = desc[2 if desc[0] == Nodetype.ARRAY else 1] sub = desc[1][0]
if INTLIKE.match(sub[1]): if INTLIKE.match(sub[1]):
typ = 'bool8' if sub[1] == 'bool' else sub[1] typ = 'bool8' if sub[1] == 'bool' else sub[1]
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]' return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
@ -71,21 +70,27 @@ def generate_python_code(typs: Typesdict) -> str:
'from typing import TYPE_CHECKING', 'from typing import TYPE_CHECKING',
'', '',
'if TYPE_CHECKING:', 'if TYPE_CHECKING:',
' from typing import Any', ' from typing import Any, ClassVar',
'', '',
' import numpy', ' import numpy',
'', '',
' from .base import Typesdict',
'', '',
] ]
for name, fields in typs.items(): for name, (consts, fields) in typs.items():
pyname = name.replace('/', '__') pyname = name.replace('/', '__')
lines += [ lines += [
'@dataclass', '@dataclass',
f'class {pyname}:', f'class {pyname}:',
f' """Class for {name}."""', f' """Class for {name}."""',
'', '',
*[f' {fname[1]}: {get_typehint(desc)}' for desc, fname in fields], *[f' {fname}: {get_typehint(desc)}' for fname, desc in fields],
*[
f' {fname}: ClassVar[{get_typehint((1, ftype))}] = {fvalue}'
for fname, ftype, fvalue in consts
],
f' __msgtype__: ClassVar[str] = {name!r}',
] ]
lines += [ lines += [
@ -93,16 +98,20 @@ def generate_python_code(typs: Typesdict) -> str:
'', '',
] ]
lines += ['FIELDDEFS = {'] def get_ftype(ftype: tuple) -> tuple:
for name, fields in typs.items(): if ftype[0] <= 2:
return int(ftype[0]), ftype[1]
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
lines += ['FIELDDEFS: Typesdict = {']
for name, (consts, fields) in typs.items():
pyname = name.replace('/', '__') pyname = name.replace('/', '__')
lines += [ lines += [
f' \'{name}\': [', f' \'{name}\': ([',
*[ *[f' ({fname!r}, {ftype!r}, {fvalue!r}),' for fname, ftype, fvalue in consts],
f' ({repr(fname[1])}, {json.loads(json.dumps(ftype))}),' ' ], [',
for ftype, fname in fields *[f' ({fname!r}, {get_ftype(ftype)!r}),' for fname, ftype in fields],
], ' ]),',
' ],',
] ]
lines += [ lines += [
'}', '}',
@ -127,15 +136,16 @@ def register_types(typs: Typesdict) -> None:
module = module_from_spec(spec) module = module_from_spec(spec)
sys.modules[name] = module sys.modules[name] = module
exec(code, module.__dict__) # pylint: disable=exec-used exec(code, module.__dict__) # pylint: disable=exec-used
fielddefs = module.FIELDDEFS # type: ignore fielddefs: Typesdict = module.FIELDDEFS # type: ignore
for name, fields in fielddefs.items(): for name, (_, fields) in fielddefs.items():
if name == 'std_msgs/msg/Header': if name == 'std_msgs/msg/Header':
continue continue
if have := types.FIELDDEFS.get(name): if have := types.FIELDDEFS.get(name):
have = [(x[0].lower(), x[1]) for x in have] _, have_fields = have
have_fields = [(x[0].lower(), x[1]) for x in have_fields]
fields = [(x[0].lower(), x[1]) for x in fields] fields = [(x[0].lower(), x[1]) for x in fields]
if have != fields: if have_fields != fields:
raise TypesysError(f'Type {name!r} is already present with different definition.') raise TypesysError(f'Type {name!r} is already present with different definition.')
for name in fielddefs.keys() - types.FIELDDEFS.keys(): for name in fielddefs.keys() - types.FIELDDEFS.keys():

File diff suppressed because it is too large Load Diff

View File

@ -161,12 +161,13 @@ def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msg
values.append(num) values.append(num)
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
arr, pos = deserialize_array(rawdata, bmap, pos, *desc.args) subdesc, length = desc.args
arr, pos = deserialize_array(rawdata, bmap, pos, length, subdesc)
values.append(arr) values.append(arr)
elif desc.valtype == Valtype.SEQUENCE: elif desc.valtype == Valtype.SEQUENCE:
size, pos = deserialize_number(rawdata, bmap, pos, 'int32') size, pos = deserialize_number(rawdata, bmap, pos, 'int32')
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args) arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args[0])
values.append(arr) values.append(arr)
return msgdef.cls(*values), pos return msgdef.cls(*values), pos
@ -323,12 +324,12 @@ def serialize_message(
pos = serialize_number(rawdata, bmap, pos, desc.args, val) pos = serialize_number(rawdata, bmap, pos, desc.args, val)
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
pos = serialize_array(rawdata, bmap, pos, desc.args[1], val) pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
elif desc.valtype == Valtype.SEQUENCE: elif desc.valtype == Valtype.SEQUENCE:
size = len(val) size = len(val)
pos = serialize_number(rawdata, bmap, pos, 'int32', size) pos = serialize_number(rawdata, bmap, pos, 'int32', size)
pos = serialize_array(rawdata, bmap, pos, desc.args, val) pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
return pos return pos
@ -397,14 +398,15 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
size += isize size += isize
elif desc.valtype == Valtype.ARRAY: elif desc.valtype == Valtype.ARRAY:
if len(val) != desc.args[0]: subdesc, length = desc.args
raise SerdeError(f'Unexpected array length: {len(val)} != {desc.args[0]}.') if len(val) != length:
size = get_array_size(desc.args[1], val, size) raise SerdeError(f'Unexpected array length: {len(val)} != {length}.')
size = get_array_size(subdesc, val, size)
elif desc.valtype == Valtype.SEQUENCE: elif desc.valtype == Valtype.SEQUENCE:
size = (size + 4 - 1) & -4 size = (size + 4 - 1) & -4
size += 4 size += 4
size = get_array_size(desc.args, val, size) size = get_array_size(desc.args[0], val, size)
return size return size

View File

@ -35,6 +35,7 @@ time time
================================================================================ ================================================================================
MSG: test_msgs/Other MSG: test_msgs/Other
uint64[3] Header uint64[3] Header
uint32 static = 42
""" """
RELSIBLING_MSG = """ RELSIBLING_MSG = """
@ -81,6 +82,11 @@ module test_msgs {
typedef test_msgs::msg::Bar Bar; typedef test_msgs::msg::Bar Bar;
typedef double d4[4]; typedef double d4[4];
module Foo_Constants {
const int32 FOO = 32;
const int64 BAR = 64;
};
@comment(type="text", text="ignore") @comment(type="text", text="ignore")
struct Foo { struct Foo {
std_msgs::msg::Header header; std_msgs::msg::Header header;
@ -102,17 +108,18 @@ def test_parse_msg():
get_types_from_msg('', 'test_msgs/msg/Foo') get_types_from_msg('', 'test_msgs/msg/Foo')
ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo') ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo')
assert 'test_msgs/msg/Foo' in ret assert 'test_msgs/msg/Foo' in ret
fields = ret['test_msgs/msg/Foo'] consts, fields = ret['test_msgs/msg/Foo']
assert fields[0][0][1] == 'std_msgs/msg/Header' assert consts == [('global', 'int32', 42)]
assert fields[0][1][1] == 'header' assert fields[0][0] == 'header'
assert fields[1][0][1] == 'std_msgs/msg/Bool' assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][1][1] == 'bool' assert fields[1][0] == 'bool'
assert fields[2][0][1] == 'test_msgs/msg/Bar' assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][1][1] == 'sibling' assert fields[2][0] == 'sibling'
assert fields[3][0][0] == Nodetype.BASE assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[4][0][0] == Nodetype.SEQUENCE assert fields[3][1][0] == Nodetype.BASE
assert fields[5][0][0] == Nodetype.SEQUENCE assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[6][0][0] == Nodetype.ARRAY assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
def test_parse_multi_msg(): def test_parse_multi_msg():
@ -122,20 +129,23 @@ def test_parse_multi_msg():
assert 'test_msgs/msg/Foo' in ret assert 'test_msgs/msg/Foo' in ret
assert 'std_msgs/msg/Header' in ret assert 'std_msgs/msg/Header' in ret
assert 'test_msgs/msg/Other' in ret assert 'test_msgs/msg/Other' in ret
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header' fields = ret['test_msgs/msg/Foo'][1]
assert ret['test_msgs/msg/Foo'][1][0][1] == 'uint8' assert fields[0][1][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][2][0][1] == 'uint8' assert fields[1][1][1] == 'uint8'
assert fields[2][1][1] == 'uint8'
consts = ret['test_msgs/msg/Other'][0]
assert consts == [('static', 'uint32', 42)]
def test_parse_relative_siblings_msg(): def test_parse_relative_siblings_msg():
"""Test relative siblings with msg parser.""" """Test relative siblings with msg parser."""
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo') ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header' assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['test_msgs/msg/Foo'][1][0][1] == 'test_msgs/msg/Other' assert ret['test_msgs/msg/Foo'][1][1][1][1] == 'test_msgs/msg/Other'
ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo') ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo')
assert ret['rel_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header' assert ret['rel_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
assert ret['rel_msgs/msg/Foo'][1][0][1] == 'rel_msgs/msg/Other' assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
def test_parse_idl(): def test_parse_idl():
@ -145,28 +155,29 @@ def test_parse_idl():
ret = get_types_from_idl(IDL) ret = get_types_from_idl(IDL)
assert 'test_msgs/msg/Foo' in ret assert 'test_msgs/msg/Foo' in ret
fields = ret['test_msgs/msg/Foo'] consts, fields = ret['test_msgs/msg/Foo']
assert fields[0][0][1] == 'std_msgs/msg/Header' assert consts == [('FOO', 'int32', 32), ('BAR', 'int64', 64)]
assert fields[0][1][1] == 'header' assert fields[0][0] == 'header'
assert fields[1][0][1] == 'std_msgs/msg/Bool' assert fields[0][1][1] == 'std_msgs/msg/Header'
assert fields[1][1][1] == 'bool' assert fields[1][0] == 'bool'
assert fields[2][0][1] == 'test_msgs/msg/Bar' assert fields[1][1][1] == 'std_msgs/msg/Bool'
assert fields[2][1][1] == 'sibling' assert fields[2][0] == 'sibling'
assert fields[3][0][0] == Nodetype.BASE assert fields[2][1][1] == 'test_msgs/msg/Bar'
assert fields[4][0][0] == Nodetype.SEQUENCE assert fields[3][1][0] == Nodetype.BASE
assert fields[5][0][0] == Nodetype.SEQUENCE assert fields[4][1][0] == Nodetype.SEQUENCE
assert fields[6][0][0] == Nodetype.ARRAY assert fields[5][1][0] == Nodetype.SEQUENCE
assert fields[6][1][0] == Nodetype.ARRAY
def test_register_types(): def test_register_types():
"""Test type registeration.""" """Test type registeration."""
assert 'foo' not in FIELDDEFS assert 'foo' not in FIELDDEFS
register_types({}) register_types({})
register_types({'foo': [[(1, 'bool'), (2, 'b')]]}) register_types({'foo': [[], [('b', (1, 'bool'))]]})
assert 'foo' in FIELDDEFS assert 'foo' in FIELDDEFS
register_types({'std_msgs/msg/Header': []}) register_types({'std_msgs/msg/Header': [[], []]})
assert len(FIELDDEFS['std_msgs/msg/Header']) == 2 assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
with pytest.raises(TypesysError, match='different definition'): with pytest.raises(TypesysError, match='different definition'):
register_types({'foo': [[(1, 'bool'), (2, 'x')]]}) register_types({'foo': [[], [('x', (1, 'bool'))]]})