Add const fields to type representations
This commit is contained in:
parent
fa57b16765
commit
03b4d7e5c7
@ -68,19 +68,19 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
|
||||
size += SIZEMAP[desc.args]
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc = desc.args[1]
|
||||
subdesc, length = desc.args
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
lines.append(f' val = message.{fieldname}')
|
||||
for idx in range(desc.args[0]):
|
||||
for idx in range(length):
|
||||
lines.append(' pos = (pos + 4 - 1) & -4')
|
||||
lines.append(f' pos += 4 + len(val[{idx}].encode()) + 1')
|
||||
aligned = 1
|
||||
is_stat = False
|
||||
else:
|
||||
lines.append(f' pos += {desc.args[0] * SIZEMAP[subdesc.args]}')
|
||||
size += desc.args[0] * SIZEMAP[subdesc.args]
|
||||
lines.append(f' pos += {length * SIZEMAP[subdesc.args]}')
|
||||
size += length * SIZEMAP[subdesc.args]
|
||||
|
||||
else:
|
||||
assert subdesc.valtype == Valtype.MESSAGE
|
||||
@ -88,7 +88,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
|
||||
anext_after = align_after(subdesc)
|
||||
|
||||
if subdesc.args.size_cdr:
|
||||
for _ in range(desc.args[0]):
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
size = (size + anext - 1) & -anext
|
||||
@ -97,7 +97,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
|
||||
else:
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").getsize_cdr')
|
||||
lines.append(f' val = message.{fieldname}')
|
||||
for idx in range(desc.args[0]):
|
||||
for idx in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = func(pos, val[{idx}])')
|
||||
@ -107,7 +107,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
|
||||
assert desc.valtype == Valtype.SEQUENCE
|
||||
lines.append(' pos += 4')
|
||||
aligned = 4
|
||||
subdesc = desc.args
|
||||
subdesc = desc.args[0]
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
lines.append(f' for val in message.{fieldname}:')
|
||||
@ -211,13 +211,13 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
aligned = SIZEMAP[desc.args]
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc = desc.args[1]
|
||||
lines.append(f' if len(val) != {desc.args[0]}:')
|
||||
subdesc, length = desc.args
|
||||
lines.append(f' if len(val) != {length}:')
|
||||
lines.append(' raise SerdeError(\'Unexpected array length\')')
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
for idx in range(desc.args[0]):
|
||||
for idx in range(length):
|
||||
lines.append(f' bval = memoryview(val[{idx}].encode())')
|
||||
lines.append(' length = len(bval) + 1')
|
||||
lines.append(' pos = (pos + 4 - 1) & -4')
|
||||
@ -229,7 +229,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
else:
|
||||
if (endianess == 'le') != (sys.byteorder == 'little'):
|
||||
lines.append(' val = val.byteswap()')
|
||||
size = desc.args[0] * SIZEMAP[subdesc.args]
|
||||
size = length * SIZEMAP[subdesc.args]
|
||||
lines.append(f' rawdata[pos:pos + {size}] = val.view(numpy.uint8)')
|
||||
lines.append(f' pos += {size}')
|
||||
|
||||
@ -240,7 +240,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
lines.append(
|
||||
f' func = get_msgdef("{subdesc.args.name}").serialize_cdr_{endianess}',
|
||||
)
|
||||
for idx in range(desc.args[0]):
|
||||
for idx in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' pos = func(rawdata, pos, val[{idx}])')
|
||||
@ -250,7 +250,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
lines.append(f' pack_int32_{endianess}(rawdata, pos, len(val))')
|
||||
lines.append(' pos += 4')
|
||||
aligned = 4
|
||||
subdesc = desc.args
|
||||
subdesc = desc.args[0]
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
@ -350,11 +350,11 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
aligned = SIZEMAP[desc.args]
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc = desc.args[1]
|
||||
subdesc, length = desc.args
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
lines.append(' value = []')
|
||||
for idx in range(desc.args[0]):
|
||||
for idx in range(length):
|
||||
if idx:
|
||||
lines.append(' pos = (pos + 4 - 1) & -4')
|
||||
lines.append(f' length = unpack_int32_{endianess}(rawdata, pos)[0]')
|
||||
@ -365,10 +365,10 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
lines.append(' values.append(value)')
|
||||
aligned = 1
|
||||
else:
|
||||
size = desc.args[0] * SIZEMAP[subdesc.args]
|
||||
size = length * SIZEMAP[subdesc.args]
|
||||
lines.append(
|
||||
f' val = numpy.frombuffer(rawdata, '
|
||||
f'dtype=numpy.{subdesc.args}, count={desc.args[0]}, offset=pos)',
|
||||
f'dtype=numpy.{subdesc.args}, count={length}, offset=pos)',
|
||||
)
|
||||
if (endianess == 'le') != (sys.byteorder == 'little'):
|
||||
lines.append(' val = val.byteswap()')
|
||||
@ -380,7 +380,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
anext_after = align_after(subdesc)
|
||||
lines.append(f' msgdef = get_msgdef("{subdesc.args.name}")')
|
||||
lines.append(' value = []')
|
||||
for _ in range(desc.args[0]):
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' pos = (pos + {anext} - 1) & -{anext}')
|
||||
lines.append(f' obj, pos = msgdef.{funcname}(rawdata, pos, msgdef.cls)')
|
||||
@ -393,7 +393,7 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
|
||||
lines.append(f' size = unpack_int32_{endianess}(rawdata, pos)[0]')
|
||||
lines.append(' pos += 4')
|
||||
aligned = 4
|
||||
subdesc = desc.args
|
||||
subdesc = desc.args[0]
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
|
||||
@ -10,13 +10,12 @@ from rosbags.typesys import types
|
||||
|
||||
from .cdr import generate_deserialize_cdr, generate_getsize_cdr, generate_serialize_cdr
|
||||
from .ros1 import generate_ros1_to_cdr
|
||||
from .typing import Field, Msgdef
|
||||
from .utils import Descriptor, Valtype
|
||||
from .typing import Descriptor, Field, Msgdef
|
||||
from .utils import Valtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
MSGDEFCACHE: Dict[str, Msgdef] = {}
|
||||
|
||||
|
||||
@ -37,7 +36,7 @@ def get_msgdef(typename: str) -> Msgdef:
|
||||
|
||||
"""
|
||||
if typename not in MSGDEFCACHE:
|
||||
entries = types.FIELDDEFS[typename]
|
||||
entries = types.FIELDDEFS[typename][1]
|
||||
|
||||
def fixup(entry: Any) -> Descriptor:
|
||||
if entry[0] == Valtype.BASE:
|
||||
@ -45,9 +44,9 @@ def get_msgdef(typename: str) -> Msgdef:
|
||||
if entry[0] == Valtype.MESSAGE:
|
||||
return Descriptor(Valtype.MESSAGE, get_msgdef(entry[1]))
|
||||
if entry[0] == Valtype.ARRAY:
|
||||
return Descriptor(Valtype.ARRAY, (entry[1], fixup(entry[2])))
|
||||
return Descriptor(Valtype.ARRAY, (fixup(entry[1][0]), entry[1][1]))
|
||||
if entry[0] == Valtype.SEQUENCE:
|
||||
return Descriptor(Valtype.SEQUENCE, fixup(entry[1]))
|
||||
return Descriptor(Valtype.SEQUENCE, (fixup(entry[1][0]), entry[1][1]))
|
||||
raise SerdeError( # pragma: no cover
|
||||
f'Unknown field type {entry[0]!r} encountered.',
|
||||
)
|
||||
|
||||
@ -22,12 +22,12 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable:
|
||||
"""Generate CDR serialization function.
|
||||
"""Generate ROS1 to CDR conversion function.
|
||||
|
||||
Args:
|
||||
fields: Fields of message.
|
||||
typename: Message type name.
|
||||
copy: Generate serialization or sizing function.
|
||||
copy: Generate conversion or sizing function.
|
||||
|
||||
Returns:
|
||||
ROS1 to CDR conversion function.
|
||||
@ -42,17 +42,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
|
||||
'import sys',
|
||||
'import numpy',
|
||||
'from rosbags.serde.messages import SerdeError, get_msgdef',
|
||||
'from rosbags.serde.primitives import pack_bool_le',
|
||||
'from rosbags.serde.primitives import pack_int8_le',
|
||||
'from rosbags.serde.primitives import pack_int16_le',
|
||||
'from rosbags.serde.primitives import pack_int32_le',
|
||||
'from rosbags.serde.primitives import pack_int64_le',
|
||||
'from rosbags.serde.primitives import pack_uint8_le',
|
||||
'from rosbags.serde.primitives import pack_uint16_le',
|
||||
'from rosbags.serde.primitives import pack_uint32_le',
|
||||
'from rosbags.serde.primitives import pack_uint64_le',
|
||||
'from rosbags.serde.primitives import pack_float32_le',
|
||||
'from rosbags.serde.primitives import pack_float64_le',
|
||||
'from rosbags.serde.primitives import unpack_int32_le',
|
||||
f'def {funcname}(input, ipos, output, opos):',
|
||||
]
|
||||
@ -89,11 +79,11 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
|
||||
aligned = size
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
subdesc = desc.args[1]
|
||||
subdesc, length = desc.args
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
if subdesc.args == 'string':
|
||||
for _ in range(desc.args[0]):
|
||||
for _ in range(length):
|
||||
lines.append(' opos = (opos + 4 - 1) & -4')
|
||||
lines.append(' length = unpack_int32_le(input, ipos)[0] + 1')
|
||||
if copy:
|
||||
@ -108,7 +98,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
|
||||
lines.append(' opos += length')
|
||||
aligned = 1
|
||||
else:
|
||||
size = desc.args[0] * SIZEMAP[subdesc.args]
|
||||
size = length * SIZEMAP[subdesc.args]
|
||||
if copy:
|
||||
lines.append(f' output[opos:opos + {size}] = input[ipos:ipos + {size}]')
|
||||
lines.append(f' ipos += {size}')
|
||||
@ -120,7 +110,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
|
||||
anext_after = align_after(subdesc)
|
||||
|
||||
lines.append(f' func = get_msgdef("{subdesc.args.name}").{funcname}')
|
||||
for _ in range(desc.args[0]):
|
||||
for _ in range(length):
|
||||
if anext > anext_after:
|
||||
lines.append(f' opos = (opos + {anext} - 1) & -{anext}')
|
||||
lines.append(' ipos, opos = func(input, ipos, output, opos)')
|
||||
@ -132,7 +122,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
|
||||
lines.append(' pack_int32_le(output, opos, size)')
|
||||
lines.append(' ipos += 4')
|
||||
lines.append(' opos += 4')
|
||||
subdesc = desc.args
|
||||
subdesc = desc.args[0]
|
||||
aligned = 4
|
||||
|
||||
if subdesc.valtype == Valtype.BASE:
|
||||
|
||||
@ -7,9 +7,14 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, List # pylint: disable=ungrouped-imports
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from .utils import Descriptor
|
||||
|
||||
class Descriptor(NamedTuple):
|
||||
"""Value type descriptor."""
|
||||
|
||||
valtype: int
|
||||
args: Any
|
||||
|
||||
|
||||
class Field(NamedTuple):
|
||||
|
||||
@ -6,11 +6,13 @@ from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List
|
||||
|
||||
from .typing import Descriptor
|
||||
|
||||
|
||||
class Valtype(IntEnum):
|
||||
@ -22,13 +24,6 @@ class Valtype(IntEnum):
|
||||
SEQUENCE = 4
|
||||
|
||||
|
||||
class Descriptor(NamedTuple):
|
||||
"""Value type descriptor."""
|
||||
|
||||
valtype: Valtype
|
||||
args: Any # Union[Descriptor, Msgdef, Tuple[int, Descriptor], str]
|
||||
|
||||
|
||||
SIZEMAP: Dict[str, int] = {
|
||||
'bool': 1,
|
||||
'int8': 1,
|
||||
@ -61,7 +56,7 @@ def align(entry: Descriptor) -> int:
|
||||
if entry.valtype == Valtype.MESSAGE:
|
||||
return align(entry.args.fields[0].descriptor)
|
||||
if entry.valtype == Valtype.ARRAY:
|
||||
return align(entry.args[1])
|
||||
return align(entry.args[0])
|
||||
assert entry.valtype == Valtype.SEQUENCE
|
||||
return 4
|
||||
|
||||
@ -83,9 +78,9 @@ def align_after(entry: Descriptor) -> int:
|
||||
if entry.valtype == Valtype.MESSAGE:
|
||||
return align_after(entry.args.fields[-1].descriptor)
|
||||
if entry.valtype == Valtype.ARRAY:
|
||||
return align_after(entry.args[1])
|
||||
return align_after(entry.args[0])
|
||||
assert entry.valtype == Valtype.SEQUENCE
|
||||
return min([4, align_after(entry.args)])
|
||||
return min([4, align_after(entry.args[0])])
|
||||
|
||||
|
||||
def compile_lines(lines: List[str]) -> ModuleType:
|
||||
|
||||
@ -8,12 +8,14 @@ from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .peg import Visitor
|
||||
|
||||
Fielddefs = List[Tuple[Any, Any]]
|
||||
Typesdict = Dict[str, Fielddefs]
|
||||
Constdefs = List[Tuple[str, str, Any]]
|
||||
Fielddesc = Tuple[int, Union[str, Tuple[Tuple[int, str], Optional[int]]]]
|
||||
Fielddefs = List[Tuple[str, Fielddesc]]
|
||||
Typesdict = Dict[str, Tuple[Constdefs, Fielddefs]]
|
||||
|
||||
|
||||
class TypesysError(Exception):
|
||||
|
||||
@ -261,8 +261,20 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
def visit_specification(self, children: Any) -> Typesdict:
|
||||
"""Process start symbol, return only children of modules."""
|
||||
children = [x[0] for x in children if x is not None]
|
||||
modules = [y for t, x in children if t == Nodetype.MODULE for y in x]
|
||||
return {x[1]: x[2] for x in modules if x[0] == Nodetype.STRUCT}
|
||||
structs = {}
|
||||
consts: dict[str, list[tuple[str, str, Any]]] = {}
|
||||
for item in children:
|
||||
if item[0] != Nodetype.MODULE:
|
||||
continue
|
||||
for subitem in item[1]:
|
||||
if subitem[0] == Nodetype.STRUCT:
|
||||
structs[subitem[1]] = subitem[2]
|
||||
elif subitem[0] == Nodetype.CONST and '_Constants/' in subitem[1][1]:
|
||||
structname, varname = subitem[1][1].split('_Constants/')
|
||||
if structname not in consts:
|
||||
consts[structname] = []
|
||||
consts[structname].append((varname, subitem[1][0], subitem[1][2]))
|
||||
return {k: (consts.get(k, []), v) for k, v in structs.items()}
|
||||
|
||||
def visit_comment(self, children: Any) -> Any:
|
||||
"""Process comment, suppress output."""
|
||||
@ -273,12 +285,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
def visit_include(self, children: Any) -> Any:
|
||||
"""Process include, suppress output."""
|
||||
|
||||
def visit_type_dcl(self, children: Any) -> Any:
|
||||
"""Process typedef, pass structs, suppress otherwise."""
|
||||
if children[0] == Nodetype.STRUCT:
|
||||
return children
|
||||
return None
|
||||
|
||||
def visit_module_dcl(self, children: Any) -> Any:
|
||||
"""Process module declaration."""
|
||||
assert len(children) == 6
|
||||
@ -288,7 +294,6 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
children = children[4]
|
||||
consts = []
|
||||
structs = []
|
||||
modules = []
|
||||
for item in children:
|
||||
if not item or item[0] is None:
|
||||
continue
|
||||
@ -299,20 +304,23 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
structs.append(item)
|
||||
else:
|
||||
assert item[0] == Nodetype.MODULE
|
||||
modules.append(item)
|
||||
consts += [x for x in item[1] if x[0] == Nodetype.CONST]
|
||||
structs += [x for x in item[1] if x[0] == Nodetype.STRUCT]
|
||||
|
||||
for _, module in modules:
|
||||
consts += [x for x in module if x[0] == Nodetype.CONST]
|
||||
structs += [x for x in module if x[0] == Nodetype.STRUCT]
|
||||
|
||||
consts = [(x[0], f'{name}/{x[1][0]}', *x[1][1:]) for x in consts]
|
||||
consts = [(x[0], (x[1][0], f'{name}/{x[1][1]}', x[1][2])) for x in consts]
|
||||
structs = [(x[0], f'{name}/{x[1]}', *x[2:]) for x in structs]
|
||||
|
||||
return (Nodetype.MODULE, consts + structs)
|
||||
|
||||
def visit_const_dcl(self, children: Any) -> Any:
|
||||
"""Process const declaration."""
|
||||
return (Nodetype.CONST, (children[1][1], *children[2:]))
|
||||
return (Nodetype.CONST, (children[1][1], children[2][1], children[4][1]))
|
||||
|
||||
def visit_type_dcl(self, children: Any) -> Any:
|
||||
"""Process type, pass structs, suppress otherwise."""
|
||||
if children[0] == Nodetype.STRUCT:
|
||||
return children
|
||||
return None
|
||||
|
||||
def visit_type_declarator(self, children: Any) -> Any:
|
||||
"""Process type declarator, register type mapping in instance typedef dictionary."""
|
||||
@ -323,7 +331,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
declarators = [children[1][0], *[x[1:][0] for x in children[1][1]]]
|
||||
for declarator in declarators:
|
||||
if declarator[0] == Nodetype.ADECLARATOR:
|
||||
value = (Nodetype.ARRAY, declarator[2][1], base)
|
||||
value = (Nodetype.ARRAY, (base, declarator[2][1]))
|
||||
else:
|
||||
value = base
|
||||
self.typedefs[declarator[1][1]] = value
|
||||
@ -333,8 +341,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
assert len(children) in [4, 6]
|
||||
if len(children) == 6:
|
||||
assert children[4][0] == Nodetype.LITERAL_NUMBER
|
||||
return (Nodetype.SEQUENCE, children[2])
|
||||
return (Nodetype.SEQUENCE, children[2])
|
||||
return (Nodetype.SEQUENCE, (children[2], None))
|
||||
|
||||
def create_struct_field(self, parts: Any) -> Any:
|
||||
"""Create struct field and expand typedefs."""
|
||||
@ -346,7 +353,7 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods
|
||||
name = self.typedefs[name[1]]
|
||||
return name
|
||||
|
||||
yield from ((resolve_name(typename), x[1]) for x in params if x)
|
||||
yield from ((x[1][1], resolve_name(typename)) for x in params if x)
|
||||
|
||||
def visit_struct_dcl(self, children: Any) -> Any:
|
||||
"""Process struct declaration."""
|
||||
|
||||
@ -21,7 +21,7 @@ from .peg import Rule, Visitor, parse_grammar
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, List
|
||||
|
||||
from .base import Typesdict
|
||||
from .base import Fielddesc, Typesdict
|
||||
|
||||
GRAMMAR_MSG = r"""
|
||||
specification
|
||||
@ -42,7 +42,7 @@ comment
|
||||
= r'#[^\n]*'
|
||||
|
||||
const_dcl
|
||||
= type_spec identifier '=' r'[^=][^\n]*'
|
||||
= type_spec identifier '=' integer_literal
|
||||
|
||||
field_dcl
|
||||
= type_spec identifier
|
||||
@ -89,7 +89,7 @@ def normalize_msgtype(name: str) -> str:
|
||||
return str(path)
|
||||
|
||||
|
||||
def normalize_fieldtype(typename: str, field: Any, names: List[str]):
|
||||
def normalize_fieldtype(typename: str, field: Fielddesc, names: List[str]) -> Fielddesc:
|
||||
"""Normalize field typename.
|
||||
|
||||
Args:
|
||||
@ -97,18 +97,20 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
|
||||
field: Field definition.
|
||||
names: Valid message names.
|
||||
|
||||
Returns:
|
||||
Normalized fieldtype.
|
||||
|
||||
"""
|
||||
dct = {Path(name).name: name for name in names}
|
||||
namedef = field[0]
|
||||
if namedef[0] == Nodetype.NAME:
|
||||
name = namedef[1]
|
||||
elif namedef[0] == Nodetype.SEQUENCE:
|
||||
name = namedef[1][1]
|
||||
ftype, args = field
|
||||
if ftype == Nodetype.NAME:
|
||||
name = args
|
||||
else:
|
||||
name = namedef[2][1]
|
||||
name = args[0][1]
|
||||
|
||||
assert isinstance(name, str)
|
||||
if name in VisitorMSG.BASETYPES:
|
||||
inamedef = (Nodetype.BASE, name)
|
||||
ifield = (Nodetype.BASE, name)
|
||||
else:
|
||||
if name in dct:
|
||||
name = dct[name]
|
||||
@ -118,16 +120,13 @@ def normalize_fieldtype(typename: str, field: Any, names: List[str]):
|
||||
name = str(Path(typename).parent / name)
|
||||
elif '/msg/' not in name:
|
||||
name = str((path := Path(name)).parent / 'msg' / path.name)
|
||||
inamedef = (Nodetype.NAME, name)
|
||||
ifield = (Nodetype.NAME, name)
|
||||
|
||||
if namedef[0] == Nodetype.NAME:
|
||||
namedef = inamedef
|
||||
elif namedef[0] == Nodetype.SEQUENCE:
|
||||
namedef = (Nodetype.SEQUENCE, inamedef)
|
||||
else:
|
||||
namedef = (Nodetype.ARRAY, namedef[1], inamedef)
|
||||
if ftype == Nodetype.NAME:
|
||||
return ifield
|
||||
|
||||
field[0] = namedef
|
||||
assert not isinstance(args, str)
|
||||
return (ftype, (ifield, args[1]))
|
||||
|
||||
|
||||
class VisitorMSG(Visitor):
|
||||
@ -157,6 +156,7 @@ class VisitorMSG(Visitor):
|
||||
|
||||
def visit_const_dcl(self, children: Any) -> Any:
|
||||
"""Process const declaration, suppress output."""
|
||||
return Nodetype.CONST, (children[0][1], children[1][1], children[3])
|
||||
|
||||
def visit_specification(self, children: Any) -> Typesdict:
|
||||
"""Process start symbol."""
|
||||
@ -164,8 +164,10 @@ class VisitorMSG(Visitor):
|
||||
typedict = dict(typelist)
|
||||
names = list(typedict.keys())
|
||||
for name, fields in typedict.items():
|
||||
for field in fields:
|
||||
normalize_fieldtype(name, field, names)
|
||||
consts = [(x[1][1], x[1][0], x[1][2]) for x in fields if x[0] == Nodetype.CONST]
|
||||
fields = [x for x in fields if x[0] != Nodetype.CONST]
|
||||
fields = [(field[1][1], normalize_fieldtype(name, field[0], names)) for field in fields]
|
||||
typedict[name] = consts, fields
|
||||
return typedict
|
||||
|
||||
def visit_msgdef(self, children: Any) -> Any:
|
||||
@ -180,8 +182,8 @@ class VisitorMSG(Visitor):
|
||||
"""Process array type specifier."""
|
||||
length = children[1][1]
|
||||
if length:
|
||||
return (Nodetype.ARRAY, int(length[0]), children[0])
|
||||
return (Nodetype.SEQUENCE, children[0])
|
||||
return Nodetype.ARRAY, (children[0], length[0])
|
||||
return Nodetype.SEQUENCE, (children[0], None)
|
||||
|
||||
def visit_simple_type_spec(self, children: Any) -> Any:
|
||||
"""Process simple type specifier."""
|
||||
@ -204,6 +206,10 @@ class VisitorMSG(Visitor):
|
||||
"""Process identifier."""
|
||||
return (Nodetype.NAME, children)
|
||||
|
||||
def visit_integer_literal(self, children: Any) -> Any:
|
||||
"""Process integer literal."""
|
||||
return int(children)
|
||||
|
||||
|
||||
def get_types_from_msg(text: str, name: str) -> Typesdict:
|
||||
"""Get type from msg message definition.
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
@ -37,7 +36,7 @@ def get_typehint(desc: tuple) -> str:
|
||||
if desc[0] == Nodetype.NAME:
|
||||
return desc[1].replace('/', '__')
|
||||
|
||||
sub = desc[2 if desc[0] == Nodetype.ARRAY else 1]
|
||||
sub = desc[1][0]
|
||||
if INTLIKE.match(sub[1]):
|
||||
typ = 'bool8' if sub[1] == 'bool' else sub[1]
|
||||
return f'numpy.ndarray[Any, numpy.dtype[numpy.{typ}]]'
|
||||
@ -71,21 +70,27 @@ def generate_python_code(typs: Typesdict) -> str:
|
||||
'from typing import TYPE_CHECKING',
|
||||
'',
|
||||
'if TYPE_CHECKING:',
|
||||
' from typing import Any',
|
||||
' from typing import Any, ClassVar',
|
||||
'',
|
||||
' import numpy',
|
||||
'',
|
||||
' from .base import Typesdict',
|
||||
'',
|
||||
]
|
||||
|
||||
for name, fields in typs.items():
|
||||
for name, (consts, fields) in typs.items():
|
||||
pyname = name.replace('/', '__')
|
||||
lines += [
|
||||
'@dataclass',
|
||||
f'class {pyname}:',
|
||||
f' """Class for {name}."""',
|
||||
'',
|
||||
*[f' {fname[1]}: {get_typehint(desc)}' for desc, fname in fields],
|
||||
*[f' {fname}: {get_typehint(desc)}' for fname, desc in fields],
|
||||
*[
|
||||
f' {fname}: ClassVar[{get_typehint((1, ftype))}] = {fvalue}'
|
||||
for fname, ftype, fvalue in consts
|
||||
],
|
||||
f' __msgtype__: ClassVar[str] = {name!r}',
|
||||
]
|
||||
|
||||
lines += [
|
||||
@ -93,16 +98,20 @@ def generate_python_code(typs: Typesdict) -> str:
|
||||
'',
|
||||
]
|
||||
|
||||
lines += ['FIELDDEFS = {']
|
||||
for name, fields in typs.items():
|
||||
def get_ftype(ftype: tuple) -> tuple:
|
||||
if ftype[0] <= 2:
|
||||
return int(ftype[0]), ftype[1]
|
||||
return int(ftype[0]), ((int(ftype[1][0][0]), ftype[1][0][1]), ftype[1][1])
|
||||
|
||||
lines += ['FIELDDEFS: Typesdict = {']
|
||||
for name, (consts, fields) in typs.items():
|
||||
pyname = name.replace('/', '__')
|
||||
lines += [
|
||||
f' \'{name}\': [',
|
||||
*[
|
||||
f' ({repr(fname[1])}, {json.loads(json.dumps(ftype))}),'
|
||||
for ftype, fname in fields
|
||||
],
|
||||
' ],',
|
||||
f' \'{name}\': ([',
|
||||
*[f' ({fname!r}, {ftype!r}, {fvalue!r}),' for fname, ftype, fvalue in consts],
|
||||
' ], [',
|
||||
*[f' ({fname!r}, {get_ftype(ftype)!r}),' for fname, ftype in fields],
|
||||
' ]),',
|
||||
]
|
||||
lines += [
|
||||
'}',
|
||||
@ -127,15 +136,16 @@ def register_types(typs: Typesdict) -> None:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[name] = module
|
||||
exec(code, module.__dict__) # pylint: disable=exec-used
|
||||
fielddefs = module.FIELDDEFS # type: ignore
|
||||
fielddefs: Typesdict = module.FIELDDEFS # type: ignore
|
||||
|
||||
for name, fields in fielddefs.items():
|
||||
for name, (_, fields) in fielddefs.items():
|
||||
if name == 'std_msgs/msg/Header':
|
||||
continue
|
||||
if have := types.FIELDDEFS.get(name):
|
||||
have = [(x[0].lower(), x[1]) for x in have]
|
||||
_, have_fields = have
|
||||
have_fields = [(x[0].lower(), x[1]) for x in have_fields]
|
||||
fields = [(x[0].lower(), x[1]) for x in fields]
|
||||
if have != fields:
|
||||
if have_fields != fields:
|
||||
raise TypesysError(f'Type {name!r} is already present with different definition.')
|
||||
|
||||
for name in fielddefs.keys() - types.FIELDDEFS.keys():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
18
tests/cdr.py
18
tests/cdr.py
@ -161,12 +161,13 @@ def deserialize_message(rawdata: bytes, bmap: BasetypeMap, pos: int, msgdef: Msg
|
||||
values.append(num)
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, *desc.args)
|
||||
subdesc, length = desc.args
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, length, subdesc)
|
||||
values.append(arr)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size, pos = deserialize_number(rawdata, bmap, pos, 'int32')
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args)
|
||||
arr, pos = deserialize_array(rawdata, bmap, pos, int(size), desc.args[0])
|
||||
values.append(arr)
|
||||
|
||||
return msgdef.cls(*values), pos
|
||||
@ -323,12 +324,12 @@ def serialize_message(
|
||||
pos = serialize_number(rawdata, bmap, pos, desc.args, val)
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args[1], val)
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size = len(val)
|
||||
pos = serialize_number(rawdata, bmap, pos, 'int32', size)
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args, val)
|
||||
pos = serialize_array(rawdata, bmap, pos, desc.args[0], val)
|
||||
|
||||
return pos
|
||||
|
||||
@ -397,14 +398,15 @@ def get_size(message: Any, msgdef: Msgdef, size: int = 0) -> int:
|
||||
size += isize
|
||||
|
||||
elif desc.valtype == Valtype.ARRAY:
|
||||
if len(val) != desc.args[0]:
|
||||
raise SerdeError(f'Unexpected array length: {len(val)} != {desc.args[0]}.')
|
||||
size = get_array_size(desc.args[1], val, size)
|
||||
subdesc, length = desc.args
|
||||
if len(val) != length:
|
||||
raise SerdeError(f'Unexpected array length: {len(val)} != {length}.')
|
||||
size = get_array_size(subdesc, val, size)
|
||||
|
||||
elif desc.valtype == Valtype.SEQUENCE:
|
||||
size = (size + 4 - 1) & -4
|
||||
size += 4
|
||||
size = get_array_size(desc.args, val, size)
|
||||
size = get_array_size(desc.args[0], val, size)
|
||||
|
||||
return size
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ time time
|
||||
================================================================================
|
||||
MSG: test_msgs/Other
|
||||
uint64[3] Header
|
||||
uint32 static = 42
|
||||
"""
|
||||
|
||||
RELSIBLING_MSG = """
|
||||
@ -81,6 +82,11 @@ module test_msgs {
|
||||
typedef test_msgs::msg::Bar Bar;
|
||||
typedef double d4[4];
|
||||
|
||||
module Foo_Constants {
|
||||
const int32 FOO = 32;
|
||||
const int64 BAR = 64;
|
||||
};
|
||||
|
||||
@comment(type="text", text="ignore")
|
||||
struct Foo {
|
||||
std_msgs::msg::Header header;
|
||||
@ -102,17 +108,18 @@ def test_parse_msg():
|
||||
get_types_from_msg('', 'test_msgs/msg/Foo')
|
||||
ret = get_types_from_msg(MSG, 'test_msgs/msg/Foo')
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
fields = ret['test_msgs/msg/Foo']
|
||||
assert fields[0][0][1] == 'std_msgs/msg/Header'
|
||||
assert fields[0][1][1] == 'header'
|
||||
assert fields[1][0][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[1][1][1] == 'bool'
|
||||
assert fields[2][0][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[2][1][1] == 'sibling'
|
||||
assert fields[3][0][0] == Nodetype.BASE
|
||||
assert fields[4][0][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][0][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][0][0] == Nodetype.ARRAY
|
||||
consts, fields = ret['test_msgs/msg/Foo']
|
||||
assert consts == [('global', 'int32', 42)]
|
||||
assert fields[0][0] == 'header'
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][0] == 'bool'
|
||||
assert fields[1][1][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[2][0] == 'sibling'
|
||||
assert fields[2][1][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[3][1][0] == Nodetype.BASE
|
||||
assert fields[4][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
|
||||
def test_parse_multi_msg():
|
||||
@ -122,20 +129,23 @@ def test_parse_multi_msg():
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
assert 'std_msgs/msg/Header' in ret
|
||||
assert 'test_msgs/msg/Other' in ret
|
||||
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
|
||||
assert ret['test_msgs/msg/Foo'][1][0][1] == 'uint8'
|
||||
assert ret['test_msgs/msg/Foo'][2][0][1] == 'uint8'
|
||||
fields = ret['test_msgs/msg/Foo'][1]
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][1][1] == 'uint8'
|
||||
assert fields[2][1][1] == 'uint8'
|
||||
consts = ret['test_msgs/msg/Other'][0]
|
||||
assert consts == [('static', 'uint32', 42)]
|
||||
|
||||
|
||||
def test_parse_relative_siblings_msg():
|
||||
"""Test relative siblings with msg parser."""
|
||||
ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo')
|
||||
assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
|
||||
assert ret['test_msgs/msg/Foo'][1][0][1] == 'test_msgs/msg/Other'
|
||||
assert ret['test_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||
assert ret['test_msgs/msg/Foo'][1][1][1][1] == 'test_msgs/msg/Other'
|
||||
|
||||
ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo')
|
||||
assert ret['rel_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header'
|
||||
assert ret['rel_msgs/msg/Foo'][1][0][1] == 'rel_msgs/msg/Other'
|
||||
assert ret['rel_msgs/msg/Foo'][1][0][1][1] == 'std_msgs/msg/Header'
|
||||
assert ret['rel_msgs/msg/Foo'][1][1][1][1] == 'rel_msgs/msg/Other'
|
||||
|
||||
|
||||
def test_parse_idl():
|
||||
@ -145,28 +155,29 @@ def test_parse_idl():
|
||||
|
||||
ret = get_types_from_idl(IDL)
|
||||
assert 'test_msgs/msg/Foo' in ret
|
||||
fields = ret['test_msgs/msg/Foo']
|
||||
assert fields[0][0][1] == 'std_msgs/msg/Header'
|
||||
assert fields[0][1][1] == 'header'
|
||||
assert fields[1][0][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[1][1][1] == 'bool'
|
||||
assert fields[2][0][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[2][1][1] == 'sibling'
|
||||
assert fields[3][0][0] == Nodetype.BASE
|
||||
assert fields[4][0][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][0][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][0][0] == Nodetype.ARRAY
|
||||
consts, fields = ret['test_msgs/msg/Foo']
|
||||
assert consts == [('FOO', 'int32', 32), ('BAR', 'int64', 64)]
|
||||
assert fields[0][0] == 'header'
|
||||
assert fields[0][1][1] == 'std_msgs/msg/Header'
|
||||
assert fields[1][0] == 'bool'
|
||||
assert fields[1][1][1] == 'std_msgs/msg/Bool'
|
||||
assert fields[2][0] == 'sibling'
|
||||
assert fields[2][1][1] == 'test_msgs/msg/Bar'
|
||||
assert fields[3][1][0] == Nodetype.BASE
|
||||
assert fields[4][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[5][1][0] == Nodetype.SEQUENCE
|
||||
assert fields[6][1][0] == Nodetype.ARRAY
|
||||
|
||||
|
||||
def test_register_types():
|
||||
"""Test type registeration."""
|
||||
assert 'foo' not in FIELDDEFS
|
||||
register_types({})
|
||||
register_types({'foo': [[(1, 'bool'), (2, 'b')]]})
|
||||
register_types({'foo': [[], [('b', (1, 'bool'))]]})
|
||||
assert 'foo' in FIELDDEFS
|
||||
|
||||
register_types({'std_msgs/msg/Header': []})
|
||||
assert len(FIELDDEFS['std_msgs/msg/Header']) == 2
|
||||
register_types({'std_msgs/msg/Header': [[], []]})
|
||||
assert len(FIELDDEFS['std_msgs/msg/Header'][1]) == 2
|
||||
|
||||
with pytest.raises(TypesysError, match='different definition'):
|
||||
register_types({'foo': [[(1, 'bool'), (2, 'x')]]})
|
||||
register_types({'foo': [[], [('x', (1, 'bool'))]]})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user