Refactor rosbag2 Reader
This commit is contained in:
parent
d7d24c4478
commit
eaa64002b8
@ -7,7 +7,8 @@ in the rosbag2 format.
|
||||
|
||||
"""
|
||||
|
||||
from .reader import Reader, ReaderError
|
||||
from .errors import ReaderError
|
||||
from .reader import Reader
|
||||
from .writer import Writer, WriterError
|
||||
|
||||
__all__ = [
|
||||
|
||||
9
src/rosbags/rosbag2/errors.py
Normal file
9
src/rosbags/rosbag2/errors.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag2 errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
"""Reader Error."""
|
||||
@ -4,11 +4,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML
|
||||
@ -16,40 +14,43 @@ from ruamel.yaml.error import YAMLError
|
||||
|
||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
|
||||
|
||||
from .errors import ReaderError
|
||||
from .storage_sqlite3 import ReaderSqlite3
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
||||
from typing import Generator, Iterable, Literal, Optional, Type, Union
|
||||
|
||||
from .metadata import FileInformation, Metadata
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
"""Reader Error."""
|
||||
class StorageProtocol(Protocol):
|
||||
"""Storage Protocol."""
|
||||
|
||||
def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]):
|
||||
"""Initialize."""
|
||||
raise NotImplementedError
|
||||
|
||||
@contextmanager
|
||||
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
|
||||
"""Transparent rosbag2 database decompression context.
|
||||
def open(self) -> None:
|
||||
"""Open file."""
|
||||
raise NotImplementedError
|
||||
|
||||
This context manager will yield a path to the decompressed file contents.
|
||||
def close(self) -> None:
|
||||
"""Close file."""
|
||||
raise NotImplementedError
|
||||
|
||||
Args:
|
||||
path: Potentially compressed file.
|
||||
do_decompress: Flag indicating if decompression shall occur.
|
||||
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||
"""Get message definitions."""
|
||||
raise NotImplementedError
|
||||
|
||||
Yields:
|
||||
Path of transparently decompressed file.
|
||||
|
||||
"""
|
||||
if do_decompress:
|
||||
decomp = zstandard.ZstdDecompressor()
|
||||
with TemporaryDirectory() as tempdir:
|
||||
dbfile = Path(tempdir, path.stem)
|
||||
with path.open('rb') as infile, dbfile.open('wb') as outfile:
|
||||
decomp.copy_stream(infile, outfile)
|
||||
yield dbfile
|
||||
else:
|
||||
yield path
|
||||
def messages(
|
||||
self,
|
||||
connections: Iterable[Connection] = (),
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Get messages from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Reader:
|
||||
@ -69,6 +70,12 @@ class Reader:
|
||||
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = {
|
||||
'sqlite3': ReaderSqlite3,
|
||||
}
|
||||
|
||||
def __init__(self, path: Union[Path, str]):
|
||||
"""Open rosbag and check metadata.
|
||||
|
||||
@ -82,7 +89,6 @@ class Reader:
|
||||
path = Path(path)
|
||||
yamlpath = path / 'metadata.yaml'
|
||||
self.path = path
|
||||
self.bio = False
|
||||
try:
|
||||
yaml = YAML(typ='safe')
|
||||
dct = yaml.load(yamlpath.read_text())
|
||||
@ -95,7 +101,7 @@ class Reader:
|
||||
self.metadata: Metadata = dct['rosbag2_bagfile_information']
|
||||
if (ver := self.metadata['version']) > 6:
|
||||
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
|
||||
if storageid := self.metadata['storage_identifier'] != 'sqlite3':
|
||||
if (storageid := self.metadata['storage_identifier']) not in self.STORAGE_PLUGINS:
|
||||
raise ReaderError(
|
||||
f'Storage plugin {storageid!r} not supported; please report issue.',
|
||||
)
|
||||
@ -131,20 +137,12 @@ class Reader:
|
||||
|
||||
self.files: list[FileInformation] = self.metadata.get('files', [])[:]
|
||||
self.custom_data: dict[str, str] = self.metadata.get('custom_data', {})
|
||||
|
||||
self.tmpdir: Optional[TemporaryDirectory] = None
|
||||
self.storage: Optional[StorageProtocol] = None
|
||||
except KeyError as exc:
|
||||
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
self.bio = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
# Future storage formats will require file handles.
|
||||
assert self.bio
|
||||
self.bio = False
|
||||
|
||||
@property
|
||||
def duration(self) -> int:
|
||||
"""Duration in nanoseconds between earliest and latest messages."""
|
||||
@ -183,7 +181,50 @@ class Reader:
|
||||
"""Topic information."""
|
||||
return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
|
||||
|
||||
def messages( # pylint: disable=too-many-locals
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
storage_paths = []
|
||||
if self.compression_mode == 'file':
|
||||
self.tmpdir = TemporaryDirectory() # pylint: disable=consider-using-with
|
||||
tmpdir = self.tmpdir.name
|
||||
decomp = zstandard.ZstdDecompressor()
|
||||
for path in self.paths:
|
||||
storage_file = Path(tmpdir, path.stem)
|
||||
with path.open('rb') as infile, storage_file.open('wb') as outfile:
|
||||
decomp.copy_stream(infile, outfile)
|
||||
storage_paths.append(storage_file)
|
||||
else:
|
||||
storage_paths = self.paths[:]
|
||||
|
||||
self.storage = self.STORAGE_PLUGINS[self.metadata['storage_identifier']](
|
||||
storage_paths,
|
||||
self.connections,
|
||||
)
|
||||
self.storage.open()
|
||||
definitions = self.storage.get_definitions()
|
||||
for idx, conn in enumerate(self.connections):
|
||||
if desc := definitions.get(conn.msgtype):
|
||||
self.connections[idx] = Connection(
|
||||
id=conn.id,
|
||||
topic=conn.topic,
|
||||
msgtype=conn.msgtype,
|
||||
msgdef=desc[1],
|
||||
md5sum=desc[0],
|
||||
msgcount=conn.msgcount,
|
||||
ext=conn.ext,
|
||||
owner=conn.owner,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
assert self.storage
|
||||
self.storage.close()
|
||||
self.storage = None
|
||||
if self.tmpdir:
|
||||
self.tmpdir.cleanup()
|
||||
self.tmpdir = None
|
||||
|
||||
def messages(
|
||||
self,
|
||||
connections: Iterable[Connection] = (),
|
||||
start: Optional[int] = None,
|
||||
@ -201,67 +242,18 @@ class Reader:
|
||||
tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not open.
|
||||
ReaderError: If reader was not opened.
|
||||
|
||||
"""
|
||||
if not self.bio:
|
||||
if not self.storage:
|
||||
raise ReaderError('Rosbag is not open.')
|
||||
|
||||
query = [
|
||||
'SELECT topics.id,messages.timestamp,messages.data',
|
||||
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
||||
]
|
||||
args: list[Any] = []
|
||||
clause = 'WHERE'
|
||||
|
||||
if connections:
|
||||
topics = {x.topic for x in connections}
|
||||
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
|
||||
args += topics
|
||||
clause = 'AND'
|
||||
|
||||
if start is not None:
|
||||
query.append(f'{clause} messages.timestamp >= ?')
|
||||
args.append(start)
|
||||
clause = 'AND'
|
||||
|
||||
if stop is not None:
|
||||
query.append(f'{clause} messages.timestamp < ?')
|
||||
args.append(stop)
|
||||
clause = 'AND'
|
||||
|
||||
query.append('ORDER BY timestamp')
|
||||
querystr = ' '.join(query)
|
||||
|
||||
for filepath in self.paths:
|
||||
with decompress(filepath, self.compression_mode == 'file') as path:
|
||||
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
|
||||
conn.row_factory = lambda _, x: x
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
'SELECT count(*) FROM sqlite_master '
|
||||
'WHERE type="table" AND name IN ("messages", "topics")',
|
||||
)
|
||||
if cur.fetchone()[0] != 2:
|
||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||
|
||||
cur.execute('SELECT name,id FROM topics')
|
||||
connmap: dict[int, Connection] = {
|
||||
row[1]: next((x for x in self.connections if x.topic == row[0]),
|
||||
None) # type: ignore
|
||||
for row in cur
|
||||
}
|
||||
|
||||
cur.execute(querystr, args)
|
||||
|
||||
if self.compression_mode == 'message':
|
||||
decomp = zstandard.ZstdDecompressor().decompress
|
||||
for row in cur:
|
||||
cid, timestamp, data = row
|
||||
yield connmap[cid], timestamp, decomp(data)
|
||||
else:
|
||||
for cid, timestamp, data in cur:
|
||||
yield connmap[cid], timestamp, data
|
||||
if self.compression_mode == 'message':
|
||||
decomp = zstandard.ZstdDecompressor().decompress
|
||||
for connection, timestamp, data in self.storage.messages(connections, start, stop):
|
||||
yield connection, timestamp, decomp(data)
|
||||
else:
|
||||
yield from self.storage.messages(connections, start, stop)
|
||||
|
||||
def __enter__(self) -> Reader:
|
||||
"""Open rosbag2 when entering contextmanager."""
|
||||
|
||||
119
src/rosbags/rosbag2/storage_sqlite3.py
Normal file
119
src/rosbags/rosbag2/storage_sqlite3.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2020-2023 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Sqlite3 storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .errors import ReaderError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Iterable, Optional
|
||||
|
||||
from rosbags.interfaces import Connection
|
||||
|
||||
|
||||
class ReaderSqlite3:
|
||||
"""Sqlite3 storage reader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Iterable[Path],
|
||||
connections: Iterable[Connection],
|
||||
):
|
||||
"""Set up storage reader.
|
||||
|
||||
Args:
|
||||
paths: Paths of storage files.
|
||||
connections: List of connections.
|
||||
|
||||
"""
|
||||
self.opened = False
|
||||
self.paths = paths
|
||||
self.connections = connections
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open rosbag2."""
|
||||
self.opened = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close rosbag2."""
|
||||
assert self.opened
|
||||
self.opened = False
|
||||
|
||||
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||
"""Get message definitions."""
|
||||
return {}
|
||||
|
||||
def messages( # pylint: disable=too-many-locals
|
||||
self,
|
||||
connections: Iterable[Connection] = (),
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
connections: Iterable with connections to filter for. An empty
|
||||
iterable disables filtering on connections.
|
||||
start: Yield only messages at or after this timestamp (ns).
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not open.
|
||||
|
||||
"""
|
||||
query = [
|
||||
'SELECT topics.id,messages.timestamp,messages.data',
|
||||
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
||||
]
|
||||
args: list[Any] = []
|
||||
clause = 'WHERE'
|
||||
|
||||
if connections:
|
||||
topics = {x.topic for x in connections}
|
||||
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
|
||||
args += topics
|
||||
clause = 'AND'
|
||||
|
||||
if start is not None:
|
||||
query.append(f'{clause} messages.timestamp >= ?')
|
||||
args.append(start)
|
||||
clause = 'AND'
|
||||
|
||||
if stop is not None:
|
||||
query.append(f'{clause} messages.timestamp < ?')
|
||||
args.append(stop)
|
||||
clause = 'AND'
|
||||
|
||||
query.append('ORDER BY timestamp')
|
||||
querystr = ' '.join(query)
|
||||
|
||||
for path in self.paths:
|
||||
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
|
||||
conn.row_factory = lambda _, x: x
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
'SELECT count(*) FROM sqlite_master '
|
||||
'WHERE type="table" AND name IN ("messages", "topics")',
|
||||
)
|
||||
if cur.fetchone()[0] != 2:
|
||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||
|
||||
cur.execute('SELECT name,id FROM topics')
|
||||
connmap: dict[int, Connection] = {
|
||||
row[1]: next((x for x in self.connections if x.topic == row[0]),
|
||||
None) # type: ignore
|
||||
for row in cur
|
||||
}
|
||||
|
||||
cur.execute(querystr, args)
|
||||
|
||||
for cid, timestamp, data in cur:
|
||||
yield connmap[cid], timestamp, data
|
||||
Loading…
x
Reference in New Issue
Block a user