Refactor rosbag2 Reader

This commit is contained in:
Marko Durkovic 2023-01-11 15:26:05 +01:00
parent d7d24c4478
commit eaa64002b8
4 changed files with 219 additions and 98 deletions

View File

@ -7,7 +7,8 @@ in the rosbag2 format.
"""
from .reader import Reader, ReaderError
from .errors import ReaderError
from .reader import Reader
from .writer import Writer, WriterError
__all__ = [

View File

@ -0,0 +1,9 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 errors."""
from __future__ import annotations
class ReaderError(Exception):
"""Reader Error."""

View File

@ -4,11 +4,9 @@
from __future__ import annotations
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol
import zstandard
from ruamel.yaml import YAML
@ -16,40 +14,43 @@ from ruamel.yaml.error import YAMLError
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
from .errors import ReaderError
from .storage_sqlite3 import ReaderSqlite3
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
from typing import Generator, Iterable, Literal, Optional, Type, Union
from .metadata import FileInformation, Metadata
class ReaderError(Exception):
"""Reader Error."""
class StorageProtocol(Protocol):
"""Storage Protocol."""
def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]):
"""Initialize."""
raise NotImplementedError
@contextmanager
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
"""Transparent rosbag2 database decompression context.
def open(self) -> None:
"""Open file."""
raise NotImplementedError
This context manager will yield a path to the decompressed file contents.
def close(self) -> None:
"""Close file."""
raise NotImplementedError
Args:
path: Potentially compressed file.
do_decompress: Flag indicating if decompression shall occur.
def get_definitions(self) -> dict[str, tuple[str, str]]:
"""Get message definitions."""
raise NotImplementedError
Yields:
Path of transparently decompressed file.
"""
if do_decompress:
decomp = zstandard.ZstdDecompressor()
with TemporaryDirectory() as tempdir:
dbfile = Path(tempdir, path.stem)
with path.open('rb') as infile, dbfile.open('wb') as outfile:
decomp.copy_stream(infile, outfile)
yield dbfile
else:
yield path
def messages(
self,
connections: Iterable[Connection] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Get messages from file."""
raise NotImplementedError
class Reader:
@ -69,6 +70,12 @@ class Reader:
"""
# pylint: disable=too-many-instance-attributes
STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = {
'sqlite3': ReaderSqlite3,
}
def __init__(self, path: Union[Path, str]):
"""Open rosbag and check metadata.
@ -82,7 +89,6 @@ class Reader:
path = Path(path)
yamlpath = path / 'metadata.yaml'
self.path = path
self.bio = False
try:
yaml = YAML(typ='safe')
dct = yaml.load(yamlpath.read_text())
@ -95,7 +101,7 @@ class Reader:
self.metadata: Metadata = dct['rosbag2_bagfile_information']
if (ver := self.metadata['version']) > 6:
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
if storageid := self.metadata['storage_identifier'] != 'sqlite3':
if (storageid := self.metadata['storage_identifier']) not in self.STORAGE_PLUGINS:
raise ReaderError(
f'Storage plugin {storageid!r} not supported; please report issue.',
)
@ -131,20 +137,12 @@ class Reader:
self.files: list[FileInformation] = self.metadata.get('files', [])[:]
self.custom_data: dict[str, str] = self.metadata.get('custom_data', {})
self.tmpdir: Optional[TemporaryDirectory] = None
self.storage: Optional[StorageProtocol] = None
except KeyError as exc:
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
def open(self) -> None:
"""Open rosbag2."""
# Future storage formats will require file handles.
self.bio = True
def close(self) -> None:
"""Close rosbag2."""
# Future storage formats will require file handles.
assert self.bio
self.bio = False
@property
def duration(self) -> int:
"""Duration in nanoseconds between earliest and latest messages."""
@ -183,7 +181,50 @@ class Reader:
"""Topic information."""
return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
def messages( # pylint: disable=too-many-locals
def open(self) -> None:
"""Open rosbag2."""
storage_paths = []
if self.compression_mode == 'file':
self.tmpdir = TemporaryDirectory() # pylint: disable=consider-using-with
tmpdir = self.tmpdir.name
decomp = zstandard.ZstdDecompressor()
for path in self.paths:
storage_file = Path(tmpdir, path.stem)
with path.open('rb') as infile, storage_file.open('wb') as outfile:
decomp.copy_stream(infile, outfile)
storage_paths.append(storage_file)
else:
storage_paths = self.paths[:]
self.storage = self.STORAGE_PLUGINS[self.metadata['storage_identifier']](
storage_paths,
self.connections,
)
self.storage.open()
definitions = self.storage.get_definitions()
for idx, conn in enumerate(self.connections):
if desc := definitions.get(conn.msgtype):
self.connections[idx] = Connection(
id=conn.id,
topic=conn.topic,
msgtype=conn.msgtype,
msgdef=desc[1],
md5sum=desc[0],
msgcount=conn.msgcount,
ext=conn.ext,
owner=conn.owner,
)
def close(self) -> None:
"""Close rosbag2."""
assert self.storage
self.storage.close()
self.storage = None
if self.tmpdir:
self.tmpdir.cleanup()
self.tmpdir = None
def messages(
self,
connections: Iterable[Connection] = (),
start: Optional[int] = None,
@ -201,67 +242,18 @@ class Reader:
tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open.
ReaderError: If reader was not opened.
"""
if not self.bio:
if not self.storage:
raise ReaderError('Rosbag is not open.')
query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: list[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for filepath in self.paths:
with decompress(filepath, self.compression_mode == 'file') as path:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
conn.row_factory = lambda _, x: x
cur = conn.cursor()
cur.execute(
'SELECT count(*) FROM sqlite_master '
'WHERE type="table" AND name IN ("messages", "topics")',
)
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections if x.topic == row[0]),
None) # type: ignore
for row in cur
}
cur.execute(querystr, args)
if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress
for row in cur:
cid, timestamp, data = row
yield connmap[cid], timestamp, decomp(data)
else:
for cid, timestamp, data in cur:
yield connmap[cid], timestamp, data
if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress
for connection, timestamp, data in self.storage.messages(connections, start, stop):
yield connection, timestamp, decomp(data)
else:
yield from self.storage.messages(connections, start, stop)
def __enter__(self) -> Reader:
"""Open rosbag2 when entering contextmanager."""

