Use built-in collections as generic types

This commit is contained in:
Marko Durkovic 2021-08-06 12:03:29 +02:00 committed by Florian Friesdorf
parent f33e65b14a
commit 5bd1bcbd83
10 changed files with 78 additions and 81 deletions

View File

@ -15,7 +15,7 @@ from rosbags.typesys import get_types_from_msg, register_types
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Optional
from rosbags.rosbag1.reader import Connection as RConnection from rosbags.rosbag1.reader import Connection as RConnection
@ -80,8 +80,8 @@ def convert(src: Path, dst: Optional[Path]) -> None:
try: try:
with Reader(src) as reader, Writer(dst) as writer: with Reader(src) as reader, Writer(dst) as writer:
typs: Dict[str, Any] = {} typs: dict[str, Any] = {}
connmap: Dict[int, WConnection] = {} connmap: dict[int, WConnection] = {}
for rconn in reader.connections.values(): for rconn in reader.connections.values():
candidate = convert_connection(rconn) candidate = convert_connection(rconn)

View File

@ -23,19 +23,7 @@ from rosbags.typesys.msg import normalize_msgtype
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import ( from typing import BinaryIO, Callable, Generator, Iterable, Literal, Optional, Type, Union
BinaryIO,
Callable,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
class ReaderError(Exception): class ReaderError(Exception):
@ -71,7 +59,7 @@ class Connection(NamedTuple):
msgdef: str msgdef: str
callerid: Optional[str] callerid: Optional[str]
latching: Optional[int] latching: Optional[int]
indexes: List indexes: list
class ChunkInfo(NamedTuple): class ChunkInfo(NamedTuple):
@ -80,7 +68,7 @@ class ChunkInfo(NamedTuple):
pos: int pos: int
start_time: int start_time: int
end_time: int end_time: int
connection_counts: Dict[int, int] connection_counts: dict[int, int]
class Chunk(NamedTuple): class Chunk(NamedTuple):
@ -107,11 +95,11 @@ class IndexData(NamedTuple):
chunk_pos: int chunk_pos: int
offset: int offset: int
def __lt__(self, other: Tuple[int, ...]) -> bool: def __lt__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only.""" """Compare by time only."""
return self.time < other[0] return self.time < other[0]
def __le__(self, other: Tuple[int, ...]) -> bool: def __le__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only.""" """Compare by time only."""
return self.time <= other[0] return self.time <= other[0]
@ -121,11 +109,11 @@ class IndexData(NamedTuple):
return NotImplemented return NotImplemented
return self.time == other[0] return self.time == other[0]
def __ge__(self, other: Tuple[int, ...]) -> bool: def __ge__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only.""" """Compare by time only."""
return self.time >= other[0] return self.time >= other[0]
def __gt__(self, other: Tuple[int, ...]) -> bool: def __gt__(self, other: tuple[int, ...]) -> bool:
"""Compare by time only.""" """Compare by time only."""
return self.time > other[0] return self.time > other[0]
@ -371,11 +359,11 @@ class Reader:
raise ReaderError(f'File {str(self.path)!r} does not exist.') raise ReaderError(f'File {str(self.path)!r} does not exist.')
self.bio: Optional[BinaryIO] = None self.bio: Optional[BinaryIO] = None
self.connections: Dict[int, Connection] = {} self.connections: dict[int, Connection] = {}
self.chunk_infos: List[ChunkInfo] = [] self.chunk_infos: list[ChunkInfo] = []
self.chunks: Dict[int, Chunk] = {} self.chunks: dict[int, Chunk] = {}
self.current_chunk = (-1, BytesIO()) self.current_chunk = (-1, BytesIO())
self.topics: Dict[str, TopicInfo] = {} self.topics: dict[str, TopicInfo] = {}
def open(self): # pylint: disable=too-many-branches,too-many-locals def open(self): # pylint: disable=too-many-branches,too-many-locals
"""Open rosbag and read metadata.""" """Open rosbag and read metadata."""
@ -480,7 +468,7 @@ class Reader:
"""Total message count.""" """Total message count."""
return reduce(lambda x, y: x + y, (x.msgcount for x in self.topics.values()), 0) return reduce(lambda x, y: x + y, (x.msgcount for x in self.topics.values()), 0)
def read_connection(self) -> Tuple[int, Connection]: def read_connection(self) -> tuple[int, Connection]:
"""Read connection record from current position.""" """Read connection record from current position."""
assert self.bio assert self.bio
header = Header.read(self.bio, RecordType.CONNECTION) header = Header.read(self.bio, RecordType.CONNECTION)
@ -552,7 +540,7 @@ class Reader:
decompressor, decompressor,
) )
def read_index_data(self, pos: int) -> Tuple[int, List[IndexData]]: def read_index_data(self, pos: int) -> tuple[int, list[IndexData]]:
"""Read index data from position. """Read index data from position.
Args: Args:
@ -576,7 +564,7 @@ class Reader:
self.bio.seek(4, os.SEEK_CUR) self.bio.seek(4, os.SEEK_CUR)
index: List[IndexData] = [] index: list[IndexData] = []
for _ in range(count): for _ in range(count):
time = deserialize_time(self.bio.read(8)) time = deserialize_time(self.bio.read(8))
offset = read_uint32(self.bio) offset = read_uint32(self.bio)
@ -588,7 +576,7 @@ class Reader:
topics: Optional[Iterable[str]] = None, topics: Optional[Iterable[str]] = None,
start: Optional[int] = None, start: Optional[int] = None,
stop: Optional[int] = None, stop: Optional[int] = None,
) -> Generator[Tuple[Connection, int, bytes], None, None]: ) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Read messages from bag. """Read messages from bag.
Args: Args:

View File

@ -17,7 +17,7 @@ from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
class ReaderError(Exception): class ReaderError(Exception):
@ -162,7 +162,7 @@ class Reader:
return mode if mode != 'none' else None return mode if mode != 'none' else None
@property @property
def topics(self) -> Dict[str, Connection]: def topics(self) -> dict[str, Connection]:
"""Topic information. """Topic information.
For the moment this a dictionary mapping topic names to connections. For the moment this a dictionary mapping topic names to connections.
@ -175,7 +175,7 @@ class Reader:
connections: Iterable[Connection] = (), connections: Iterable[Connection] = (),
start: Optional[int] = None, start: Optional[int] = None,
stop: Optional[int] = None, stop: Optional[int] = None,
) -> Generator[Tuple[Connection, int, bytes], None, None]: ) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Read messages from bag. """Read messages from bag.
Args: Args:
@ -185,7 +185,7 @@ class Reader:
stop: Yield only messages before this timestamp (ns). stop: Yield only messages before this timestamp (ns).
Yields: Yields:
Tuples of connection, timestamp (ns), and rawdata. tuples of connection, timestamp (ns), and rawdata.
Raises: Raises:
ReaderError: Bag not open. ReaderError: Bag not open.
@ -198,7 +198,7 @@ class Reader:
'SELECT topics.id,messages.timestamp,messages.data', 'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id', 'FROM messages JOIN topics ON messages.topic_id=topics.id',
] ]
args: List[Any] = [] args: list[Any] = []
clause = 'WHERE' clause = 'WHERE'
if connections: if connections:

View File

@ -16,7 +16,7 @@ from .connection import Connection
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Literal, Optional, Type, Union from typing import Any, Literal, Optional, Type, Union
class WriterError(Exception): class WriterError(Exception):
@ -79,7 +79,7 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_mode = '' self.compression_mode = ''
self.compression_format = '' self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None self.compressor: Optional[zstandard.ZstdCompressor] = None
self.connections: Dict[int, Connection] = {} self.connections: dict[int, Connection] = {}
self.conn = None self.conn = None
self.cursor: Optional[sqlite3.Cursor] = None self.cursor: Optional[sqlite3.Cursor] = None

View File

@ -13,16 +13,16 @@ from __future__ import annotations
import sys import sys
from itertools import tee from itertools import tee
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, cast from typing import TYPE_CHECKING, Iterator, cast
from .typing import Field from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, List from typing import Callable
def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]: def generate_getsize_cdr(fields: list[Field]) -> tuple[Callable, int]:
"""Generate cdr size calculation function. """Generate cdr size calculation function.
Args: Args:
@ -37,7 +37,9 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
is_stat = True is_stat = True
aligned = 8 aligned = 8
icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) iterators = tee([*fields, None])
icurr = cast(Iterator[Field], iterators[0])
inext = iterators[1]
next(inext) next(inext)
lines = [ lines = [
'import sys', 'import sys',
@ -155,7 +157,7 @@ def generate_getsize_cdr(fields: List[Field]) -> Tuple[Callable, int]:
return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore return compile_lines(lines).getsize_cdr, is_stat * size # type: ignore
def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable: def generate_serialize_cdr(fields: list[Field], endianess: str) -> Callable:
"""Generate cdr serialization function. """Generate cdr serialization function.
Args: Args:
@ -168,7 +170,9 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
""" """
# pylint: disable=too-many-branches,too-many-locals,too-many-statements # pylint: disable=too-many-branches,too-many-locals,too-many-statements
aligned = 8 aligned = 8
icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) iterators = tee([*fields, None])
icurr = cast(Iterator[Field], iterators[0])
inext = iterators[1]
next(inext) next(inext)
lines = [ lines = [
'import sys', 'import sys',
@ -292,7 +296,7 @@ def generate_serialize_cdr(fields: List[Field], endianess: str) -> Callable:
return compile_lines(lines).serialize_cdr # type: ignore return compile_lines(lines).serialize_cdr # type: ignore
def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable: def generate_deserialize_cdr(fields: list[Field], endianess: str) -> Callable:
"""Generate cdr deserialization function. """Generate cdr deserialization function.
Args: Args:
@ -305,7 +309,9 @@ def generate_deserialize_cdr(fields: List[Field], endianess: str) -> Callable:
""" """
# pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
aligned = 8 aligned = 8
icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) iterators = tee([*fields, None])
icurr = cast(Iterator[Field], iterators[0])
inext = iterators[1]
next(inext) next(inext)
lines = [ lines = [
'import sys', 'import sys',

View File

@ -14,9 +14,9 @@ from .typing import Descriptor, Field, Msgdef
from .utils import Valtype from .utils import Valtype
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Dict from typing import Any
MSGDEFCACHE: Dict[str, Msgdef] = {} MSGDEFCACHE: dict[str, Msgdef] = {}
class SerdeError(Exception): class SerdeError(Exception):

View File

@ -12,16 +12,16 @@ conversion of ROS1 to CDR.
from __future__ import annotations from __future__ import annotations
from itertools import tee from itertools import tee
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, cast from typing import TYPE_CHECKING, Iterator, cast
from .typing import Field from .typing import Field
from .utils import SIZEMAP, Valtype, align, align_after, compile_lines from .utils import SIZEMAP, Valtype, align, align_after, compile_lines
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, List # pylint: disable=ungrouped-imports from typing import Callable # pylint: disable=ungrouped-imports
def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Callable: def generate_ros1_to_cdr(fields: list[Field], typename: str, copy: bool) -> Callable:
"""Generate ROS1 to CDR conversion function. """Generate ROS1 to CDR conversion function.
Args: Args:
@ -35,7 +35,9 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
""" """
# pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
aligned = 8 aligned = 8
icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) iterators = tee([*fields, None])
icurr = cast(Iterator[Field], iterators[0])
inext = iterators[1]
next(inext) next(inext)
funcname = 'ros1_to_cdr' if copy else 'getsize_ros1_to_cdr' funcname = 'ros1_to_cdr' if copy else 'getsize_ros1_to_cdr'
lines = [ lines = [
@ -170,7 +172,7 @@ def generate_ros1_to_cdr(fields: List[Field], typename: str, copy: bool) -> Call
return getattr(compile_lines(lines), funcname) return getattr(compile_lines(lines), funcname)
def generate_cdr_to_ros1(fields: List[Field], typename: str, copy: bool) -> Callable: def generate_cdr_to_ros1(fields: list[Field], typename: str, copy: bool) -> Callable:
"""Generate CDR to ROS1 conversion function. """Generate CDR to ROS1 conversion function.
Args: Args:
@ -184,7 +186,9 @@ def generate_cdr_to_ros1(fields: List[Field], typename: str, copy: bool) -> Call
""" """
# pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
aligned = 8 aligned = 8
icurr, inext = cast(Tuple[Iterator[Field], Iterator[Optional[Field]]], tee([*fields, None])) iterators = tee([*fields, None])
icurr = cast(Iterator[Field], iterators[0])
inext = iterators[1]
next(inext) next(inext)
funcname = 'cdr_to_ros1' if copy else 'getsize_cdr_to_ros1' funcname = 'cdr_to_ros1' if copy else 'getsize_cdr_to_ros1'
lines = [ lines = [

View File

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from types import ModuleType from types import ModuleType
from typing import Dict, List
from .typing import Descriptor from .typing import Descriptor
@ -24,7 +23,7 @@ class Valtype(IntEnum):
SEQUENCE = 4 SEQUENCE = 4
SIZEMAP: Dict[str, int] = { SIZEMAP: dict[str, int] = {
'bool': 1, 'bool': 1,
'int8': 1, 'int8': 1,
'int16': 2, 'int16': 2,
@ -83,7 +82,7 @@ def align_after(entry: Descriptor) -> int:
return min([4, align_after(entry.args[0])]) return min([4, align_after(entry.args[0])])
def compile_lines(lines: List[str]) -> ModuleType: def compile_lines(lines: list[str]) -> ModuleType:
"""Compile lines of code to module. """Compile lines of code to module.
Args: Args:

View File

@ -21,7 +21,7 @@ from .peg import Rule, Visitor, parse_grammar
from .types import FIELDDEFS from .types import FIELDDEFS
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, List from typing import Any
from .base import Fielddesc, Typesdict from .base import Fielddesc, Typesdict
@ -91,7 +91,7 @@ def normalize_msgtype(name: str) -> str:
return str(path) return str(path)
def normalize_fieldtype(typename: str, field: Fielddesc, names: List[str]) -> Fielddesc: def normalize_fieldtype(typename: str, field: Fielddesc, names: list[str]) -> Fielddesc:
"""Normalize field typename. """Normalize field typename.
Args: Args:
@ -235,7 +235,7 @@ def get_types_from_msg(text: str, name: str) -> Typesdict:
name: Message typename. name: Message typename.
Returns: Returns:
List with single message name and parsetree. list with single message name and parsetree.
""" """
return parse_message_definition(VisitorMSG(), f'MSG: {name}\n{text}') return parse_message_definition(VisitorMSG(), f'MSG: {name}\n{text}')

View File

@ -14,7 +14,7 @@ import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
class Rule: class Rule:
@ -23,7 +23,7 @@ class Rule:
LIT = 'LITERAL' LIT = 'LITERAL'
WS = re.compile(r'\s+', re.M | re.S) WS = re.compile(r'\s+', re.M | re.S)
def __init__(self, value: Any, rules: Dict[str, Rule], name: Optional[str] = None): def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None):
"""Initialize. """Initialize.
Args: Args:
@ -58,7 +58,7 @@ class Rule:
class RuleLiteral(Rule): class RuleLiteral(Rule):
"""Rule to match string literal.""" """Rule to match string literal."""
def __init__(self, value: Any, rules: Dict[str, Rule], name: Optional[str] = None): def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None):
"""Initialize. """Initialize.
Args: Args:
@ -70,7 +70,7 @@ class RuleLiteral(Rule):
super().__init__(value, rules, name) super().__init__(value, rules, name)
self.value = value[1:-1].replace('\\\'', '\'') self.value = value[1:-1].replace('\\\'', '\'')
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
value = self.value value = self.value
if text[pos:pos + len(value)] == value: if text[pos:pos + len(value)] == value:
@ -83,7 +83,7 @@ class RuleLiteral(Rule):
class RuleRegex(Rule): class RuleRegex(Rule):
"""Rule to match regular expression.""" """Rule to match regular expression."""
def __init__(self, value: Any, rules: Dict[str, Rule], name: Optional[str] = None): def __init__(self, value: Any, rules: dict[str, Rule], name: Optional[str] = None):
"""Initialize. """Initialize.
Args: Args:
@ -95,7 +95,7 @@ class RuleRegex(Rule):
super().__init__(value, rules, name) super().__init__(value, rules, name)
self.value = re.compile(value[2:-1], re.M | re.S) self.value = re.compile(value[2:-1], re.M | re.S)
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
match = self.value.match(text, pos) match = self.value.match(text, pos)
if not match: if not match:
@ -107,7 +107,7 @@ class RuleRegex(Rule):
class RuleToken(Rule): class RuleToken(Rule):
"""Rule to match token.""" """Rule to match token."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
token = self.rules[self.value] token = self.rules[self.value]
npos, data = token.parse(text, pos) npos, data = token.parse(text, pos)
@ -119,7 +119,7 @@ class RuleToken(Rule):
class RuleOneof(Rule): class RuleOneof(Rule):
"""Rule to match first matching subrule.""" """Rule to match first matching subrule."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
for value in self.value: for value in self.value:
npos, data = value.parse(text, pos) npos, data = value.parse(text, pos)
@ -131,7 +131,7 @@ class RuleOneof(Rule):
class RuleSequence(Rule): class RuleSequence(Rule):
"""Rule to match a sequence of subrules.""" """Rule to match a sequence of subrules."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
data = [] data = []
npos = pos npos = pos
@ -146,9 +146,9 @@ class RuleSequence(Rule):
class RuleZeroPlus(Rule): class RuleZeroPlus(Rule):
"""Rule to match zero or more occurences of subrule.""" """Rule to match zero or more occurences of subrule."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
data: List[Any] = [] data: list[Any] = []
lpos = pos lpos = pos
while True: while True:
npos, node = self.value.parse(text, lpos) npos, node = self.value.parse(text, lpos)
@ -161,7 +161,7 @@ class RuleZeroPlus(Rule):
class RuleOnePlus(Rule): class RuleOnePlus(Rule):
"""Rule to match one or more occurences of subrule.""" """Rule to match one or more occurences of subrule."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
npos, node = self.value.parse(text, pos) npos, node = self.value.parse(text, pos)
if npos == -1: if npos == -1:
@ -179,7 +179,7 @@ class RuleOnePlus(Rule):
class RuleZeroOne(Rule): class RuleZeroOne(Rule):
"""Rule to match zero or one occurence of subrule.""" """Rule to match zero or one occurence of subrule."""
def parse(self, text: str, pos: int) -> Tuple[int, Any]: def parse(self, text: str, pos: int) -> tuple[int, Any]:
"""Apply rule at position.""" """Apply rule at position."""
npos, node = self.value.parse(text, pos) npos, node = self.value.parse(text, pos)
if npos == -1: if npos == -1:
@ -190,7 +190,7 @@ class RuleZeroOne(Rule):
class Visitor: # pylint: disable=too-few-public-methods class Visitor: # pylint: disable=too-few-public-methods
"""Visitor transforming parse trees.""" """Visitor transforming parse trees."""
RULES: Dict[str, Rule] = {} RULES: dict[str, Rule] = {}
def __init__(self): def __init__(self):
"""Initialize.""" """Initialize."""
@ -208,15 +208,15 @@ class Visitor: # pylint: disable=too-few-public-methods
return func(tree['data']) return func(tree['data'])
def split_token(tok: str) -> List[str]: def split_token(tok: str) -> list[str]:
"""Split repetition and grouping tokens.""" """Split repetition and grouping tokens."""
return list(filter(None, re.split(r'(^\()|(\)(?=[*+?]?$))|([*+?]$)', tok))) return list(filter(None, re.split(r'(^\()|(\)(?=[*+?]?$))|([*+?]$)', tok)))
def collapse_tokens(toks: List[Optional[Rule]], rules: Dict[str, Rule]) -> Rule: def collapse_tokens(toks: list[Optional[Rule]], rules: dict[str, Rule]) -> Rule:
"""Collapse linear list of tokens to oneof of sequences.""" """Collapse linear list of tokens to oneof of sequences."""
value: List[Rule] = [] value: list[Rule] = []
seq: List[Rule] = [] seq: list[Rule] = []
for tok in toks: for tok in toks:
if tok: if tok:
seq.append(tok) seq.append(tok)
@ -227,9 +227,9 @@ def collapse_tokens(toks: List[Optional[Rule]], rules: Dict[str, Rule]) -> Rule:
return RuleOneof(value, rules) if len(value) > 1 else value[0] return RuleOneof(value, rules) if len(value) > 1 else value[0]
def parse_grammar(grammar: str) -> Dict[str, Rule]: def parse_grammar(grammar: str) -> dict[str, Rule]:
"""Parse grammar into rule dictionary.""" """Parse grammar into rule dictionary."""
rules: Dict[str, Rule] = {} rules: dict[str, Rule] = {}
for token in grammar.split('\n\n'): for token in grammar.split('\n\n'):
lines = token.strip().split('\n') lines = token.strip().split('\n')
name, *defs = lines name, *defs = lines
@ -237,8 +237,8 @@ def parse_grammar(grammar: str) -> Dict[str, Rule]:
assert items assert items
assert items[0] == '=' assert items[0] == '='
items.pop(0) items.pop(0)
stack: List[Optional[Rule]] = [] stack: list[Optional[Rule]] = []
parens: List[int] = [] parens: list[int] = []
while items: while items:
tok = items.pop(0) tok = items.pop(0)
if tok in ['*', '+', '?']: if tok in ['*', '+', '?']: