diff --git a/src/rosbags/typesys/idl.py b/src/rosbags/typesys/idl.py index b9542272..259f7ce5 100644 --- a/src/rosbags/typesys/idl.py +++ b/src/rosbags/typesys/idl.py @@ -253,6 +253,11 @@ class VisitorIDL(Visitor): # pylint: disable=too-many-public-methods RULES = parse_grammar(GRAMMAR_IDL) + def __init__(self): + """Initialize.""" + super().__init__() + self.typedefs = {} + def visit_specification(self, children: Any) -> Typesdict: """Process start symbol, return only children of modules.""" children = [x[0] for x in children if x is not None] diff --git a/src/rosbags/typesys/msg.py b/src/rosbags/typesys/msg.py index 7242d12a..59739c0f 100644 --- a/src/rosbags/typesys/msg.py +++ b/src/rosbags/typesys/msg.py @@ -89,10 +89,11 @@ def normalize_msgtype(name: str) -> str: return str(path) -def normalize_fieldtype(field: Any, names: List[str]): +def normalize_fieldtype(typename: str, field: Any, names: List[str]): """Normalize field typename. Args: + typename: Type name of field owner. field: Field definition. names: Valid message names. @@ -111,9 +112,12 @@ def normalize_fieldtype(field: Any, names: List[str]): else: if name in dct: name = dct[name] + elif name == 'Header': + name = 'std_msgs/msg/Header' + elif '/' not in name: + name = str(Path(typename).parent / name) elif '/msg/' not in name: - ptype = Path(name) - name = str(ptype.parent / 'msg' / ptype.name) + name = str((path := Path(name)).parent / 'msg' / path.name) inamedef = (Nodetype.NAME, name) if namedef[0] == Nodetype.NAME: @@ -159,9 +163,9 @@ class VisitorMSG(Visitor): typelist = [children[0], *[x[1] for x in children[1]]] typedict = dict(typelist) names = list(typedict.keys()) - for _, fields in typedict.items(): + for name, fields in typedict.items(): for field in fields: - normalize_fieldtype(field, names) + normalize_fieldtype(name, field, names) return typedict def visit_msgdef(self, children: Any) -> Any: diff --git a/src/rosbags/typesys/peg.py b/src/rosbags/typesys/peg.py index 95d1e731..92cc229b 100644 --- a/src/rosbags/typesys/peg.py +++ b/src/rosbags/typesys/peg.py @@ -58,10 +58,22 @@ class Rule: class RuleLiteral(Rule): """Rule to match string literal.""" + def __init__(self, value: Any, rules: Dict[str, Rule], name: Optional[str] = None): + """Initialize. + + Args: + value: Value of this rule. + rules: Grammar containing all rules. + name: Name of this rule. + + """ + super().__init__(value, rules, name) + self.value = value[1:-1].replace('\\\'', '\'') + def parse(self, text: str, pos: int) -> Tuple[int, Any]: """Apply rule at position.""" - value: str = self.value[1:-1].replace('\\\'', '\'') - if text[pos:].startswith(value): + value = self.value + if text[pos:pos + len(value)] == value: npos = pos + len(value) npos = self.skip_ws(text, npos) return npos, (self.LIT, value) @@ -71,10 +83,21 @@ class RuleLiteral(Rule): class RuleRegex(Rule): """Rule to match regular expression.""" + def __init__(self, value: Any, rules: Dict[str, Rule], name: Optional[str] = None): + """Initialize. + + Args: + value: Value of this rule. + rules: Grammar containing all rules. + name: Name of this rule. + + """ + super().__init__(value, rules, name) + self.value = re.compile(value[2:-1], re.M | re.S) + def parse(self, text: str, pos: int) -> Tuple[int, Any]: """Apply rule at position.""" - pattern = re.compile(self.value[2:-1], re.M | re.S) - match = pattern.match(text, pos) + match = self.value.match(text, pos) if not match: return -1, [] npos = self.skip_ws(text, match.span()[1]) @@ -171,7 +194,6 @@ class Visitor: # pylint: disable=too-few-public-methods def __init__(self): """Initialize.""" - self.typedefs = {} def visit(self, tree: Any) -> Any: """Visit all nodes in parse tree.""" diff --git a/tests/test_parse.py b/tests/test_parse.py index 014149ce..d5ea00fe 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -37,6 +37,11 @@ MSG: test_msgs/Other uint64[3] Header """ +RELSIBLING_MSG = """ +Header header +Other other +""" + IDL_LANG = """ // assign different literals and expressions @@ -122,6 +127,17 @@ def test_parse_multi_msg(): assert ret['test_msgs/msg/Foo'][2][0][1] == 'uint8' +def test_parse_relative_siblings_msg(): + """Test relative siblings with msg parser.""" + ret = get_types_from_msg(RELSIBLING_MSG, 'test_msgs/msg/Foo') + assert ret['test_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header' + assert ret['test_msgs/msg/Foo'][1][0][1] == 'test_msgs/msg/Other' + + ret = get_types_from_msg(RELSIBLING_MSG, 'rel_msgs/msg/Foo') + assert ret['rel_msgs/msg/Foo'][0][0][1] == 'std_msgs/msg/Header' + assert ret['rel_msgs/msg/Foo'][1][0][1] == 'rel_msgs/msg/Other' + + def test_parse_idl(): """Test idl parser.""" ret = get_types_from_idl(IDL_LANG)