View File

@ -0,0 +1,119 @@
# Copyright 2020-2023 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Sqlite3 storage."""
from __future__ import annotations
import sqlite3
from typing import TYPE_CHECKING
from .errors import ReaderError
if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Generator, Iterable, Optional
from rosbags.interfaces import Connection
class ReaderSqlite3:
"""Sqlite3 storage reader."""
def __init__(
self,
paths: Iterable[Path],
connections: Iterable[Connection],
):
"""Set up storage reader.
Args:
paths: Paths of storage files.
connections: List of connections.
"""
self.opened = False
self.paths = paths
self.connections = connections
def open(self) -> None:
"""Open rosbag2."""
self.opened = True
def close(self) -> None:
"""Close rosbag2."""
assert self.opened
self.opened = False
def get_definitions(self) -> dict[str, tuple[str, str]]:
"""Get message definitions."""
return {}
def messages( # pylint: disable=too-many-locals
self,
connections: Iterable[Connection] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
connections: Iterable with connections to filter for. An empty
iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns).
Yields:
tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open.
"""
query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: list[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for path in self.paths:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
conn.row_factory = lambda _, x: x
cur = conn.cursor()
cur.execute(
'SELECT count(*) FROM sqlite_master '
'WHERE type="table" AND name IN ("messages", "topics")',
)
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
cur.execute('SELECT name,id FROM topics')
connmap: dict[int, Connection] = {
row[1]: next((x for x in self.connections if x.topic == row[0]),
None) # type: ignore
for row in cur
}
cur.execute(querystr, args)
for cid, timestamp, data in cur:
yield connmap[cid], timestamp, data