From e88241074e334a5ccdfedacb233ee95dda9de4b9 Mon Sep 17 00:00:00 2001 From: Marko Durkovic Date: Mon, 11 Apr 2022 10:46:12 +0200 Subject: [PATCH] Type message definition parsers --- src/rosbags/typesys/idl.py | 279 +++++++++++++++++++++++++++---------- src/rosbags/typesys/msg.py | 115 +++++++++------ src/rosbags/typesys/peg.py | 70 ++++++---- tests/test_parse.py | 11 ++ 4 files changed, 335 insertions(+), 140 deletions(-) diff --git a/src/rosbags/typesys/idl.py b/src/rosbags/typesys/idl.py index 8dc9c6d4..28cfcd85 100644 --- a/src/rosbags/typesys/idl.py +++ b/src/rosbags/typesys/idl.py @@ -14,12 +14,17 @@ from __future__ import annotations from typing import TYPE_CHECKING from .base import Nodetype, parse_message_definition -from .peg import Rule, Visitor, parse_grammar +from .peg import Visitor, parse_grammar if TYPE_CHECKING: - from typing import Any + from typing import Any, Generator, Optional, Tuple, Union - from .base import Typesdict + from .base import Fielddefs, Fielddesc, Typesdict + + StringNode = Tuple[Nodetype, str] + ConstValue = Union[str, bool, int, float] + LiteralMatch = Tuple[str, str] + LiteralNode = Tuple[Nodetype, ConstValue] GRAMMAR_IDL = r""" specification @@ -256,47 +261,79 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods def __init__(self) -> None: """Initialize.""" super().__init__() - self.typedefs: dict[str, tuple[Nodetype, tuple[Any, Any]]] = {} + self.typedefs: dict[str, Fielddesc] = {} - def visit_specification(self, children: Any) -> Typesdict: + # yapf: disable + def visit_specification( + self, + children: tuple[ + Optional[ + tuple[ + tuple[ + Nodetype, + list[tuple[Nodetype, tuple[str, str, ConstValue]]], + list[tuple[Nodetype, str, Fielddefs]], + ], + LiteralMatch, + ], + ], + ], + ) -> Typesdict: """Process start symbol, return only children of modules.""" - children = [x[0] for x in children if x is not None] - structs = {} - consts: dict[str, list[tuple[str, str, Any]]] = {} + structs: dict[str, Fielddefs] = {} + consts: dict[str, list[tuple[str, str, ConstValue]]] = {} for item in children: - if item[0] != Nodetype.MODULE: + if item is None or item[0][0] != Nodetype.MODULE: continue - for subitem in item[1]: - if subitem[0] == Nodetype.STRUCT: - structs[subitem[1]] = subitem[2] - elif subitem[0] == Nodetype.CONST and '_Constants/' in subitem[1][1]: - structname, varname = subitem[1][1].split('_Constants/') + for csubitem in item[0][1]: + assert csubitem[0] == Nodetype.CONST + if '_Constants/' in csubitem[1][1]: + structname, varname = csubitem[1][1].split('_Constants/') if structname not in consts: consts[structname] = [] - consts[structname].append((varname, subitem[1][0], subitem[1][2])) - return {k: (consts.get(k, []), v) for k, v in structs.items()} + consts[structname].append((varname, csubitem[1][0], csubitem[1][2])) - def visit_comment(self, children: Any) -> Any: + for ssubitem in item[0][2]: + assert ssubitem[0] == Nodetype.STRUCT + structs[ssubitem[1]] = ssubitem[2] + if ssubitem[1] not in consts: + consts[ssubitem[1]] = [] + return {k: (consts[k], v) for k, v in structs.items()} + # yapf: enable + + def visit_comment(self, _: str) -> None: """Process comment, suppress output.""" - def visit_macro(self, children: Any) -> Any: + def visit_macro(self, _: Union[LiteralMatch, tuple[LiteralMatch, str]]) -> None: """Process macro, suppress output.""" - def visit_include(self, children: Any) -> Any: + def visit_include( + self, + _: tuple[LiteralMatch, tuple[LiteralMatch, str, LiteralMatch]], + ) -> None: """Process include, suppress output.""" - def visit_module_dcl(self, children: Any) -> Any: + # yapf: disable + def visit_module_dcl( + self, + children: tuple[tuple[()], LiteralMatch, StringNode, LiteralMatch, Any, LiteralMatch], + ) -> tuple[ + Nodetype, + list[tuple[Nodetype, tuple[str, str, ConstValue]]], + list[tuple[Nodetype, str, Fielddefs]], + ]: """Process module declaration.""" assert len(children) == 6 assert children[2][0] == Nodetype.NAME name = children[2][1] - children = children[4] + definitions = children[4] consts = [] structs = [] - for item in children: - if not item or item[0] is None: + for item in definitions: + if item is None or item[0] is None: continue + assert item[1] == ('LITERAL', ';') item = item[0] if item[0] == Nodetype.CONST: consts.append(item) @@ -304,58 +341,98 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods structs.append(item) else: assert item[0] == Nodetype.MODULE - consts += [x for x in item[1] if x[0] == Nodetype.CONST] - structs += [x for x in item[1] if x[0] == Nodetype.STRUCT] + consts += item[1] + structs += item[2] - consts = [(x[0], (x[1][0], f'{name}/{x[1][1]}', x[1][2])) for x in consts] - structs = [(x[0], f'{name}/{x[1]}', *x[2:]) for x in structs] + consts = [(ityp, (typ, f'{name}/{subname}', val)) for ityp, (typ, subname, val) in consts] + structs = [(typ, f'{name}/{subname}', *rest) for typ, subname, *rest in structs] - return (Nodetype.MODULE, consts + structs) + return (Nodetype.MODULE, consts, structs) + # yapf: enable - def visit_const_dcl(self, children: Any) -> Any: + def visit_const_dcl( + self, + children: tuple[LiteralMatch, StringNode, StringNode, LiteralMatch, LiteralNode], + ) -> tuple[Nodetype, tuple[str, str, ConstValue]]: """Process const declaration.""" return (Nodetype.CONST, (children[1][1], children[2][1], children[4][1])) - def visit_type_dcl(self, children: Any) -> Any: + def visit_type_dcl( + self, + children: Optional[tuple[Nodetype, str, Fielddefs]], + ) -> Optional[tuple[Nodetype, str, Fielddefs]]: """Process type, pass structs, suppress otherwise.""" - if children[0] == Nodetype.STRUCT: - return children - return None + return children if children and children[0] == Nodetype.STRUCT else None - def visit_type_declarator(self, children: Any) -> Any: + def visit_typedef_dcl( + self, + children: tuple[LiteralMatch, tuple[StringNode, tuple[Any, ...]]], + ) -> None: """Process type declarator, register type mapping in instance typedef dictionary.""" assert len(children) == 2 - base, declarators = children - if base[1] in self.typedefs: - base = self.typedefs[base[1]] - declarators = [children[1][0], *[x[1:][0] for x in children[1][1]]] - for declarator in declarators: + dclchildren = children[1] + assert len(dclchildren) == 2 + base: Fielddesc + value: Fielddesc + base = typedef if (typedef := self.typedefs.get(dclchildren[0][1])) else dclchildren[0] + flat = [dclchildren[1][0], *[x[1:][0] for x in dclchildren[1][1]]] + for declarator in flat: if declarator[0] == Nodetype.ADECLARATOR: - value = (Nodetype.ARRAY, (base, declarator[2][1])) + typ, name = base + assert isinstance(typ, Nodetype) + assert isinstance(name, str) + assert isinstance(declarator[2][1], int) + value = (Nodetype.ARRAY, ((typ, name), declarator[2][1])) else: value = base self.typedefs[declarator[1][1]] = value - def visit_sequence_type(self, children: Any) -> Any: + def visit_sequence_type( + self, + children: Union[tuple[LiteralMatch, LiteralMatch, StringNode, LiteralMatch], + tuple[LiteralMatch, LiteralMatch, StringNode, LiteralMatch, LiteralNode, + LiteralMatch]], + ) -> tuple[Nodetype, tuple[StringNode, None]]: """Process sequence type specification.""" - assert len(children) in [4, 6] + assert len(children) in {4, 6} if len(children) == 6: - assert children[4][0] == Nodetype.LITERAL_NUMBER + idx = len(children) - 2 + assert children[idx][0] == Nodetype.LITERAL_NUMBER return (Nodetype.SEQUENCE, (children[2], None)) - def create_struct_field(self, parts: Any) -> Any: + # yapf: disable + def create_struct_field( + self, + parts: tuple[ + tuple[()], + Fielddesc, + tuple[ + tuple[Nodetype, StringNode], + tuple[ + tuple[str, tuple[Nodetype, StringNode]], + ..., + ], + ], + LiteralMatch, + ], + ) -> Generator[tuple[str, Fielddesc], None, None]: """Create struct field and expand typedefs.""" typename, params = parts[1:3] - params = [params[0], *[x[1:][0] for x in params[1]]] + flat = [params[0], *[x[1:][0] for x in params[1]]] - def resolve_name(name: Any) -> Any: + def resolve_name(name: Fielddesc) -> Fielddesc: while name[0] == Nodetype.NAME and name[1] in self.typedefs: + assert isinstance(name[1], str) name = self.typedefs[name[1]] return name - yield from ((x[1][1], resolve_name(typename)) for x in params if x) + yield from ((x[1][1], resolve_name(typename)) for x in flat if x) + # yapf: enable - def visit_struct_dcl(self, children: Any) -> Any: + def visit_struct_dcl( + self, + children: tuple[tuple[()], LiteralMatch, StringNode, LiteralMatch, Any, LiteralMatch], + ) -> tuple[Nodetype, str, Any]: """Process struct declaration.""" assert len(children) == 6 assert children[2][0] == Nodetype.NAME @@ -363,27 +440,51 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods fields = [y for x in children[4] for y in self.create_struct_field(x)] return (Nodetype.STRUCT, children[2][1], fields) - def visit_simple_declarator(self, children: Any) -> Any: + def visit_simple_declarator(self, children: StringNode) -> tuple[Nodetype, StringNode]: """Process simple declarator.""" assert len(children) == 2 return (Nodetype.SDECLARATOR, children) - def visit_array_declarator(self, children: Any) -> Any: + def visit_array_declarator( + self, + children: tuple[StringNode, tuple[tuple[LiteralMatch, LiteralNode, LiteralMatch]]], + ) -> tuple[Nodetype, StringNode, LiteralNode]: """Process array declarator.""" assert len(children) == 2 return (Nodetype.ADECLARATOR, children[0], children[1][0][1]) - def visit_annotation(self, children: Any) -> Any: + # yapf: disable + def visit_annotation( + self, + children: tuple[ + LiteralMatch, + StringNode, + tuple[ + tuple[ + LiteralMatch, + tuple[ + tuple[StringNode, LiteralMatch, LiteralNode], + tuple[ + tuple[LiteralMatch, tuple[StringNode, LiteralMatch, LiteralNode]], + ..., + ], + ], + LiteralMatch, + ], + ], + ], + ) -> tuple[Nodetype, str, list[tuple[StringNode, LiteralNode]]]: """Process annotation.""" assert len(children) == 3 assert children[1][0] == Nodetype.NAME params = children[2][0][1] - params = [ - [z for z in y if z[0] != Rule.LIT] for y in [params[0], *[x[1:][0] for x in params[1]]] - ] - return (Nodetype.ANNOTATION, children[1][1], params) + flat = [params[0], *[x[1:][0] for x in params[1]]] + assert all(len(x) == 3 for x in flat) + retparams = [(x[0], x[2]) for x in flat] + return (Nodetype.ANNOTATION, children[1][1], retparams) + # yapf: enable - def visit_base_type_spec(self, children: Any) -> Any: + def visit_base_type_spec(self, children: str) -> StringNode: """Process base type specifier.""" oname = children name = { @@ -394,26 +495,40 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods }.get(oname, oname) return (Nodetype.BASE, name) - def visit_string_type(self, children: Any) -> Any: + def visit_string_type( + self, + children: Union[StringNode, tuple[LiteralMatch, LiteralMatch, LiteralNode, LiteralMatch]], + ) -> Union[StringNode, tuple[Nodetype, str, LiteralNode]]: """Prrocess string type specifier.""" - assert len(children) in [2, 4] - if len(children) == 4: - return (Nodetype.BASE, 'string', children[2]) - return (Nodetype.BASE, 'string') + if len(children) == 2: + return (Nodetype.BASE, 'string') - def visit_scoped_name(self, children: Any) -> Any: + assert len(children) == 4 + assert isinstance(children[0], tuple) + return (Nodetype.BASE, 'string', children[2]) + + def visit_scoped_name( + self, + children: Union[StringNode, tuple[StringNode, LiteralMatch, StringNode]], + ) -> StringNode: """Process scoped name.""" if len(children) == 2: + assert isinstance(children[1], str) return (Nodetype.NAME, children[1]) assert len(children) == 3 + assert isinstance(children[0], tuple) assert children[1][1] == '::' return (Nodetype.NAME, f'{children[0][1]}/{children[2][1]}') - def visit_identifier(self, children: Any) -> Any: + def visit_identifier(self, children: str) -> StringNode: """Process identifier.""" return (Nodetype.NAME, children) - def visit_expression(self, children: Any) -> Any: + def visit_expression( + self, + children: Union[LiteralNode, tuple[LiteralMatch, LiteralNode], + tuple[LiteralNode, LiteralMatch, LiteralNode]], + ) -> Union[LiteralNode, tuple[Nodetype, str, int], tuple[Nodetype, str, int, int]]: """Process expression, literals are assumed to be integers only.""" if children[0] in [ Nodetype.LITERAL_STRING, @@ -422,46 +537,56 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods Nodetype.LITERAL_CHAR, Nodetype.NAME, ]: - return children + assert isinstance(children[1], (str, bool, int, float)) + return (children[0], children[1]) - assert len(children) in [2, 3] + assert isinstance(children[0], tuple) if len(children) == 3: assert isinstance(children[0][1], int) + assert isinstance(children[1][1], str) assert isinstance(children[2][1], int) - return (Nodetype.EXPRESSION_BINARY, children[1], children[0][1], children[2]) + return (Nodetype.EXPRESSION_BINARY, children[1][1], children[0][1], children[2][1]) assert len(children) == 2 - assert isinstance(children[1][1], int), children - return (Nodetype.EXPRESSION_UNARY, children[0][1], children[1]) + assert isinstance(children[0][1], str) + assert isinstance(children[1], tuple) + assert isinstance(children[1][1], int) + return (Nodetype.EXPRESSION_UNARY, children[0][1], children[1][1]) - def visit_boolean_literal(self, children: Any) -> Any: + def visit_boolean_literal(self, children: str) -> LiteralNode: """Process boolean literal.""" return (Nodetype.LITERAL_BOOLEAN, children[1] == 'TRUE') - def visit_float_literal(self, children: Any) -> Any: + def visit_float_literal(self, children: str) -> LiteralNode: """Process float literal.""" return (Nodetype.LITERAL_NUMBER, float(children)) - def visit_decimal_literal(self, children: Any) -> Any: + def visit_decimal_literal(self, children: str) -> LiteralNode: """Process decimal integer literal.""" return (Nodetype.LITERAL_NUMBER, int(children)) - def visit_octal_literal(self, children: Any) -> Any: + def visit_octal_literal(self, children: str) -> LiteralNode: """Process octal integer literal.""" return (Nodetype.LITERAL_NUMBER, int(children, 8)) - def visit_hexadecimal_literal(self, children: Any) -> Any: + def visit_hexadecimal_literal(self, children: str) -> LiteralNode: """Process hexadecimal integer literal.""" return (Nodetype.LITERAL_NUMBER, int(children, 16)) - def visit_character_literal(self, children: Any) -> Any: + def visit_character_literal( + self, + children: tuple[LiteralMatch, str, LiteralMatch], + ) -> StringNode: """Process char literal.""" return (Nodetype.LITERAL_CHAR, children[1]) - def visit_string_literals(self, children: Any) -> Any: + def visit_string_literals( + self, + children: tuple[tuple[LiteralMatch, str, LiteralMatch], ...], + ) -> StringNode: """Process string literal.""" return ( Nodetype.LITERAL_STRING, - ''.join(y for x in children for y in x if y and y[0] != Rule.LIT), + ''.join(x[1] for x in children), ) diff --git a/src/rosbags/typesys/msg.py b/src/rosbags/typesys/msg.py index 42b583a9..3b9a110f 100644 --- a/src/rosbags/typesys/msg.py +++ b/src/rosbags/typesys/msg.py @@ -21,9 +21,16 @@ from .peg import Rule, Visitor, parse_grammar from .types import FIELDDEFS if TYPE_CHECKING: - from typing import Any + from typing import Optional, Tuple, TypeVar, Union - from .base import Fielddesc, Typesdict + from .base import Constdefs, Fielddefs, Fielddesc, Typesdict + + T = TypeVar('T') + + StringNode = Tuple[Nodetype, str] + ConstValue = Union[str, bool, int, float] + Msgdesc = Tuple[Tuple[StringNode, Tuple[str, str, int], str], ...] + LiteralMatch = Tuple[str, str] GRAMMAR_MSG = r""" specification @@ -44,7 +51,7 @@ comment = r'#[^\n]*' const_dcl - = 'string' identifier r'=(?!={79}\n)' r'[^\n]+' + = 'string' identifier '=' r'(?!={79}\n)[^\n]+' / type_spec identifier '=' float_literal / type_spec identifier '=' integer_literal / type_spec identifier '=' boolean_literal @@ -158,10 +165,7 @@ def normalize_fieldtype(typename: str, field: Fielddesc, names: list[str]) -> Fi """ dct = {Path(name).name: name for name in names} ftype, args = field - if ftype == Nodetype.NAME: - name = args - else: - name = args[0][1] + name = args if ftype == Nodetype.NAME else args[0][1] assert isinstance(name, str) if name in VisitorMSG.BASETYPES: @@ -220,52 +224,82 @@ class VisitorMSG(Visitor): 'string', } - def visit_comment(self, children: Any) -> Any: + def visit_comment(self, _: str) -> None: """Process comment, suppress output.""" - def visit_const_dcl(self, children: Any) -> Any: + def visit_const_dcl( + self, + children: tuple[StringNode, StringNode, LiteralMatch, ConstValue], + ) -> tuple[StringNode, tuple[str, str, ConstValue]]: """Process const declaration, suppress output.""" - typ = children[0][1] - if typ == 'string': + value: Union[str, bool, int, float] + if (typ := children[0][1]) == 'string': + assert isinstance(children[3], str) value = children[3].strip() else: value = children[3] - return Nodetype.CONST, (typ, children[1][1], value) + return (Nodetype.CONST, ''), (typ, children[1][1], value) - def visit_specification(self, children: Any) -> Typesdict: + def visit_specification( + self, + children: tuple[tuple[str, Msgdesc], tuple[tuple[str, tuple[str, Msgdesc]], ...]], + ) -> Typesdict: """Process start symbol.""" typelist = [children[0], *[x[1] for x in children[1]]] typedict = dict(typelist) names = list(typedict.keys()) - for name, fields in typedict.items(): - consts = [(x[1][1], x[1][0], x[1][2]) for x in fields if x[0] == Nodetype.CONST] - fields = [x for x in fields if x[0] != Nodetype.CONST] - fields = [(field[1][1], normalize_fieldtype(name, field[0], names)) for field in fields] - typedict[name] = consts, fields - return typedict + res: Typesdict = {} + for name, items in typedict.items(): + consts: Constdefs = [ + (x[1][1], x[1][0], x[1][2]) for x in items if x[0] == (Nodetype.CONST, '') + ] + fields: Fielddefs = [ + (field[1][1], normalize_fieldtype(name, field[0], names)) + for field in items + if field[0] != (Nodetype.CONST, '') + ] + res[name] = consts, fields + return res - def visit_msgdef(self, children: Any) -> Any: + def visit_msgdef( + self, + children: tuple[str, StringNode, tuple[Optional[T]]], + ) -> tuple[str, tuple[T, ...]]: """Process single message definition.""" assert len(children) == 3 - return normalize_msgtype(children[1][1]), [x for x in children[2] if x is not None] + return normalize_msgtype(children[1][1]), tuple(x for x in children[2] if x is not None) - def visit_msgsep(self, children: Any) -> Any: + def visit_msgsep(self, _: str) -> None: """Process message separator, suppress output.""" - def visit_array_type_spec(self, children: Any) -> Any: + def visit_array_type_spec( + self, + children: tuple[StringNode, tuple[LiteralMatch, tuple[int, ...], LiteralMatch]], + ) -> tuple[Nodetype, tuple[StringNode, Optional[int]]]: """Process array type specifier.""" - length = children[1][1] - if length: + if length := children[1][1]: return Nodetype.ARRAY, (children[0], length[0]) return Nodetype.SEQUENCE, (children[0], None) - def visit_bounded_array_type_spec(self, children: Any) -> Any: + def visit_bounded_array_type_spec( + self, + children: list[StringNode], + ) -> tuple[Nodetype, tuple[StringNode, None]]: """Process bounded array type specifier.""" return Nodetype.SEQUENCE, (children[0], None) - def visit_simple_type_spec(self, children: Any) -> Any: + def visit_simple_type_spec( + self, + children: Union[StringNode, tuple[LiteralMatch, LiteralMatch, int]], + ) -> StringNode: """Process simple type specifier.""" - typespec = children[0][1] if ('LITERAL', '<=') in children else children[1] + if len(children) > 2: + assert (Rule.LIT, '<=') in children + assert isinstance(children[0], tuple) + typespec = children[0][1] + else: + assert isinstance(children[1], str) + typespec = children[1] dct = { 'time': 'builtin_interfaces/msg/Time', 'duration': 'builtin_interfaces/msg/Duration', @@ -274,38 +308,41 @@ class VisitorMSG(Visitor): } return Nodetype.NAME, dct.get(typespec, typespec) - def visit_scoped_name(self, children: Any) -> Any: + def visit_scoped_name( + self, + children: Union[StringNode, tuple[StringNode, LiteralMatch, StringNode]], + ) -> StringNode: """Process scoped name.""" if len(children) == 2: - return children + return children # type: ignore assert len(children) == 3 - return (Nodetype.NAME, '/'.join(x[1] for x in children if x[0] != Rule.LIT)) + return (Nodetype.NAME, '/'.join(x[1] for x in children if x[0] != Rule.LIT)) # type: ignore - def visit_identifier(self, children: Any) -> Any: + def visit_identifier(self, children: str) -> StringNode: """Process identifier.""" return (Nodetype.NAME, children) - def visit_boolean_literal(self, children: Any) -> Any: + def visit_boolean_literal(self, children: str) -> bool: """Process boolean literal.""" - return children.lower() in ['true', '1'] + return children.lower() in {'true', '1'} - def visit_float_literal(self, children: Any) -> Any: + def visit_float_literal(self, children: str) -> float: """Process float literal.""" return float(children) - def visit_decimal_literal(self, children: Any) -> Any: + def visit_decimal_literal(self, children: str) -> int: """Process decimal integer literal.""" return int(children) - def visit_octal_literal(self, children: Any) -> Any: + def visit_octal_literal(self, children: str) -> int: """Process octal integer literal.""" return int(children, 8) - def visit_hexadecimal_literal(self, children: Any) -> Any: + def visit_hexadecimal_literal(self, children: str) -> int: """Process hexadecimal integer literal.""" return int(children, 16) - def visit_string_literal(self, children: Any) -> Any: + def visit_string_literal(self, children: str) -> str: """Process integer literal.""" return children[1] diff --git a/src/rosbags/typesys/peg.py b/src/rosbags/typesys/peg.py index 27d7d5ff..bb2b9dba 100644 --- a/src/rosbags/typesys/peg.py +++ b/src/rosbags/typesys/peg.py @@ -14,7 +14,10 @@ import re from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Optional, Pattern, TypeVar, Union + + Tree = Any + T = TypeVar('T') class Rule: @@ -23,7 +26,12 @@ class Rule: LIT = 'LITERAL' WS = re.compile(r'\s+', re.M | re.S) - def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None): + def __init__( + self, + value: Union[str, Pattern[str], Rule, list[Rule]], + rules: dict[str, Rule], + name: Optional[str] = None, + ): """Initialize. Args: @@ -41,14 +49,9 @@ class Rule: match = self.WS.match(text, pos) return match.span()[1] if match else pos - def make_node(self, data: Any) -> Any: + def make_node(self, data: T) -> Union[T, dict[str, Union[str, T]]]: """Make node for parse tree.""" - if self.name: - return { - 'node': self.name, - 'data': data, - } - return data + return {'node': self.name, 'data': data} if self.name else data def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" @@ -58,7 +61,7 @@ class Rule: class RuleLiteral(Rule): """Rule to match string literal.""" - def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None): + def __init__(self, value: str, rules: dict[str, Rule], name: Optional[str] = None): """Initialize. Args: @@ -73,6 +76,7 @@ class RuleLiteral(Rule): def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" value = self.value + assert isinstance(value, str) if text[pos:pos + len(value)] == value: npos = pos + len(value) npos = self.skip_ws(text, npos) @@ -83,7 +87,9 @@ class RuleLiteral(Rule): class RuleRegex(Rule): """Rule to match regular expression.""" - def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None): + value: Pattern[str] + + def __init__(self, value: str, rules: dict[str, Rule], name: Optional[str] = None): """Initialize. Args: @@ -99,7 +105,7 @@ class RuleRegex(Rule): """Apply rule at position.""" match = self.value.match(text, pos) if not match: - return -1, [] + return -1, () npos = self.skip_ws(text, match.span()[1]) return npos, self.make_node(match.group()) @@ -107,6 +113,8 @@ class RuleRegex(Rule): class RuleToken(Rule): """Rule to match token.""" + value: str + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" token = self.rules[self.value] @@ -119,18 +127,22 @@ class RuleToken(Rule): class RuleOneof(Rule): """Rule to match first matching subrule.""" + value: list[Rule] + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" for value in self.value: npos, data = value.parse(text, pos) if npos != -1: return npos, self.make_node(data) - return -1, [] + return -1, () class RuleSequence(Rule): """Rule to match a sequence of subrules.""" + value: list[Rule] + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" data = [] @@ -138,14 +150,16 @@ class RuleSequence(Rule): for value in self.value: npos, node = value.parse(text, npos) if npos == -1: - return -1, [] + return -1, () data.append(node) - return npos, self.make_node(data) + return npos, self.make_node(tuple(data)) class RuleZeroPlus(Rule): """Rule to match zero or more occurences of subrule.""" + value: Rule + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" data: list[Any] = [] @@ -153,7 +167,7 @@ class RuleZeroPlus(Rule): while True: npos, node = self.value.parse(text, lpos) if npos == -1: - return lpos, self.make_node(data) + return lpos, self.make_node(tuple(data)) data.append(node) lpos = npos @@ -161,17 +175,19 @@ class RuleZeroPlus(Rule): class RuleOnePlus(Rule): """Rule to match one or more occurences of subrule.""" + value: Rule + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" npos, node = self.value.parse(text, pos) if npos == -1: - return -1, [] + return -1, () data = [node] lpos = npos while True: npos, node = self.value.parse(text, lpos) if npos == -1: - return lpos, self.make_node(data) + return lpos, self.make_node(tuple(data)) data.append(node) lpos = npos @@ -179,12 +195,14 @@ class RuleOnePlus(Rule): class RuleZeroOne(Rule): """Rule to match zero or one occurence of subrule.""" + value: Rule + def parse(self, text: str, pos: int) -> tuple[int, Any]: """Apply rule at position.""" npos, node = self.value.parse(text, pos) if npos == -1: - return pos, self.make_node([]) - return npos, self.make_node([node]) + return pos, self.make_node(()) + return npos, self.make_node((node,)) class Visitor: # pylint: disable=too-few-public-methods @@ -195,14 +213,17 @@ class Visitor: # pylint: disable=too-few-public-methods def __init__(self) -> None: """Initialize.""" - def visit(self, tree: Any) -> Any: + def visit(self, tree: Tree) -> Tree: """Visit all nodes in parse tree.""" - if isinstance(tree, list): - return [self.visit(x) for x in tree] + if isinstance(tree, tuple): + return tuple(self.visit(x) for x in tree) - if not isinstance(tree, dict): + if isinstance(tree, str): return tree + assert isinstance(tree, dict), tree + assert list(tree.keys()) == ['node', 'data'], tree.keys() + tree['data'] = self.visit(tree['data']) func = getattr(self, f'visit_{tree["node"]}', lambda x: x) return func(tree['data']) @@ -242,6 +263,7 @@ def parse_grammar(grammar: str) -> dict[str, Rule]: while items: tok = items.pop(0) if tok in ['*', '+', '?']: + assert isinstance(stack[-1], Rule) stack[-1] = { '*': RuleZeroPlus, '+': RuleOnePlus, diff --git a/tests/test_parse.py b/tests/test_parse.py index e3453bd6..662aecfd 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -140,6 +140,10 @@ module test_msgs { d4 array; }; }; + + struct Bar { + int i; + }; }; """ @@ -273,6 +277,13 @@ def test_parse_idl() -> None: assert fields[5][1][0] == Nodetype.SEQUENCE assert fields[6][1][0] == Nodetype.ARRAY + assert 'test_msgs/Bar' in ret + consts, fields = ret['test_msgs/Bar'] + assert consts == [] + assert len(fields) == 1 + assert fields[0][0] == 'i' + assert fields[0][1][1] == 'int' + def test_register_types() -> None: """Test type registeration."""