diff --git a/src/rosbags/rosbag2/__init__.py b/src/rosbags/rosbag2/__init__.py index 8c8ae4c3..d487f0c3 100644 --- a/src/rosbags/rosbag2/__init__.py +++ b/src/rosbags/rosbag2/__init__.py @@ -7,7 +7,8 @@ in the rosbag2 format. """ -from .reader import Reader, ReaderError +from .errors import ReaderError +from .reader import Reader from .writer import Writer, WriterError __all__ = [ diff --git a/src/rosbags/rosbag2/errors.py b/src/rosbags/rosbag2/errors.py new file mode 100644 index 00000000..698f1f9d --- /dev/null +++ b/src/rosbags/rosbag2/errors.py @@ -0,0 +1,9 @@ +# Copyright 2020-2023 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Rosbag2 errors.""" + +from __future__ import annotations + + +class ReaderError(Exception): + """Reader Error.""" diff --git a/src/rosbags/rosbag2/reader.py b/src/rosbags/rosbag2/reader.py index f345e0eb..27de18f8 100644 --- a/src/rosbags/rosbag2/reader.py +++ b/src/rosbags/rosbag2/reader.py @@ -4,11 +4,9 @@ from __future__ import annotations -import sqlite3 -from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol import zstandard from ruamel.yaml import YAML @@ -16,40 +14,43 @@ from ruamel.yaml.error import YAMLError from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo +from .errors import ReaderError +from .storage_sqlite3 import ReaderSqlite3 + if TYPE_CHECKING: from types import TracebackType - from typing import Any, Generator, Iterable, Literal, Optional, Type, Union + from typing import Generator, Iterable, Literal, Optional, Type, Union from .metadata import FileInformation, Metadata -class ReaderError(Exception): - """Reader Error.""" +class StorageProtocol(Protocol): + """Storage Protocol.""" + def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]): + """Initialize.""" + raise NotImplementedError -@contextmanager -def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]: - """Transparent rosbag2 database decompression context. + def open(self) -> None: + """Open file.""" + raise NotImplementedError - This context manager will yield a path to the decompressed file contents. + def close(self) -> None: + """Close file.""" + raise NotImplementedError - Args: - path: Potentially compressed file. - do_decompress: Flag indicating if decompression shall occur. + def get_definitions(self) -> dict[str, tuple[str, str]]: + """Get message definitions.""" + raise NotImplementedError - Yields: - Path of transparently decompressed file. - - """ - if do_decompress: - decomp = zstandard.ZstdDecompressor() - with TemporaryDirectory() as tempdir: - dbfile = Path(tempdir, path.stem) - with path.open('rb') as infile, dbfile.open('wb') as outfile: - decomp.copy_stream(infile, outfile) - yield dbfile - else: - yield path + def messages( + self, + connections: Iterable[Connection] = (), + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> Generator[tuple[Connection, int, bytes], None, None]: + """Get messages from file.""" + raise NotImplementedError class Reader: @@ -69,6 +70,12 @@ class Reader: """ + # pylint: disable=too-many-instance-attributes + + STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = { + 'sqlite3': ReaderSqlite3, + } + def __init__(self, path: Union[Path, str]): """Open rosbag and check metadata. @@ -82,7 +89,6 @@ class Reader: path = Path(path) yamlpath = path / 'metadata.yaml' self.path = path - self.bio = False try: yaml = YAML(typ='safe') dct = yaml.load(yamlpath.read_text()) @@ -95,7 +101,7 @@ class Reader: self.metadata: Metadata = dct['rosbag2_bagfile_information'] if (ver := self.metadata['version']) > 6: raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.') - if storageid := self.metadata['storage_identifier'] != 'sqlite3': + if (storageid := self.metadata['storage_identifier']) not in self.STORAGE_PLUGINS: raise ReaderError( f'Storage plugin {storageid!r} not supported; please report issue.', ) @@ -131,20 +137,12 @@ class Reader: self.files: list[FileInformation] = self.metadata.get('files', [])[:] self.custom_data: dict[str, str] = self.metadata.get('custom_data', {}) + + self.tmpdir: Optional[TemporaryDirectory] = None + self.storage: Optional[StorageProtocol] = None except KeyError as exc: raise ReaderError(f'A metadata key is missing {exc!r}.') from None - def open(self) -> None: - """Open rosbag2.""" - # Future storage formats will require file handles. - self.bio = True - - def close(self) -> None: - """Close rosbag2.""" - # Future storage formats will require file handles. - assert self.bio - self.bio = False - @property def duration(self) -> int: """Duration in nanoseconds between earliest and latest messages.""" @@ -183,7 +181,50 @@ class Reader: """Topic information.""" return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections} - def messages( # pylint: disable=too-many-locals + def open(self) -> None: + """Open rosbag2.""" + storage_paths = [] + if self.compression_mode == 'file': + self.tmpdir = TemporaryDirectory() # pylint: disable=consider-using-with + tmpdir = self.tmpdir.name + decomp = zstandard.ZstdDecompressor() + for path in self.paths: + storage_file = Path(tmpdir, path.stem) + with path.open('rb') as infile, storage_file.open('wb') as outfile: + decomp.copy_stream(infile, outfile) + storage_paths.append(storage_file) + else: + storage_paths = self.paths[:] + + self.storage = self.STORAGE_PLUGINS[self.metadata['storage_identifier']]( + storage_paths, + self.connections, + ) + self.storage.open() + definitions = self.storage.get_definitions() + for idx, conn in enumerate(self.connections): + if desc := definitions.get(conn.msgtype): + self.connections[idx] = Connection( + id=conn.id, + topic=conn.topic, + msgtype=conn.msgtype, + msgdef=desc[1], + md5sum=desc[0], + msgcount=conn.msgcount, + ext=conn.ext, + owner=conn.owner, + ) + + def close(self) -> None: + """Close rosbag2.""" + assert self.storage + self.storage.close() + self.storage = None + if self.tmpdir: + self.tmpdir.cleanup() + self.tmpdir = None + + def messages( self, connections: Iterable[Connection] = (), start: Optional[int] = None, @@ -201,67 +242,18 @@ class Reader: tuples of connection, timestamp (ns), and rawdata. Raises: - ReaderError: Bag not open. + ReaderError: If reader was not opened. """ - if not self.bio: + if not self.storage: raise ReaderError('Rosbag is not open.') - query = [ - 'SELECT topics.id,messages.timestamp,messages.data', - 'FROM messages JOIN topics ON messages.topic_id=topics.id', - ] - args: list[Any] = [] - clause = 'WHERE' - - if connections: - topics = {x.topic for x in connections} - query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})') - args += topics - clause = 'AND' - - if start is not None: - query.append(f'{clause} messages.timestamp >= ?') - args.append(start) - clause = 'AND' - - if stop is not None: - query.append(f'{clause} messages.timestamp < ?') - args.append(stop) - clause = 'AND' - - query.append('ORDER BY timestamp') - querystr = ' '.join(query) - - for filepath in self.paths: - with decompress(filepath, self.compression_mode == 'file') as path: - conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True) - conn.row_factory = lambda _, x: x - cur = conn.cursor() - cur.execute( - 'SELECT count(*) FROM sqlite_master ' - 'WHERE type="table" AND name IN ("messages", "topics")', - ) - if cur.fetchone()[0] != 2: - raise ReaderError(f'Cannot open database {path} or database missing tables.') - - cur.execute('SELECT name,id FROM topics') - connmap: dict[int, Connection] = { - row[1]: next((x for x in self.connections if x.topic == row[0]), - None) # type: ignore - for row in cur - } - - cur.execute(querystr, args) - - if self.compression_mode == 'message': - decomp = zstandard.ZstdDecompressor().decompress - for row in cur: - cid, timestamp, data = row - yield connmap[cid], timestamp, decomp(data) - else: - for cid, timestamp, data in cur: - yield connmap[cid], timestamp, data + if self.compression_mode == 'message': + decomp = zstandard.ZstdDecompressor().decompress + for connection, timestamp, data in self.storage.messages(connections, start, stop): + yield connection, timestamp, decomp(data) + else: + yield from self.storage.messages(connections, start, stop) def __enter__(self) -> Reader: """Open rosbag2 when entering contextmanager.""" diff --git a/src/rosbags/rosbag2/storage_sqlite3.py b/src/rosbags/rosbag2/storage_sqlite3.py new file mode 100644 index 00000000..bddfc038 --- /dev/null +++ b/src/rosbags/rosbag2/storage_sqlite3.py @@ -0,0 +1,119 @@ +# Copyright 2020-2023 Ternaris. +# SPDX-License-Identifier: Apache-2.0 +"""Sqlite3 storage.""" + +from __future__ import annotations + +import sqlite3 +from typing import TYPE_CHECKING + +from .errors import ReaderError + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Generator, Iterable, Optional + + from rosbags.interfaces import Connection + + +class ReaderSqlite3: + """Sqlite3 storage reader.""" + + def __init__( + self, + paths: Iterable[Path], + connections: Iterable[Connection], + ): + """Set up storage reader. + + Args: + paths: Paths of storage files. + connections: List of connections. + + """ + self.opened = False + self.paths = paths + self.connections = connections + + def open(self) -> None: + """Open rosbag2.""" + self.opened = True + + def close(self) -> None: + """Close rosbag2.""" + assert self.opened + self.opened = False + + def get_definitions(self) -> dict[str, tuple[str, str]]: + """Get message definitions.""" + return {} + + def messages( # pylint: disable=too-many-locals + self, + connections: Iterable[Connection] = (), + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> Generator[tuple[Connection, int, bytes], None, None]: + """Read messages from bag. + + Args: + connections: Iterable with connections to filter for. An empty + iterable disables filtering on connections. + start: Yield only messages at or after this timestamp (ns). + stop: Yield only messages before this timestamp (ns). + + Yields: + tuples of connection, timestamp (ns), and rawdata. + + Raises: + ReaderError: Bag not open. + + """ + query = [ + 'SELECT topics.id,messages.timestamp,messages.data', + 'FROM messages JOIN topics ON messages.topic_id=topics.id', + ] + args: list[Any] = [] + clause = 'WHERE' + + if connections: + topics = {x.topic for x in connections} + query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})') + args += topics + clause = 'AND' + + if start is not None: + query.append(f'{clause} messages.timestamp >= ?') + args.append(start) + clause = 'AND' + + if stop is not None: + query.append(f'{clause} messages.timestamp < ?') + args.append(stop) + clause = 'AND' + + query.append('ORDER BY timestamp') + querystr = ' '.join(query) + + for path in self.paths: + conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True) + conn.row_factory = lambda _, x: x + cur = conn.cursor() + cur.execute( + 'SELECT count(*) FROM sqlite_master ' + 'WHERE type="table" AND name IN ("messages", "topics")', + ) + if cur.fetchone()[0] != 2: + raise ReaderError(f'Cannot open database {path} or database missing tables.') + + cur.execute('SELECT name,id FROM topics') + connmap: dict[int, Connection] = { + row[1]: next((x for x in self.connections if x.topic == row[0]), + None) # type: ignore + for row in cur + } + + cur.execute(querystr, args) + + for cid, timestamp, data in cur: + yield connmap[cid], timestamp, data