Skip to content

Feature: Join Latest #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5f6d11e
[JOIN] Split _as_stateful into helper functions
gwaramadze Apr 7, 2025
fa9a869
[JOIN] Refactor _as_stateful to accept sdf
gwaramadze Apr 8, 2025
92fb704
[JOIN] Implement join
gwaramadze Apr 9, 2025
4fa3557
Use sdf.register_store method
gwaramadze Apr 21, 2025
14c94f0
First test
gwaramadze Apr 21, 2025
6901e2a
Refactor test into fixtures
gwaramadze Apr 21, 2025
ae883cd
Ensure both sides are copartitioned
gwaramadze Apr 22, 2025
6b59a97
Add on_overlap and merger params
gwaramadze Apr 22, 2025
7548ffd
Add how param with inner and left options
gwaramadze Apr 23, 2025
bc1d81c
Invert left<>right in tests
gwaramadze Apr 23, 2025
c64bc70
Add retention_ms param
gwaramadze Apr 23, 2025
d3167d3
Correct after rebase
gwaramadze May 6, 2025
5a1c06f
Rename to join_latest
gwaramadze May 8, 2025
707ee03
Change order right > left
gwaramadze May 9, 2025
6ef5a5f
Fix return type
gwaramadze May 9, 2025
4c7cfb8
Self joins not supported
gwaramadze May 9, 2025
21b1c03
Optimize TimestampedPartitionTransaction._get_min_eligible_timestamp
daniil-quix May 14, 2025
637aab2
Optimize RocksDBStorePartition.iter_items(backwards=True)
daniil-quix May 15, 2025
40dbaa7
WIP: refactor StreamingDataFrame.join_latest
daniil-quix May 19, 2025
20f93b2
join_latest: register store only for the right side
daniil-quix May 19, 2025
1ddfab5
join_latest: rename on_overlap -> on_merge
daniil-quix May 19, 2025
c6a20ba
Move ensure_topics_copartitioned() to TopicManager
daniil-quix May 19, 2025
155ee63
Replace WindowStoreAlreadyRegisteredError
daniil-quix May 19, 2025
dfb7247
TimestampedStore: Rename get_last -> get_latest
daniil-quix May 19, 2025
3d6cd33
Tests for test_register_timestamped_store_twice
daniil-quix May 19, 2025
c58b1ef
join_latest: update docstring
daniil-quix May 19, 2025
8be9263
join_latest: use is None check in mergers
daniil-quix May 19, 2025
a2c5790
Rename join_latest -> join_asof
daniil-quix May 20, 2025
c129d8d
join_latest: docs
daniil-quix May 21, 2025
e1f5dd8
Add counter to timestamped store (#886)
gwaramadze May 21, 2025
ada71fb
join_latest: remove newlines in docs
daniil-quix May 21, 2025
0711de7
join_latest: add custom on_merge example
daniil-quix May 21, 2025
583341f
join_latest: docs spelling
daniil-quix May 21, 2025
cc2f5b1
Update docs/joins.md
daniil-quix May 21, 2025
4e5629b
Update docs/joins.md
daniil-quix May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 70 additions & 31 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@
from quixstreams.models.serializers import DeserializerType, SerializerType
from quixstreams.sinks import BaseSink
from quixstreams.state.base import State
from quixstreams.state.manager import StoreTypes
from quixstreams.utils.printing import (
DEFAULT_COLUMN_NAME,
DEFAULT_LIVE,
DEFAULT_LIVE_SLOWDOWN,
)

from . import joins
from .exceptions import InvalidOperation, TopicPartitionsMismatch
from .registry import DataFrameRegistry
from .series import StreamingSeries
Expand Down Expand Up @@ -275,7 +277,7 @@ def func(d: dict, state: State):
Default - `False`.
"""
if stateful:
self._register_store()
self.register_store()
# Force the callback to accept metadata
if metadata:
with_metadata_func = cast(ApplyWithMetadataCallbackStateful, func)
Expand All @@ -284,11 +286,7 @@ def func(d: dict, state: State):
cast(ApplyCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload]
else:
stream = self.stream.add_apply(
Expand Down Expand Up @@ -384,7 +382,7 @@ def func(values: list, state: State):
:return: the updated StreamingDataFrame instance (reassignment NOT required).
"""
if stateful:
self._register_store()
self.register_store()
# Force the callback to accept metadata
if metadata:
with_metadata_func = cast(UpdateWithMetadataCallbackStateful, func)
Expand All @@ -393,11 +391,7 @@ def func(values: list, state: State):
cast(UpdateCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
return self._add_update(stateful_func, metadata=True)
else:
return self._add_update(
Expand Down Expand Up @@ -486,7 +480,7 @@ def func(d: dict, state: State):
"""

if stateful:
self._register_store()
self.register_store()
# Force the callback to accept metadata
if metadata:
with_metadata_func = cast(FilterWithMetadataCallbackStateful, func)
Expand All @@ -495,11 +489,7 @@ def func(d: dict, state: State):
cast(FilterCallbackStateful, func)
)

stateful_func = _as_stateful(
func=with_metadata_func,
processing_context=self._processing_context,
stream_id=self.stream_id,
)
stateful_func = _as_stateful(with_metadata_func, self)
stream = self.stream.add_filter(stateful_func, metadata=True)
else:
stream = self.stream.add_filter( # type: ignore[call-overload]
Expand Down Expand Up @@ -1656,12 +1646,62 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
*self.topics, *other.topics, stream=merged_stream
)

def ensure_topics_copartitioned(self):
partitions_counts = set(t.broker_config.num_partitions for t in self._topics)
def join_latest(
self,
right: "StreamingDataFrame",
how: joins.How = "inner",
on_overlap: Union[joins.OnOverlap, Callable[[Any, Any], Any]] = "raise",
grace_ms: Union[int, timedelta] = timedelta(days=7),
# TODO: Allow passing the store name here?
) -> "StreamingDataFrame":
"""
Join the StreamingDataFrame with the latest effective values on the right side.
This join is built with enrichment use case in mind when the left side is a data stream and the right side
is metadata.

The underlying topics of the dataframes must have the same number of partitions
and use the same partitioner (keys should be distributed between partitions using the same method).

Joining dataframes belonging to the same topics (aka "self-join") is not supported as of now.

How the joining works:
- Records from the right side get written to the state store without emitting any updates downstream.
- Records on the left side query the right store for the values with the same **key** and the timestamp lower or equal to the record's timestamp.
Left side emits data downstream.
- If the match is found, the two records are merged together into a new one according to the `on_overlap` logic.
- The size of the right store is controlled by the "grace_ms":
a newly added "right" record expires other values with the same key with timestamps below "<current timestamp> - <grace_ms>".

:param right: a StreamingDataFrame to join with.

:param how: the join strategy. Can be one of:
- "inner" - emits the result when the match on the right side is found for the left record.
- "left" - emits the result for each left record even if there is no match on the right side.
Default - `"inner"`.

:param on_overlap: how to merge the matched records together assuming they are dictionaries:
- "raise" - fail with an error if the same keys are found in both dictionaries
- "keep-left" - prefer the keys from the left record.
- "keep-right" - prefer the keys from the right record
- callback - a callback in form "(<left>, <right>) -> <new record>" to merge the records manually.
Use it to customize the merging logic or when one of the records is not a dictionary.

:param grace_ms: how long to keep the right records in the store.
Can be specified as either an `int` representing milliseconds or as a `timedelta` object.
Default - 7 days.

"""
return joins.JoinLatest(how=how, on_overlap=on_overlap, grace_ms=grace_ms).join(
self, right
)

def ensure_topics_copartitioned(self, *topics: Topic):
topics = topics or self._topics
partitions_counts = set(t.broker_config.num_partitions for t in topics)
if len(partitions_counts) > 1:
msg = ", ".join(
f'"{t.name}" ({t.broker_config.num_partitions} partitions)'
for t in self._topics
for t in topics
)
raise TopicPartitionsMismatch(
f"The underlying topics must have the same number of partitions to use State; got {msg}"
Expand Down Expand Up @@ -1689,7 +1729,7 @@ def _add_update(
self._stream = self._stream.add_update(func, metadata=metadata) # type: ignore[call-overload]
return self

def _register_store(self):
def register_store(self, store_type: Optional[StoreTypes] = None) -> None:
"""
Register the default store for the current stream_id in StateStoreManager.
"""
Expand All @@ -1699,7 +1739,9 @@ def _register_store(self):
changelog_topic_config = self._topic_manager.derive_topic_config(self._topics)

self._processing_context.state_manager.register_store(
stream_id=self.stream_id, changelog_config=changelog_topic_config
stream_id=self.stream_id,
store_type=store_type,
changelog_config=changelog_topic_config,
)

def _groupby_key(
Expand Down Expand Up @@ -1847,19 +1889,16 @@ def wrapper(

def _as_stateful(
func: Callable[[Any, Any, int, Any, State], T],
processing_context: ProcessingContext,
stream_id: str,
sdf: StreamingDataFrame,
) -> Callable[[Any, Any, int, Any], T]:
@functools.wraps(func)
def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
ctx = message_context()
transaction = processing_context.checkpoint.get_store_transaction(
stream_id=stream_id,
partition=ctx.partition,
)
# Pass a State object with an interface limited to the key updates only
# and prefix all the state keys by the message key
state = transaction.as_state(prefix=key)
state = sdf.processing_context.checkpoint.get_store_transaction(
stream_id=sdf.stream_id,
partition=message_context().partition,
).as_state(prefix=key)
return func(value, key, timestamp, headers, state)

return wrapper
3 changes: 3 additions & 0 deletions quixstreams/dataframe/joins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .join_latest import How as How
from .join_latest import JoinLatest as JoinLatest
from .join_latest import OnOverlap as OnOverlap
113 changes: 113 additions & 0 deletions quixstreams/dataframe/joins/join_latest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import typing
from datetime import timedelta
from typing import Any, Callable, Literal, Union, cast, get_args

from quixstreams.context import message_context
from quixstreams.dataframe.utils import ensure_milliseconds
from quixstreams.models.topics.manager import TopicManager
from quixstreams.state.rocksdb.timestamped import TimestampedPartitionTransaction

from .utils import keep_left_merger, keep_right_merger, raise_merger

if typing.TYPE_CHECKING:
from quixstreams.dataframe.dataframe import StreamingDataFrame

DISCARDED = object()
How = Literal["inner", "left"]
How_choices = get_args(How)

OnOverlap = Literal["keep-left", "keep-right", "raise"]
OnOverlap_choices = get_args(OnOverlap)


class JoinLatest:
def __init__(
self,
how: How,
on_overlap: Union[OnOverlap, Callable[[Any, Any], Any]],
grace_ms: Union[int, timedelta],
store_name: str = "join",
):
if how not in How_choices:
raise ValueError(
f'Invalid "how" value: {how}. '
f"Valid choices are: {', '.join(How_choices)}."
)
self._how = how

if callable(on_overlap):
self._merger = on_overlap
elif on_overlap == "keep-left":
self._merger = keep_left_merger
elif on_overlap == "keep-right":
self._merger = keep_right_merger
elif on_overlap == "raise":
self._merger = raise_merger
else:
raise ValueError(
f'Invalid "on_overlap" value: {on_overlap}. '
f"Provide either one of {', '.join(OnOverlap_choices)} or "
f"a callable to merge records manually."
)

self._retention_ms = ensure_milliseconds(grace_ms)
self._store_name = store_name

def join(
self,
left: "StreamingDataFrame",
right: "StreamingDataFrame",
) -> "StreamingDataFrame":
if left.stream_id == right.stream_id:
raise ValueError(
"Joining dataframes originating from "
"the same topic is not yet supported.",
)
left.ensure_topics_copartitioned(*left.topics, *right.topics)

for sdf in (left, right):
changelog_config = TopicManager.derive_topic_config(sdf.topics)
sdf.processing_context.state_manager.register_timestamped_store(
stream_id=sdf.stream_id,
store_name=self._store_name,
changelog_config=changelog_config,
)

is_inner_join = self._how == "inner"

def left_func(value, key, timestamp, headers):
tx = cast(
TimestampedPartitionTransaction,
right.processing_context.checkpoint.get_store_transaction(
stream_id=right.stream_id,
partition=message_context().partition,
store_name=self._store_name,
),
)

right_value = tx.get_last(timestamp=timestamp, prefix=key)
if is_inner_join and not right_value:
return DISCARDED
return self._merger(value, right_value)

def right_func(value, key, timestamp, headers):
tx = cast(
TimestampedPartitionTransaction,
right.processing_context.checkpoint.get_store_transaction(
stream_id=right.stream_id,
partition=message_context().partition,
store_name=self._store_name,
),
)
tx.set_for_timestamp(
timestamp=timestamp,
value=value,
prefix=key,
retention_ms=self._retention_ms,
)

right = right.update(right_func, metadata=True).filter(lambda value: False)
left = left.apply(left_func, metadata=True).filter(
lambda value: value is not DISCARDED
)
return left.concat(right)
36 changes: 36 additions & 0 deletions quixstreams/dataframe/joins/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Mapping, Optional


def keep_left_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict:
"""
Merge two dictionaries, preferring values from the left dictionary
"""
left = left or {}
right = right or {}
return {**right, **left}


def keep_right_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict:
"""
Merge two dictionaries, preferring values from the right dictionary
"""
left = left or {}
right = right or {}
# TODO: Add try-except everywhere and tell to pass a callback if one of the objects is not a mapping
return {**left, **right}


def raise_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict:
"""
Merge two dictionaries and raise an error if overlapping keys detected
"""
left = left or {}
right = right or {}
if overlapping_columns := left.keys() & right.keys():
overlapping_columns_str = ", ".join(sorted(overlapping_columns))
raise ValueError(
f"Overlapping columns: {overlapping_columns_str}."
'You need to provide either an "on_overlap" value of '
"'keep-left' or 'keep-right' or a custom merger function."
)
return {**left, **right}
4 changes: 4 additions & 0 deletions quixstreams/state/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ class PartitionStoreIsUsed(QuixException): ...
class StoreNotRegisteredError(QuixException): ...


# TODO: Merge these two exceptions together?
class WindowedStoreAlreadyRegisteredError(QuixException): ...


class StoreAlreadyRegisteredError(QuixException): ...


class InvalidStoreTransactionStateError(QuixException): ...


Expand Down
Loading