Refactor rosbag2 Reader

This commit is contained in:
Marko Durkovic 2023-01-11 15:26:05 +01:00
parent d7d24c4478
commit eaa64002b8
4 changed files with 219 additions and 98 deletions

View File

@ -7,7 +7,8 @@ in the rosbag2 format.
""" """
from .reader import Reader, ReaderError from .errors import ReaderError
from .reader import Reader
from .writer import Writer, WriterError from .writer import Writer, WriterError
__all__ = [ __all__ = [

View File

@ -0,0 +1,9 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 errors."""
from __future__ import annotations
class ReaderError(Exception):
"""Reader Error."""

View File

@ -4,11 +4,9 @@
from __future__ import annotations from __future__ import annotations
import sqlite3
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Protocol
import zstandard import zstandard
from ruamel.yaml import YAML from ruamel.yaml import YAML
@ -16,40 +14,43 @@ from ruamel.yaml.error import YAMLError
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
from .errors import ReaderError
from .storage_sqlite3 import ReaderSqlite3
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union from typing import Generator, Iterable, Literal, Optional, Type, Union
from .metadata import FileInformation, Metadata from .metadata import FileInformation, Metadata
class ReaderError(Exception): class StorageProtocol(Protocol):
"""Reader Error.""" """Storage Protocol."""
def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]):
"""Initialize."""
raise NotImplementedError
@contextmanager def open(self) -> None:
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]: """Open file."""
"""Transparent rosbag2 database decompression context. raise NotImplementedError
This context manager will yield a path to the decompressed file contents. def close(self) -> None:
"""Close file."""
raise NotImplementedError
Args: def get_definitions(self) -> dict[str, tuple[str, str]]:
path: Potentially compressed file. """Get message definitions."""
do_decompress: Flag indicating if decompression shall occur. raise NotImplementedError
Yields: def messages(
Path of transparently decompressed file. self,
connections: Iterable[Connection] = (),
""" start: Optional[int] = None,
if do_decompress: stop: Optional[int] = None,
decomp = zstandard.ZstdDecompressor() ) -> Generator[tuple[Connection, int, bytes], None, None]:
with TemporaryDirectory() as tempdir: """Get messages from file."""
dbfile = Path(tempdir, path.stem) raise NotImplementedError
with path.open('rb') as infile, dbfile.open('wb') as outfile:
decomp.copy_stream(infile, outfile)
yield dbfile
else:
yield path
class Reader: class Reader:
@ -69,6 +70,12 @@ class Reader:
""" """
# pylint: disable=too-many-instance-attributes
STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = {
'sqlite3': ReaderSqlite3,
}
def __init__(self, path: Union[Path, str]): def __init__(self, path: Union[Path, str]):
"""Open rosbag and check metadata. """Open rosbag and check metadata.
@ -82,7 +89,6 @@ class Reader:
path = Path(path) path = Path(path)
yamlpath = path / 'metadata.yaml' yamlpath = path / 'metadata.yaml'
self.path = path self.path = path
self.bio = False
try: try:
yaml = YAML(typ='safe') yaml = YAML(typ='safe')
dct = yaml.load(yamlpath.read_text()) dct = yaml.load(yamlpath.read_text())
@ -95,7 +101,7 @@ class Reader:
self.metadata: Metadata = dct['rosbag2_bagfile_information'] self.metadata: Metadata = dct['rosbag2_bagfile_information']
if (ver := self.metadata['version']) > 6: if (ver := self.metadata['version']) > 6:
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.') raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
if storageid := self.metadata['storage_identifier'] != 'sqlite3': if (storageid := self.metadata['storage_identifier']) not in self.STORAGE_PLUGINS:
raise ReaderError( raise ReaderError(
f'Storage plugin {storageid!r} not supported; please report issue.', f'Storage plugin {storageid!r} not supported; please report issue.',
) )
@ -131,20 +137,12 @@ class Reader:
self.files: list[FileInformation] = self.metadata.get('files', [])[:] self.files: list[FileInformation] = self.metadata.get('files', [])[:]
self.custom_data: dict[str, str] = self.metadata.get('custom_data', {}) self.custom_data: dict[str, str] = self.metadata.get('custom_data', {})
self.tmpdir: Optional[TemporaryDirectory] = None
self.storage: Optional[StorageProtocol] = None
except KeyError as exc: except KeyError as exc:
raise ReaderError(f'A metadata key is missing {exc!r}.') from None raise ReaderError(f'A metadata key is missing {exc!r}.') from None
def open(self) -> None:
"""Open rosbag2."""
# Future storage formats will require file handles.
self.bio = True
def close(self) -> None:
"""Close rosbag2."""
# Future storage formats will require file handles.
assert self.bio
self.bio = False
@property @property
def duration(self) -> int: def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages.""" """Duration in nanoseconds between earliest and latest messages."""
@ -183,7 +181,50 @@ class Reader:
"""Topic information.""" """Topic information."""
return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections} return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
def messages( # pylint: disable=too-many-locals def open(self) -> None:
"""Open rosbag2."""
storage_paths = []
if self.compression_mode == 'file':
self.tmpdir = TemporaryDirectory() # pylint: disable=consider-using-with
tmpdir = self.tmpdir.name
decomp = zstandard.ZstdDecompressor()
for path in self.paths:
storage_file = Path(tmpdir, path.stem)
with path.open('rb') as infile, storage_file.open('wb') as outfile:
decomp.copy_stream(infile, outfile)
storage_paths.append(storage_file)
else:
storage_paths = self.paths[:]
self.storage = self.STORAGE_PLUGINS[self.metadata['storage_identifier']](
storage_paths,
self.connections,
)
self.storage.open()
definitions = self.storage.get_definitions()
for idx, conn in enumerate(self.connections):
if desc := definitions.get(conn.msgtype):
self.connections[idx] = Connection(
id=conn.id,
topic=conn.topic,
msgtype=conn.msgtype,
msgdef=desc[1],
md5sum=desc[0],
msgcount=conn.msgcount,
ext=conn.ext,
owner=conn.owner,
)
def close(self) -> None:
"""Close rosbag2."""
assert self.storage
self.storage.close()
self.storage = None
if self.tmpdir:
self.tmpdir.cleanup()
self.tmpdir = None
def messages(
self, self,
connections: Iterable[Connection] = (), connections: Iterable[Connection] = (),
start: Optional[int] = None, start: Optional[int] = None,
@ -201,67 +242,18 @@ class Reader:
tuples of connection, timestamp (ns), and rawdata. tuples of connection, timestamp (ns), and rawdata.
Raises: Raises:
ReaderError: Bag not open. ReaderError: If reader was not opened.
""" """
if not self.bio: if not self.storage:
raise ReaderError('Rosbag is not open.') raise ReaderError('Rosbag is not open.')
query = [ if self.compression_mode == 'message':
'SELECT topics.id,messages.timestamp,messages.data', decomp = zstandard.ZstdDecompressor().decompress
'FROM messages JOIN topics ON messages.topic_id=topics.id', for connection, timestamp, data in self.storage.messages(connections, start, stop):
] yield connection, timestamp, decomp(data)
args: list[Any] = [] else:
clause = 'WHERE' yield from self.storage.messages(connections, start, stop)
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for filepath in self.paths:
with decompress(filepath, self.compression_mode == 'file') as path:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
conn.row_factory = lambda _, x: x
cur = conn.cursor()
cur.execute(
'SELECT count(*) FROM sqlite_master '
'WHERE type="table" AND name IN ("messages", "topics")',
)
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections if x.topic == row[0]),
None) # type: ignore
for row in cur
}
cur.execute(querystr, args)
if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress
for row in cur:
cid, timestamp, data = row
yield connmap[cid], timestamp, decomp(data)
else:
for cid, timestamp, data in cur:
yield connmap[cid], timestamp, data
def __enter__(self) -> Reader: def __enter__(self) -> Reader:
"""Open rosbag2 when entering contextmanager.""" """Open rosbag2 when entering contextmanager."""

View File

@ -0,0 +1,119 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Sqlite3 storage."""
from __future__ import annotations
import sqlite3
from typing import TYPE_CHECKING
from .errors import ReaderError
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Generator, Iterable, Optional
from rosbags.interfaces import Connection
class ReaderSqlite3:
"""Sqlite3 storage reader."""
def __init__(
self,
paths: Iterable[Path],
connections: Iterable[Connection],
):
"""Set up storage reader.
Args:
paths: Paths of storage files.
connections: List of connections.
"""
self.opened = False
self.paths = paths
self.connections = connections
def open(self) -> None:
"""Open rosbag2."""
self.opened = True
def close(self) -> None:
"""Close rosbag2."""
assert self.opened
self.opened = False
def get_definitions(self) -> dict[str, tuple[str, str]]:
"""Get message definitions."""
return {}
def messages( # pylint: disable=too-many-locals
self,
connections: Iterable[Connection] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
connections: Iterable with connections to filter for. An empty
iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns).
Yields:
tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open.
"""
query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: list[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for path in self.paths:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
conn.row_factory = lambda _, x: x
cur = conn.cursor()
cur.execute(
'SELECT count(*) FROM sqlite_master '
'WHERE type="table" AND name IN ("messages", "topics")',
)
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections if x.topic == row[0]),
None) # type: ignore
for row in cur
}
cur.execute(querystr, args)
for cid, timestamp, data in cur:
yield connmap[cid], timestamp, data