Change to connection oriented reader API

This commit is contained in:
Marko Durkovic
2021-08-01 18:22:36 +02:00
committed by Florian Friesdorf
parent ebf357a0c6
commit f33e65b14a
13 changed files with 290 additions and 172 deletions
+35 -13
View File
@@ -4,10 +4,12 @@
from __future__ import annotations
from dataclasses import asdict
from typing import TYPE_CHECKING
from rosbags.rosbag1 import Reader, ReaderError
from rosbags.rosbag2 import Writer, WriterError
from rosbags.rosbag2.connection import Connection as WConnection
from rosbags.serde import ros1_to_cdr
from rosbags.typesys import get_types_from_msg, register_types
@@ -15,6 +17,8 @@ if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Dict, Optional
from rosbags.rosbag1.reader import Connection as RConnection
LATCH = """
- history: 3
depth: 0
@@ -38,6 +42,26 @@ class ConverterError(Exception):
"""Converter Error."""
def convert_connection(rconn: RConnection) -> WConnection:
"""Convert rosbag1 connection to rosbag2 connection.
Args:
rconn: Rosbag1 connection.
Returns:
Rosbag2 connection.
"""
return WConnection(
-1,
0,
rconn.topic,
rconn.msgtype,
'cdr',
LATCH if rconn.latching else '',
)
def convert(src: Path, dst: Optional[Path]) -> None:
"""Convert Rosbag1 to Rosbag2.
@@ -57,21 +81,19 @@ def convert(src: Path, dst: Optional[Path]) -> None:
try:
with Reader(src) as reader, Writer(dst) as writer:
typs: Dict[str, Any] = {}
for name, topic in reader.topics.items():
connection = next( # pragma: no branch
x for x in reader.connections.values() if x.topic == name
)
writer.add_topic(
name,
topic.msgtype,
offered_qos_profiles=LATCH if connection.latching else '',
)
typs.update(get_types_from_msg(topic.msgdef, topic.msgtype))
connmap: Dict[int, WConnection] = {}
for rconn in reader.connections.values():
candidate = convert_connection(rconn)
existing = next((x for x in writer.connections.values() if x == candidate), None)
wconn = existing if existing else writer.add_connection(**asdict(candidate))
connmap[rconn.cid] = wconn
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
register_types(typs)
for topic, msgtype, timestamp, data in reader.messages():
data = ros1_to_cdr(data, msgtype)
writer.write(topic, timestamp, data)
for rconn, timestamp, data in reader.messages():
data = ros1_to_cdr(data, rconn.msgtype)
writer.write(connmap[rconn.cid], timestamp, data)
except ReaderError as err:
raise ConverterError(f'Reading source bag: {err}') from err
except WriterError as err:
+6 -9
View File
@@ -249,7 +249,7 @@ class Header(dict):
raise ReaderError(f'Could not read time field {name!r}.') from err
@classmethod
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header':
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
"""Read header from file handle.
Args:
@@ -588,7 +588,7 @@ class Reader:
topics: Optional[Iterable[str]] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]:
) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
@@ -598,7 +598,7 @@ class Reader:
stop: Yield only messages before this timestamp (ns).
Yields:
Tuples of topic name, type, timestamp (ns), and rawdata.
Tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open or data corrupt.
@@ -635,13 +635,10 @@ class Reader:
if have != RecordType.MSGDATA:
raise ReaderError('Expected to find message data.')
connection = self.connections[header.get_uint32('conn')]
time = header.get_time('time')
data = read_bytes(chunk, read_uint32(chunk))
assert entry.time == time
yield connection.topic, connection.msgtype, time, data
connection = self.connections[header.get_uint32('conn')]
assert entry.time == header.get_time('time')
yield connection, entry.time, data
def __enter__(self) -> Reader:
"""Open rosbag1 when entering contextmanager."""
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 connection."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class Connection:
"""Connection metadata."""
id: int = field(compare=False) # pylint: disable=invalid-name
count: int = field(compare=False)
topic: str
msgtype: str
serialization_format: str
offered_qos_profiles: str
+61 -33
View File
@@ -13,9 +13,11 @@ from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML, YAMLError
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
class ReaderError(Exception):
@@ -96,11 +98,21 @@ class Reader:
if missing:
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
topics = [x['topic_metadata'] for x in self.metadata['topics_with_message_count']]
noncdr = {y for x in topics if (y := x['serialization_format']) != 'cdr'}
self.connections = {
idx + 1: Connection(
id=idx + 1,
count=x['message_count'],
topic=x['topic_metadata']['name'],
msgtype=x['topic_metadata']['type'],
serialization_format=x['topic_metadata']['serialization_format'],
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
}
noncdr = {
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
}
if noncdr:
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
self.topics = {x['name']: x['type'] for x in topics}
if self.compression_mode and (cfmt := self.compression_format) != 'zstd':
raise ReaderError(f'Compression format {cfmt!r} is not supported.')
@@ -149,22 +161,31 @@ class Reader:
mode = self.metadata.get('compression_mode', '').lower()
return mode if mode != 'none' else None
@property
def topics(self) -> Dict[str, Connection]:
"""Topic information.
For the moment this a dictionary mapping topic names to connections.
"""
return {x.topic: x for x in self.connections.values()}
def messages( # pylint: disable=too-many-locals
self,
topics: Iterable[str] = (),
connections: Iterable[Connection] = (),
start: Optional[int] = None,
stop: Optional[int] = None,
) -> Generator[Tuple[str, str, int, bytes], None, None]:
) -> Generator[Tuple[Connection, int, bytes], None, None]:
"""Read messages from bag.
Args:
topics: Iterable with topic names to filter for. An empty iterable
yields all messages.
connections: Iterable with connections to filter for. An empty
iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns).
Yields:
Tuples of topic name, type, timestamp (ns), and rawdata.
Tuples of connection, timestamp (ns), and rawdata.
Raises:
ReaderError: Bag not open.
@@ -173,7 +194,32 @@ class Reader:
if not self.bio:
raise ReaderError('Rosbag is not open.')
topics = tuple(topics)
query = [
'SELECT topics.id,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
clause = 'WHERE'
if connections:
topics = {x.topic for x in connections}
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
args += topics
clause = 'AND'
if start is not None:
query.append(f'{clause} messages.timestamp >= ?')
args.append(start)
clause = 'AND'
if stop is not None:
query.append(f'{clause} messages.timestamp < ?')
args.append(stop)
clause = 'AND'
query.append('ORDER BY timestamp')
querystr = ' '.join(query)
for filepath in self.paths:
with decompress(filepath, self.compression_mode == 'file') as path:
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
@@ -186,34 +232,16 @@ class Reader:
if cur.fetchone()[0] != 2:
raise ReaderError(f'Cannot open database {path} or database missing tables.')
query = [
'SELECT topics.name,topics.type,messages.timestamp,messages.data',
'FROM messages JOIN topics ON messages.topic_id=topics.id',
]
args: List[Any] = []
if topics:
query.append(f'WHERE topics.name IN ({",".join("?" for _ in topics)})')
args += topics
if start is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp >= ?')
args.append(start)
if stop is not None:
query.append(f'{"AND" if args else "WHERE"} messages.timestamp < ?')
args.append(stop)
query.append('ORDER BY timestamp')
cur.execute(' '.join(query), args)
cur.execute(querystr, args)
if self.compression_mode == 'message':
decomp = zstandard.ZstdDecompressor().decompress
for row in cur:
topic, msgtype, timestamp, data = row
yield topic, msgtype, timestamp, decomp(data)
cid, timestamp, data = row
yield self.connections[cid], timestamp, decomp(data)
else:
yield from cur
for cid, timestamp, data in cur:
yield self.connections[cid], timestamp, data
def __enter__(self) -> Reader:
"""Open rosbag2 when entering contextmanager."""
+44 -28
View File
@@ -1,6 +1,6 @@
# Copyright 2020-2021 Ternaris.
# SPDX-License-Identifier: Apache-2.0
"""Rosbag2 reader."""
"""Rosbag2 writer."""
from __future__ import annotations
@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING
import zstandard
from ruamel.yaml import YAML
from .connection import Connection
if TYPE_CHECKING:
from types import TracebackType
from typing import Any, Dict, Literal, Optional, Type, Union
@@ -77,10 +79,9 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.compression_mode = ''
self.compression_format = ''
self.compressor: Optional[zstandard.ZstdCompressor] = None
self.topics: Dict[str, Any] = {}
self.connections: Dict[int, Connection] = {}
self.conn = None
self.cursor: Optional[sqlite3.Cursor] = None
self.topics = {}
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
"""Enable compression on bag.
@@ -118,22 +119,27 @@ class Writer: # pylint: disable=too-many-instance-attributes
self.conn.executescript(self.SQLITE_SCHEMA)
self.cursor = self.conn.cursor()
def add_topic(
def add_connection(
self,
name: str,
typ: str,
topic: str,
msgtype: str,
serialization_format: str = 'cdr',
offered_qos_profiles: str = '',
):
"""Add a topic.
**_kw: Any,
) -> Connection:
"""Add a connection.
This function can only be called after opening a bag.
Args:
name: Topic name.
typ: Message type.
topic: Topic name.
msgtype: Message type.
serialization_format: Serialization format.
offered_qos_profiles: QOS Profile.
_kw: Ignored to allow consuming dicts from connection objects.
Returns:
Connection object.
Raises:
WriterError: Bag not open or topic previously registered.
@@ -141,17 +147,28 @@ class Writer: # pylint: disable=too-many-instance-attributes
"""
if not self.cursor:
raise WriterError('Bag was not opened.')
if name in self.topics:
raise WriterError(f'Topics can only be added once: {name!r}.')
meta = (len(self.topics) + 1, name, typ, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
self.topics[name] = [*meta, 0]
def write(self, topic: str, timestamp: int, data: bytes):
connection = Connection(
id=len(self.connections.values()) + 1,
count=0,
topic=topic,
msgtype=msgtype,
serialization_format=serialization_format,
offered_qos_profiles=offered_qos_profiles,
)
if connection in self.connections.values():
raise WriterError(f'Connection can only be added once: {connection!r}.')
self.connections[connection.id] = connection
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
return connection
def write(self, connection: Connection, timestamp: int, data: bytes):
"""Write message to rosbag2.
Args:
topic: Topic message belongs to.
connection: Connection to write message to.
timestamp: Message timestamp (ns).
data: Serialized message data.
@@ -161,19 +178,18 @@ class Writer: # pylint: disable=too-many-instance-attributes
"""
if not self.cursor:
raise WriterError('Bag was not opened.')
if topic not in self.topics:
raise WriterError(f'Tried to write to unknown topic {topic!r}.')
if connection not in self.connections.values():
raise WriterError(f'Tried to write to unknown connection {connection!r}.')
if self.compression_mode == 'message':
assert self.compressor
data = self.compressor.compress(data)
tmeta = self.topics[topic]
self.cursor.execute(
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
(tmeta[0], timestamp, data),
(connection.id, timestamp, data),
)
tmeta[-1] += 1
connection.count += 1
def close(self):
"""Close rosbag2 after writing.
@@ -214,13 +230,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
'topics_with_message_count': [
{
'topic_metadata': {
'name': x[1],
'type': x[2],
'serialization_format': x[3],
'offered_qos_profiles': x[4],
'name': x.topic,
'type': x.msgtype,
'serialization_format': x.serialization_format,
'offered_qos_profiles': x.offered_qos_profiles,
},
'message_count': x[5],
} for x in self.topics.values()
'message_count': x.count,
} for x in self.connections.values()
],
'compression_format': self.compression_format,
'compression_mode': self.compression_mode,