Make reader1 API match reader2

This commit is contained in:
Marko Durkovic 2021-09-13 10:49:57 +02:00
parent 5d0aa8277c
commit 885900df39
3 changed files with 12 additions and 8 deletions

View File

@ -55,8 +55,8 @@ class Connection(NamedTuple):
cid: int cid: int
topic: str topic: str
msgtype: str msgtype: str
md5sum: str
msgdef: str msgdef: str
md5sum: str
callerid: Optional[str] callerid: Optional[str]
latching: Optional[int] latching: Optional[int]
indexes: list indexes: list
@ -487,8 +487,8 @@ class Reader:
conn, conn,
topic, topic,
normalize_msgtype(typ), normalize_msgtype(typ),
md5sum,
msgdef, msgdef,
md5sum,
callerid, callerid,
latching, latching,
[], [],
@ -573,15 +573,15 @@ class Reader:
def messages( def messages(
self, self,
topics: Optional[Iterable[str]] = None, connections: Iterable[Connection] = (),
start: Optional[int] = None, start: Optional[int] = None,
stop: Optional[int] = None, stop: Optional[int] = None,
) -> Generator[tuple[Connection, int, bytes], None, None]: ) -> Generator[tuple[Connection, int, bytes], None, None]:
"""Read messages from bag. """Read messages from bag.
Args: Args:
topics: Iterable with topic names to filter for. An empty iterable connections: Iterable with connections to filter for. An empty
yields all messages. iterable disables filtering on connections.
start: Yield only messages at or after this timestamp (ns). start: Yield only messages at or after this timestamp (ns).
stop: Yield only messages before this timestamp (ns). stop: Yield only messages before this timestamp (ns).
@ -595,7 +595,10 @@ class Reader:
if not self.bio: if not self.bio:
raise ReaderError('Rosbag is not open.') raise ReaderError('Rosbag is not open.')
indexes = [x.indexes for x in self.connections.values() if not topics or x.topic in topics] if not connections:
connections = self.connections.values()
indexes = [x.indexes for x in connections]
for entry in heapq.merge(*indexes): for entry in heapq.merge(*indexes):
if start and entry.time < start: if start and entry.time < start:
continue continue

View File

@ -245,8 +245,8 @@ class Writer: # pylint: disable=too-many-instance-attributes
len(self.connections), len(self.connections),
topic, topic,
denormalize_msgtype(msgtype), denormalize_msgtype(msgtype),
md5sum,
msgdef, msgdef,
md5sum,
callerid, callerid,
latching, latching,
[], [],

View File

@ -254,7 +254,8 @@ def test_reader(tmp_path): # pylint: disable=too-many-statements
assert msgs[0][2] == b'MSGCONTENT5' assert msgs[0][2] == b'MSGCONTENT5'
assert msgs[1][2] == b'MSGCONTENT10' assert msgs[1][2] == b'MSGCONTENT10'
msgs = list(reader.messages(['/topic0'])) connections = [x for x in reader.connections.values() if x.topic == '/topic0']
msgs = list(reader.messages(connections))
assert len(msgs) == 1 assert len(msgs) == 1
assert msgs[0][2] == b'MSGCONTENT10' assert msgs[0][2] == b'MSGCONTENT10'