Refactor rosbag2 Reader
This commit is contained in:
parent
d7d24c4478
commit
eaa64002b8
@ -7,7 +7,8 @@ in the rosbag2 format.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .reader import Reader, ReaderError
|
from .errors import ReaderError
|
||||||
|
from .reader import Reader
|
||||||
from .writer import Writer, WriterError
|
from .writer import Writer, WriterError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
9
src/rosbags/rosbag2/errors.py
Normal file
9
src/rosbags/rosbag2/errors.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Copyright 2020-2023 Ternaris.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Rosbag2 errors."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class ReaderError(Exception):
|
||||||
|
"""Reader Error."""
|
||||||
@ -4,11 +4,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
import zstandard
|
import zstandard
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
@ -16,40 +14,43 @@ from ruamel.yaml.error import YAMLError
|
|||||||
|
|
||||||
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
|
from rosbags.interfaces import Connection, ConnectionExtRosbag2, TopicInfo
|
||||||
|
|
||||||
|
from .errors import ReaderError
|
||||||
|
from .storage_sqlite3 import ReaderSqlite3
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Generator, Iterable, Literal, Optional, Type, Union
|
from typing import Generator, Iterable, Literal, Optional, Type, Union
|
||||||
|
|
||||||
from .metadata import FileInformation, Metadata
|
from .metadata import FileInformation, Metadata
|
||||||
|
|
||||||
|
|
||||||
class ReaderError(Exception):
|
class StorageProtocol(Protocol):
|
||||||
"""Reader Error."""
|
"""Storage Protocol."""
|
||||||
|
|
||||||
|
def __init__(self, paths: Iterable[Path], connections: Iterable[Connection]):
|
||||||
|
"""Initialize."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@contextmanager
|
def open(self) -> None:
|
||||||
def decompress(path: Path, do_decompress: bool) -> Generator[Path, None, None]:
|
"""Open file."""
|
||||||
"""Transparent rosbag2 database decompression context.
|
raise NotImplementedError
|
||||||
|
|
||||||
This context manager will yield a path to the decompressed file contents.
|
def close(self) -> None:
|
||||||
|
"""Close file."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
Args:
|
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||||
path: Potentially compressed file.
|
"""Get message definitions."""
|
||||||
do_decompress: Flag indicating if decompression shall occur.
|
raise NotImplementedError
|
||||||
|
|
||||||
Yields:
|
def messages(
|
||||||
Path of transparently decompressed file.
|
self,
|
||||||
|
connections: Iterable[Connection] = (),
|
||||||
"""
|
start: Optional[int] = None,
|
||||||
if do_decompress:
|
stop: Optional[int] = None,
|
||||||
decomp = zstandard.ZstdDecompressor()
|
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||||
with TemporaryDirectory() as tempdir:
|
"""Get messages from file."""
|
||||||
dbfile = Path(tempdir, path.stem)
|
raise NotImplementedError
|
||||||
with path.open('rb') as infile, dbfile.open('wb') as outfile:
|
|
||||||
decomp.copy_stream(infile, outfile)
|
|
||||||
yield dbfile
|
|
||||||
else:
|
|
||||||
yield path
|
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
class Reader:
|
||||||
@ -69,6 +70,12 @@ class Reader:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=too-many-instance-attributes
|
||||||
|
|
||||||
|
STORAGE_PLUGINS: dict[str, Type[StorageProtocol]] = {
|
||||||
|
'sqlite3': ReaderSqlite3,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, path: Union[Path, str]):
|
def __init__(self, path: Union[Path, str]):
|
||||||
"""Open rosbag and check metadata.
|
"""Open rosbag and check metadata.
|
||||||
|
|
||||||
@ -82,7 +89,6 @@ class Reader:
|
|||||||
path = Path(path)
|
path = Path(path)
|
||||||
yamlpath = path / 'metadata.yaml'
|
yamlpath = path / 'metadata.yaml'
|
||||||
self.path = path
|
self.path = path
|
||||||
self.bio = False
|
|
||||||
try:
|
try:
|
||||||
yaml = YAML(typ='safe')
|
yaml = YAML(typ='safe')
|
||||||
dct = yaml.load(yamlpath.read_text())
|
dct = yaml.load(yamlpath.read_text())
|
||||||
@ -95,7 +101,7 @@ class Reader:
|
|||||||
self.metadata: Metadata = dct['rosbag2_bagfile_information']
|
self.metadata: Metadata = dct['rosbag2_bagfile_information']
|
||||||
if (ver := self.metadata['version']) > 6:
|
if (ver := self.metadata['version']) > 6:
|
||||||
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
|
raise ReaderError(f'Rosbag2 version {ver} not supported; please report issue.')
|
||||||
if storageid := self.metadata['storage_identifier'] != 'sqlite3':
|
if (storageid := self.metadata['storage_identifier']) not in self.STORAGE_PLUGINS:
|
||||||
raise ReaderError(
|
raise ReaderError(
|
||||||
f'Storage plugin {storageid!r} not supported; please report issue.',
|
f'Storage plugin {storageid!r} not supported; please report issue.',
|
||||||
)
|
)
|
||||||
@ -131,20 +137,12 @@ class Reader:
|
|||||||
|
|
||||||
self.files: list[FileInformation] = self.metadata.get('files', [])[:]
|
self.files: list[FileInformation] = self.metadata.get('files', [])[:]
|
||||||
self.custom_data: dict[str, str] = self.metadata.get('custom_data', {})
|
self.custom_data: dict[str, str] = self.metadata.get('custom_data', {})
|
||||||
|
|
||||||
|
self.tmpdir: Optional[TemporaryDirectory] = None
|
||||||
|
self.storage: Optional[StorageProtocol] = None
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
raise ReaderError(f'A metadata key is missing {exc!r}.') from None
|
||||||
|
|
||||||
def open(self) -> None:
|
|
||||||
"""Open rosbag2."""
|
|
||||||
# Future storage formats will require file handles.
|
|
||||||
self.bio = True
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close rosbag2."""
|
|
||||||
# Future storage formats will require file handles.
|
|
||||||
assert self.bio
|
|
||||||
self.bio = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def duration(self) -> int:
|
def duration(self) -> int:
|
||||||
"""Duration in nanoseconds between earliest and latest messages."""
|
"""Duration in nanoseconds between earliest and latest messages."""
|
||||||
@ -183,7 +181,50 @@ class Reader:
|
|||||||
"""Topic information."""
|
"""Topic information."""
|
||||||
return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
|
return {x.topic: TopicInfo(x.msgtype, x.msgdef, x.msgcount, [x]) for x in self.connections}
|
||||||
|
|
||||||
def messages( # pylint: disable=too-many-locals
|
def open(self) -> None:
|
||||||
|
"""Open rosbag2."""
|
||||||
|
storage_paths = []
|
||||||
|
if self.compression_mode == 'file':
|
||||||
|
self.tmpdir = TemporaryDirectory() # pylint: disable=consider-using-with
|
||||||
|
tmpdir = self.tmpdir.name
|
||||||
|
decomp = zstandard.ZstdDecompressor()
|
||||||
|
for path in self.paths:
|
||||||
|
storage_file = Path(tmpdir, path.stem)
|
||||||
|
with path.open('rb') as infile, storage_file.open('wb') as outfile:
|
||||||
|
decomp.copy_stream(infile, outfile)
|
||||||
|
storage_paths.append(storage_file)
|
||||||
|
else:
|
||||||
|
storage_paths = self.paths[:]
|
||||||
|
|
||||||
|
self.storage = self.STORAGE_PLUGINS[self.metadata['storage_identifier']](
|
||||||
|
storage_paths,
|
||||||
|
self.connections,
|
||||||
|
)
|
||||||
|
self.storage.open()
|
||||||
|
definitions = self.storage.get_definitions()
|
||||||
|
for idx, conn in enumerate(self.connections):
|
||||||
|
if desc := definitions.get(conn.msgtype):
|
||||||
|
self.connections[idx] = Connection(
|
||||||
|
id=conn.id,
|
||||||
|
topic=conn.topic,
|
||||||
|
msgtype=conn.msgtype,
|
||||||
|
msgdef=desc[1],
|
||||||
|
md5sum=desc[0],
|
||||||
|
msgcount=conn.msgcount,
|
||||||
|
ext=conn.ext,
|
||||||
|
owner=conn.owner,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close rosbag2."""
|
||||||
|
assert self.storage
|
||||||
|
self.storage.close()
|
||||||
|
self.storage = None
|
||||||
|
if self.tmpdir:
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
self.tmpdir = None
|
||||||
|
|
||||||
|
def messages(
|
||||||
self,
|
self,
|
||||||
connections: Iterable[Connection] = (),
|
connections: Iterable[Connection] = (),
|
||||||
start: Optional[int] = None,
|
start: Optional[int] = None,
|
||||||
@ -201,67 +242,18 @@ class Reader:
|
|||||||
tuples of connection, timestamp (ns), and rawdata.
|
tuples of connection, timestamp (ns), and rawdata.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ReaderError: Bag not open.
|
ReaderError: If reader was not opened.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.bio:
|
if not self.storage:
|
||||||
raise ReaderError('Rosbag is not open.')
|
raise ReaderError('Rosbag is not open.')
|
||||||
|
|
||||||
query = [
|
|
||||||
'SELECT topics.id,messages.timestamp,messages.data',
|
|
||||||
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
|
||||||
]
|
|
||||||
args: list[Any] = []
|
|
||||||
clause = 'WHERE'
|
|
||||||
|
|
||||||
if connections:
|
|
||||||
topics = {x.topic for x in connections}
|
|
||||||
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
|
|
||||||
args += topics
|
|
||||||
clause = 'AND'
|
|
||||||
|
|
||||||
if start is not None:
|
|
||||||
query.append(f'{clause} messages.timestamp >= ?')
|
|
||||||
args.append(start)
|
|
||||||
clause = 'AND'
|
|
||||||
|
|
||||||
if stop is not None:
|
|
||||||
query.append(f'{clause} messages.timestamp < ?')
|
|
||||||
args.append(stop)
|
|
||||||
clause = 'AND'
|
|
||||||
|
|
||||||
query.append('ORDER BY timestamp')
|
|
||||||
querystr = ' '.join(query)
|
|
||||||
|
|
||||||
for filepath in self.paths:
|
|
||||||
with decompress(filepath, self.compression_mode == 'file') as path:
|
|
||||||
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
|
|
||||||
conn.row_factory = lambda _, x: x
|
|
||||||
cur = conn.cursor()
|
|
||||||
cur.execute(
|
|
||||||
'SELECT count(*) FROM sqlite_master '
|
|
||||||
'WHERE type="table" AND name IN ("messages", "topics")',
|
|
||||||
)
|
|
||||||
if cur.fetchone()[0] != 2:
|
|
||||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
|
||||||
|
|
||||||
cur.execute('SELECT name,id FROM topics')
|
|
||||||
connmap: dict[int, Connection] = {
|
|
||||||
row[1]: next((x for x in self.connections if x.topic == row[0]),
|
|
||||||
None) # type: ignore
|
|
||||||
for row in cur
|
|
||||||
}
|
|
||||||
|
|
||||||
cur.execute(querystr, args)
|
|
||||||
|
|
||||||
if self.compression_mode == 'message':
|
if self.compression_mode == 'message':
|
||||||
decomp = zstandard.ZstdDecompressor().decompress
|
decomp = zstandard.ZstdDecompressor().decompress
|
||||||
for row in cur:
|
for connection, timestamp, data in self.storage.messages(connections, start, stop):
|
||||||
cid, timestamp, data = row
|
yield connection, timestamp, decomp(data)
|
||||||
yield connmap[cid], timestamp, decomp(data)
|
|
||||||
else:
|
else:
|
||||||
for cid, timestamp, data in cur:
|
yield from self.storage.messages(connections, start, stop)
|
||||||
yield connmap[cid], timestamp, data
|
|
||||||
|
|
||||||
def __enter__(self) -> Reader:
|
def __enter__(self) -> Reader:
|
||||||
"""Open rosbag2 when entering contextmanager."""
|
"""Open rosbag2 when entering contextmanager."""
|
||||||
|
|||||||
119
src/rosbags/rosbag2/storage_sqlite3.py
Normal file
119
src/rosbags/rosbag2/storage_sqlite3.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Copyright 2020-2023 Ternaris.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Sqlite3 storage."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from .errors import ReaderError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator, Iterable, Optional
|
||||||
|
|
||||||
|
from rosbags.interfaces import Connection
|
||||||
|
|
||||||
|
|
||||||
|
class ReaderSqlite3:
|
||||||
|
"""Sqlite3 storage reader."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
paths: Iterable[Path],
|
||||||
|
connections: Iterable[Connection],
|
||||||
|
):
|
||||||
|
"""Set up storage reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: Paths of storage files.
|
||||||
|
connections: List of connections.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.opened = False
|
||||||
|
self.paths = paths
|
||||||
|
self.connections = connections
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
"""Open rosbag2."""
|
||||||
|
self.opened = True
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close rosbag2."""
|
||||||
|
assert self.opened
|
||||||
|
self.opened = False
|
||||||
|
|
||||||
|
def get_definitions(self) -> dict[str, tuple[str, str]]:
|
||||||
|
"""Get message definitions."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def messages( # pylint: disable=too-many-locals
|
||||||
|
self,
|
||||||
|
connections: Iterable[Connection] = (),
|
||||||
|
start: Optional[int] = None,
|
||||||
|
stop: Optional[int] = None,
|
||||||
|
) -> Generator[tuple[Connection, int, bytes], None, None]:
|
||||||
|
"""Read messages from bag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connections: Iterable with connections to filter for. An empty
|
||||||
|
iterable disables filtering on connections.
|
||||||
|
start: Yield only messages at or after this timestamp (ns).
|
||||||
|
stop: Yield only messages before this timestamp (ns).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
tuples of connection, timestamp (ns), and rawdata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ReaderError: Bag not open.
|
||||||
|
|
||||||
|
"""
|
||||||
|
query = [
|
||||||
|
'SELECT topics.id,messages.timestamp,messages.data',
|
||||||
|
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
||||||
|
]
|
||||||
|
args: list[Any] = []
|
||||||
|
clause = 'WHERE'
|
||||||
|
|
||||||
|
if connections:
|
||||||
|
topics = {x.topic for x in connections}
|
||||||
|
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
|
||||||
|
args += topics
|
||||||
|
clause = 'AND'
|
||||||
|
|
||||||
|
if start is not None:
|
||||||
|
query.append(f'{clause} messages.timestamp >= ?')
|
||||||
|
args.append(start)
|
||||||
|
clause = 'AND'
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
query.append(f'{clause} messages.timestamp < ?')
|
||||||
|
args.append(stop)
|
||||||
|
clause = 'AND'
|
||||||
|
|
||||||
|
query.append('ORDER BY timestamp')
|
||||||
|
querystr = ' '.join(query)
|
||||||
|
|
||||||
|
for path in self.paths:
|
||||||
|
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
|
||||||
|
conn.row_factory = lambda _, x: x
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
'SELECT count(*) FROM sqlite_master '
|
||||||
|
'WHERE type="table" AND name IN ("messages", "topics")',
|
||||||
|
)
|
||||||
|
if cur.fetchone()[0] != 2:
|
||||||
|
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||||
|
|
||||||
|
cur.execute('SELECT name,id FROM topics')
|
||||||
|
connmap: dict[int, Connection] = {
|
||||||
|
row[1]: next((x for x in self.connections if x.topic == row[0]),
|
||||||
|
None) # type: ignore
|
||||||
|
for row in cur
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.execute(querystr, args)
|
||||||
|
|
||||||
|
for cid, timestamp, data in cur:
|
||||||
|
yield connmap[cid], timestamp, data
|
||||||
Loading…
x
Reference in New Issue
Block a user