Change to connection oriented reader API
This commit is contained in:
committed by
Florian Friesdorf
parent
ebf357a0c6
commit
f33e65b14a
@@ -4,10 +4,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from rosbags.rosbag1 import Reader, ReaderError
|
||||
from rosbags.rosbag2 import Writer, WriterError
|
||||
from rosbags.rosbag2.connection import Connection as WConnection
|
||||
from rosbags.serde import ros1_to_cdr
|
||||
from rosbags.typesys import get_types_from_msg, register_types
|
||||
|
||||
@@ -15,6 +17,8 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from rosbags.rosbag1.reader import Connection as RConnection
|
||||
|
||||
LATCH = """
|
||||
- history: 3
|
||||
depth: 0
|
||||
@@ -38,6 +42,26 @@ class ConverterError(Exception):
|
||||
"""Converter Error."""
|
||||
|
||||
|
||||
def convert_connection(rconn: RConnection) -> WConnection:
|
||||
"""Convert rosbag1 connection to rosbag2 connection.
|
||||
|
||||
Args:
|
||||
rconn: Rosbag1 connection.
|
||||
|
||||
Returns:
|
||||
Rosbag2 connection.
|
||||
|
||||
"""
|
||||
return WConnection(
|
||||
-1,
|
||||
0,
|
||||
rconn.topic,
|
||||
rconn.msgtype,
|
||||
'cdr',
|
||||
LATCH if rconn.latching else '',
|
||||
)
|
||||
|
||||
|
||||
def convert(src: Path, dst: Optional[Path]) -> None:
|
||||
"""Convert Rosbag1 to Rosbag2.
|
||||
|
||||
@@ -57,21 +81,19 @@ def convert(src: Path, dst: Optional[Path]) -> None:
|
||||
try:
|
||||
with Reader(src) as reader, Writer(dst) as writer:
|
||||
typs: Dict[str, Any] = {}
|
||||
for name, topic in reader.topics.items():
|
||||
connection = next( # pragma: no branch
|
||||
x for x in reader.connections.values() if x.topic == name
|
||||
)
|
||||
writer.add_topic(
|
||||
name,
|
||||
topic.msgtype,
|
||||
offered_qos_profiles=LATCH if connection.latching else '',
|
||||
)
|
||||
typs.update(get_types_from_msg(topic.msgdef, topic.msgtype))
|
||||
connmap: Dict[int, WConnection] = {}
|
||||
|
||||
for rconn in reader.connections.values():
|
||||
candidate = convert_connection(rconn)
|
||||
existing = next((x for x in writer.connections.values() if x == candidate), None)
|
||||
wconn = existing if existing else writer.add_connection(**asdict(candidate))
|
||||
connmap[rconn.cid] = wconn
|
||||
typs.update(get_types_from_msg(rconn.msgdef, rconn.msgtype))
|
||||
register_types(typs)
|
||||
|
||||
for topic, msgtype, timestamp, data in reader.messages():
|
||||
data = ros1_to_cdr(data, msgtype)
|
||||
writer.write(topic, timestamp, data)
|
||||
for rconn, timestamp, data in reader.messages():
|
||||
data = ros1_to_cdr(data, rconn.msgtype)
|
||||
writer.write(connmap[rconn.cid], timestamp, data)
|
||||
except ReaderError as err:
|
||||
raise ConverterError(f'Reading source bag: {err}') from err
|
||||
except WriterError as err:
|
||||
|
||||
@@ -249,7 +249,7 @@ class Header(dict):
|
||||
raise ReaderError(f'Could not read time field {name!r}.') from err
|
||||
|
||||
@classmethod
|
||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> 'Header':
|
||||
def read(cls: type, src: BinaryIO, expect: Optional[RecordType] = None) -> Header:
|
||||
"""Read header from file handle.
|
||||
|
||||
Args:
|
||||
@@ -588,7 +588,7 @@ class Reader:
|
||||
topics: Optional[Iterable[str]] = None,
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[Tuple[str, str, int, bytes], None, None]:
|
||||
) -> Generator[Tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
@@ -598,7 +598,7 @@ class Reader:
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
Tuples of topic name, type, timestamp (ns), and rawdata.
|
||||
Tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not open or data corrupt.
|
||||
@@ -635,13 +635,10 @@ class Reader:
|
||||
if have != RecordType.MSGDATA:
|
||||
raise ReaderError('Expected to find message data.')
|
||||
|
||||
connection = self.connections[header.get_uint32('conn')]
|
||||
time = header.get_time('time')
|
||||
|
||||
data = read_bytes(chunk, read_uint32(chunk))
|
||||
|
||||
assert entry.time == time
|
||||
yield connection.topic, connection.msgtype, time, data
|
||||
connection = self.connections[header.get_uint32('conn')]
|
||||
assert entry.time == header.get_time('time')
|
||||
yield connection, entry.time, data
|
||||
|
||||
def __enter__(self) -> Reader:
|
||||
"""Open rosbag1 when entering contextmanager."""
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag2 connection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Connection:
|
||||
"""Connection metadata."""
|
||||
id: int = field(compare=False) # pylint: disable=invalid-name
|
||||
count: int = field(compare=False)
|
||||
topic: str
|
||||
msgtype: str
|
||||
serialization_format: str
|
||||
offered_qos_profiles: str
|
||||
@@ -13,9 +13,11 @@ from typing import TYPE_CHECKING
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
|
||||
class ReaderError(Exception):
|
||||
@@ -96,11 +98,21 @@ class Reader:
|
||||
if missing:
|
||||
raise ReaderError(f'Some database files are missing: {[str(x) for x in missing]!r}')
|
||||
|
||||
topics = [x['topic_metadata'] for x in self.metadata['topics_with_message_count']]
|
||||
noncdr = {y for x in topics if (y := x['serialization_format']) != 'cdr'}
|
||||
self.connections = {
|
||||
idx + 1: Connection(
|
||||
id=idx + 1,
|
||||
count=x['message_count'],
|
||||
topic=x['topic_metadata']['name'],
|
||||
msgtype=x['topic_metadata']['type'],
|
||||
serialization_format=x['topic_metadata']['serialization_format'],
|
||||
offered_qos_profiles=x['topic_metadata'].get('offered_qos_profiles', ''),
|
||||
) for idx, x in enumerate(self.metadata['topics_with_message_count'])
|
||||
}
|
||||
noncdr = {
|
||||
y for x in self.connections.values() if (y := x.serialization_format) != 'cdr'
|
||||
}
|
||||
if noncdr:
|
||||
raise ReaderError(f'Serialization format {noncdr!r} is not supported.')
|
||||
self.topics = {x['name']: x['type'] for x in topics}
|
||||
|
||||
if self.compression_mode and (cfmt := self.compression_format) != 'zstd':
|
||||
raise ReaderError(f'Compression format {cfmt!r} is not supported.')
|
||||
@@ -149,22 +161,31 @@ class Reader:
|
||||
mode = self.metadata.get('compression_mode', '').lower()
|
||||
return mode if mode != 'none' else None
|
||||
|
||||
@property
|
||||
def topics(self) -> Dict[str, Connection]:
|
||||
"""Topic information.
|
||||
|
||||
For the moment this a dictionary mapping topic names to connections.
|
||||
|
||||
"""
|
||||
return {x.topic: x for x in self.connections.values()}
|
||||
|
||||
def messages( # pylint: disable=too-many-locals
|
||||
self,
|
||||
topics: Iterable[str] = (),
|
||||
connections: Iterable[Connection] = (),
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> Generator[Tuple[str, str, int, bytes], None, None]:
|
||||
) -> Generator[Tuple[Connection, int, bytes], None, None]:
|
||||
"""Read messages from bag.
|
||||
|
||||
Args:
|
||||
topics: Iterable with topic names to filter for. An empty iterable
|
||||
yields all messages.
|
||||
connections: Iterable with connections to filter for. An empty
|
||||
iterable disables filtering on connections.
|
||||
start: Yield only messages at or after this timestamp (ns).
|
||||
stop: Yield only messages before this timestamp (ns).
|
||||
|
||||
Yields:
|
||||
Tuples of topic name, type, timestamp (ns), and rawdata.
|
||||
Tuples of connection, timestamp (ns), and rawdata.
|
||||
|
||||
Raises:
|
||||
ReaderError: Bag not open.
|
||||
@@ -173,7 +194,32 @@ class Reader:
|
||||
if not self.bio:
|
||||
raise ReaderError('Rosbag is not open.')
|
||||
|
||||
topics = tuple(topics)
|
||||
query = [
|
||||
'SELECT topics.id,messages.timestamp,messages.data',
|
||||
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
||||
]
|
||||
args: List[Any] = []
|
||||
clause = 'WHERE'
|
||||
|
||||
if connections:
|
||||
topics = {x.topic for x in connections}
|
||||
query.append(f'{clause} topics.name IN ({",".join("?" for _ in topics)})')
|
||||
args += topics
|
||||
clause = 'AND'
|
||||
|
||||
if start is not None:
|
||||
query.append(f'{clause} messages.timestamp >= ?')
|
||||
args.append(start)
|
||||
clause = 'AND'
|
||||
|
||||
if stop is not None:
|
||||
query.append(f'{clause} messages.timestamp < ?')
|
||||
args.append(stop)
|
||||
clause = 'AND'
|
||||
|
||||
query.append('ORDER BY timestamp')
|
||||
querystr = ' '.join(query)
|
||||
|
||||
for filepath in self.paths:
|
||||
with decompress(filepath, self.compression_mode == 'file') as path:
|
||||
conn = sqlite3.connect(f'file:{path}?immutable=1', uri=True)
|
||||
@@ -186,34 +232,16 @@ class Reader:
|
||||
if cur.fetchone()[0] != 2:
|
||||
raise ReaderError(f'Cannot open database {path} or database missing tables.')
|
||||
|
||||
query = [
|
||||
'SELECT topics.name,topics.type,messages.timestamp,messages.data',
|
||||
'FROM messages JOIN topics ON messages.topic_id=topics.id',
|
||||
]
|
||||
args: List[Any] = []
|
||||
|
||||
if topics:
|
||||
query.append(f'WHERE topics.name IN ({",".join("?" for _ in topics)})')
|
||||
args += topics
|
||||
|
||||
if start is not None:
|
||||
query.append(f'{"AND" if args else "WHERE"} messages.timestamp >= ?')
|
||||
args.append(start)
|
||||
|
||||
if stop is not None:
|
||||
query.append(f'{"AND" if args else "WHERE"} messages.timestamp < ?')
|
||||
args.append(stop)
|
||||
|
||||
query.append('ORDER BY timestamp')
|
||||
cur.execute(' '.join(query), args)
|
||||
cur.execute(querystr, args)
|
||||
|
||||
if self.compression_mode == 'message':
|
||||
decomp = zstandard.ZstdDecompressor().decompress
|
||||
for row in cur:
|
||||
topic, msgtype, timestamp, data = row
|
||||
yield topic, msgtype, timestamp, decomp(data)
|
||||
cid, timestamp, data = row
|
||||
yield self.connections[cid], timestamp, decomp(data)
|
||||
else:
|
||||
yield from cur
|
||||
for cid, timestamp, data in cur:
|
||||
yield self.connections[cid], timestamp, data
|
||||
|
||||
def __enter__(self) -> Reader:
|
||||
"""Open rosbag2 when entering contextmanager."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright 2020-2021 Ternaris.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Rosbag2 reader."""
|
||||
"""Rosbag2 writer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import TYPE_CHECKING
|
||||
import zstandard
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Literal, Optional, Type, Union
|
||||
@@ -77,10 +79,9 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.compression_mode = ''
|
||||
self.compression_format = ''
|
||||
self.compressor: Optional[zstandard.ZstdCompressor] = None
|
||||
self.topics: Dict[str, Any] = {}
|
||||
self.connections: Dict[int, Connection] = {}
|
||||
self.conn = None
|
||||
self.cursor: Optional[sqlite3.Cursor] = None
|
||||
self.topics = {}
|
||||
|
||||
def set_compression(self, mode: CompressionMode, fmt: CompressionFormat):
|
||||
"""Enable compression on bag.
|
||||
@@ -118,22 +119,27 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
self.conn.executescript(self.SQLITE_SCHEMA)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def add_topic(
|
||||
def add_connection(
|
||||
self,
|
||||
name: str,
|
||||
typ: str,
|
||||
topic: str,
|
||||
msgtype: str,
|
||||
serialization_format: str = 'cdr',
|
||||
offered_qos_profiles: str = '',
|
||||
):
|
||||
"""Add a topic.
|
||||
**_kw: Any,
|
||||
) -> Connection:
|
||||
"""Add a connection.
|
||||
|
||||
This function can only be called after opening a bag.
|
||||
|
||||
Args:
|
||||
name: Topic name.
|
||||
typ: Message type.
|
||||
topic: Topic name.
|
||||
msgtype: Message type.
|
||||
serialization_format: Serialization format.
|
||||
offered_qos_profiles: QOS Profile.
|
||||
_kw: Ignored to allow consuming dicts from connection objects.
|
||||
|
||||
Returns:
|
||||
Connection object.
|
||||
|
||||
Raises:
|
||||
WriterError: Bag not open or topic previously registered.
|
||||
@@ -141,17 +147,28 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
if not self.cursor:
|
||||
raise WriterError('Bag was not opened.')
|
||||
if name in self.topics:
|
||||
raise WriterError(f'Topics can only be added once: {name!r}.')
|
||||
meta = (len(self.topics) + 1, name, typ, serialization_format, offered_qos_profiles)
|
||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||
self.topics[name] = [*meta, 0]
|
||||
|
||||
def write(self, topic: str, timestamp: int, data: bytes):
|
||||
connection = Connection(
|
||||
id=len(self.connections.values()) + 1,
|
||||
count=0,
|
||||
topic=topic,
|
||||
msgtype=msgtype,
|
||||
serialization_format=serialization_format,
|
||||
offered_qos_profiles=offered_qos_profiles,
|
||||
)
|
||||
if connection in self.connections.values():
|
||||
raise WriterError(f'Connection can only be added once: {connection!r}.')
|
||||
|
||||
self.connections[connection.id] = connection
|
||||
meta = (connection.id, topic, msgtype, serialization_format, offered_qos_profiles)
|
||||
self.cursor.execute('INSERT INTO topics VALUES(?, ?, ?, ?, ?)', meta)
|
||||
return connection
|
||||
|
||||
def write(self, connection: Connection, timestamp: int, data: bytes):
|
||||
"""Write message to rosbag2.
|
||||
|
||||
Args:
|
||||
topic: Topic message belongs to.
|
||||
connection: Connection to write message to.
|
||||
timestamp: Message timestamp (ns).
|
||||
data: Serialized message data.
|
||||
|
||||
@@ -161,19 +178,18 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
if not self.cursor:
|
||||
raise WriterError('Bag was not opened.')
|
||||
if topic not in self.topics:
|
||||
raise WriterError(f'Tried to write to unknown topic {topic!r}.')
|
||||
if connection not in self.connections.values():
|
||||
raise WriterError(f'Tried to write to unknown connection {connection!r}.')
|
||||
|
||||
if self.compression_mode == 'message':
|
||||
assert self.compressor
|
||||
data = self.compressor.compress(data)
|
||||
|
||||
tmeta = self.topics[topic]
|
||||
self.cursor.execute(
|
||||
'INSERT INTO messages (topic_id, timestamp, data) VALUES(?, ?, ?)',
|
||||
(tmeta[0], timestamp, data),
|
||||
(connection.id, timestamp, data),
|
||||
)
|
||||
tmeta[-1] += 1
|
||||
connection.count += 1
|
||||
|
||||
def close(self):
|
||||
"""Close rosbag2 after writing.
|
||||
@@ -214,13 +230,13 @@ class Writer: # pylint: disable=too-many-instance-attributes
|
||||
'topics_with_message_count': [
|
||||
{
|
||||
'topic_metadata': {
|
||||
'name': x[1],
|
||||
'type': x[2],
|
||||
'serialization_format': x[3],
|
||||
'offered_qos_profiles': x[4],
|
||||
'name': x.topic,
|
||||
'type': x.msgtype,
|
||||
'serialization_format': x.serialization_format,
|
||||
'offered_qos_profiles': x.offered_qos_profiles,
|
||||
},
|
||||
'message_count': x[5],
|
||||
} for x in self.topics.values()
|
||||
'message_count': x.count,
|
||||
} for x in self.connections.values()
|
||||
],
|
||||
'compression_format': self.compression_format,
|
||||
'compression_mode': self.compression_mode,
|
||||
|
||||
Reference in New Issue
Block a user