diff --git a/docs/img/join-asof.png b/docs/img/join-asof.png new file mode 100644 index 000000000..5d086b418 Binary files /dev/null and b/docs/img/join-asof.png differ diff --git a/docs/joins.md b/docs/joins.md new file mode 100644 index 000000000..8440db3bc --- /dev/null +++ b/docs/joins.md @@ -0,0 +1,146 @@ +# Joins +## Join as-of + + +Use `StreamingDataFrame.join_asof()` to join two topics into a new stream where each left record +is merged with the right record with the same key whose timestamp is less than or equal to the left timestamp. + +This join is built with the timeseries enrichment use cases in mind, where the left side represents some measurements and the right side represents events. + +Some examples: + +- Matching of the sensor measurements with the events in the system. +- Joining the purchases with the effective prices of the goods. + +During as-of join, the records on the right side get stored into a lookup table in the state, and the records from the left side query this state for matches. + +![img.png](img/join-asof.png) + +### Requirements +To perform a join, the underlying topics must follow these requirements: + +1. **Both topics must have the same number of partitions.** +Join is a stateful operation, and it requires partitions of left and right topics to be assigned to the same application during processing. + +2. **Keys in both topics must be distributed across partitions using the same algorithm.** +For example, messages with the key `A` must go to the same partition number for both left and right topics. This is Kafka's default behaviour. + + +### Example + +Join records from the topic "measurements" with the latest effective records from +the topic "metadata" using the "inner" join strategy and a grace period of 14 days: + +```python +from datetime import timedelta + +from quixstreams import Application + +app = Application(...) + +sdf_measurements = app.dataframe(app.topic("measurements")) +sdf_metadata = app.dataframe(app.topic("metadata")) + +# Join records from the topic "measurements" +# with the latest effective records from the topic "metadata". +# using the "inner" join strategy and keeping the "metadata" records stored for 14 days in event time. +sdf_joined = sdf_measurements.join_asof( + right=sdf_metadata, + how="inner", # Emit updates only if the match is found in the store. + on_merge="keep-left", # Prefer the columns from the left dataframe if they overlap with the right. + grace_ms=timedelta(days=14), # Keep the state for 14 days (measured in event time similar to windows). +) + +if __name__ == '__main__': + app.run() +``` + + +### How it works + +Here is a description of the as-of join algorithm: + +- Records from the right side get written to the state store without emitting any updates downstream. +- Records on the left side query the right store for the values with the same **key** and the timestamp lower or equal to the record's timestamp. +- If the match is found, the two records are merged together into a new one according to the `on_merge` logic. +- The size of the right store is controlled by the "grace_ms": + a newly added "right" record expires other values with the same key with timestamps below " - ". + +#### Joining strategies +As-of join supports the following joining strategies: + +- `inner` - emit the output for the left record only when the match is found (default). +- `left` - emit the output for the left record even without a match. + + +#### Merging records together +When the match is found, the two records are merged according to the `on_merge` parameter. + +Out-of-the-box implementations assume that records are **dictionaries**. +For merging other data types (as well as customizing the behavior) use the callback option. + +Possible values: + +- `raise` - merge two records together into a new dictionary and raise an exception if the same keys are found in both dictionaries. +This is a default behavior. + +- `keep-left` - merge two records together into a new dictionary and prefer keys from the **left** record in case of overlap. + +- `keep-right` - merge two records together into a new dictionary and prefer keys from the **right** record in case of overlap. + +- custom callback - pass a callback `(, ) -> ` to merge the records manually. +Use it when non-dictionary types are expected, or you want to customize the returned object: + +```python +from typing import Optional + +from quixstreams import Application + +app = Application(...) + +sdf_measurements = app.dataframe(app.topic("measurements")) +sdf_metadata = app.dataframe(app.topic("metadata")) + + +def on_merge(left: int, right: Optional[str]) -> dict: + """ + Merge non-dictionary items into a dict + """ + return {'measurement': left, 'metadata': right} + + +sdf_joined = sdf_measurements.join_asof(right=sdf_metadata, on_merge=on_merge) + +if __name__ == '__main__': + app.run() +``` + + + +#### State expiration +`StreamingDataFrame.join_asof` stores the right records to the state. +The `grace_ms` parameter regulates the state's lifetime (default - 7 days) to prevent it from growing in size forever. + +It shares some similarities with `grace_ms` in [Windows](windowing.md/#lateness-and-out-of-order-processing): + +- The timestamps are obtained from the records. +- The join key keeps track of the maximum observed timestamp for **each individual key**. +- The older values get expired only when the larger timestamp gets stored to the state. + +Adjust `grace_ms` based on the expected time gap between the left and the right side of the join. + +### Limitations + +- Joining dataframes belonging to the same topics (aka "self-join") is not supported. +- As-of join preserves headers only for the left dataframe. +If you need headers of the right side records, consider adding them to the value. + +## Message ordering between partitions +Joins use [`StreamingDataFrame.concat()`](concatenating.md) under the hood, which means that the application's internal consumer goes into a special "buffered" mode +when the join is used. + +In this mode, it buffers messages per partition in order to process them in the timestamp order between different topics. +Timestamp alignment is effective only for the partitions **with the same numbers**: partition zero is aligned with other zero partitions, but not with partition one. + +Note that message ordering works only when the messages are consumed from the topics. +If you change timestamps of the record during processing, they will be processed in the original order. diff --git a/mkdocs.yml b/mkdocs.yml index 6c9c0ed6e..32093720c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,7 @@ nav: - Windowing: windowing.md - Aggregations: aggregations.md - Concatenating Topics: concatenating.md + - Joins: joins.md - Branching StreamingDataFrames: branching.md - Configuration: configuration.md diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 10e12c913..31fd8d05f 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -49,13 +49,15 @@ from quixstreams.models.serializers import DeserializerType, SerializerType from quixstreams.sinks import BaseSink from quixstreams.state.base import State +from quixstreams.state.manager import StoreTypes from quixstreams.utils.printing import ( DEFAULT_COLUMN_NAME, DEFAULT_LIVE, DEFAULT_LIVE_SLOWDOWN, ) -from .exceptions import InvalidOperation, TopicPartitionsMismatch +from .exceptions import InvalidOperation +from .joins import JoinAsOf, JoinAsOfHow, OnOverlap from .registry import DataFrameRegistry from .series import StreamingSeries from .utils import ensure_milliseconds @@ -275,7 +277,7 @@ def func(d: dict, state: State): Default - `False`. """ if stateful: - self._register_store() + self.register_store() # Force the callback to accept metadata if metadata: with_metadata_func = cast(ApplyWithMetadataCallbackStateful, func) @@ -284,11 +286,7 @@ def func(d: dict, state: State): cast(ApplyCallbackStateful, func) ) - stateful_func = _as_stateful( - func=with_metadata_func, - processing_context=self._processing_context, - stream_id=self.stream_id, - ) + stateful_func = _as_stateful(with_metadata_func, self) stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload] else: stream = self.stream.add_apply( @@ -384,7 +382,7 @@ def func(values: list, state: State): :return: the updated StreamingDataFrame instance (reassignment NOT required). """ if stateful: - self._register_store() + self.register_store() # Force the callback to accept metadata if metadata: with_metadata_func = cast(UpdateWithMetadataCallbackStateful, func) @@ -393,11 +391,7 @@ def func(values: list, state: State): cast(UpdateCallbackStateful, func) ) - stateful_func = _as_stateful( - func=with_metadata_func, - processing_context=self._processing_context, - stream_id=self.stream_id, - ) + stateful_func = _as_stateful(with_metadata_func, self) return self._add_update(stateful_func, metadata=True) else: return self._add_update( @@ -486,7 +480,7 @@ def func(d: dict, state: State): """ if stateful: - self._register_store() + self.register_store() # Force the callback to accept metadata if metadata: with_metadata_func = cast(FilterWithMetadataCallbackStateful, func) @@ -495,11 +489,7 @@ def func(d: dict, state: State): cast(FilterCallbackStateful, func) ) - stateful_func = _as_stateful( - func=with_metadata_func, - processing_context=self._processing_context, - stream_id=self.stream_id, - ) + stateful_func = _as_stateful(with_metadata_func, self) stream = self.stream.add_filter(stateful_func, metadata=True) else: stream = self.stream.add_filter( # type: ignore[call-overload] @@ -1656,16 +1646,79 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame": *self.topics, *other.topics, stream=merged_stream ) - def ensure_topics_copartitioned(self): - partitions_counts = set(t.broker_config.num_partitions for t in self._topics) - if len(partitions_counts) > 1: - msg = ", ".join( - f'"{t.name}" ({t.broker_config.num_partitions} partitions)' - for t in self._topics - ) - raise TopicPartitionsMismatch( - f"The underlying topics must have the same number of partitions to use State; got {msg}" - ) + def join_asof( + self, + right: "StreamingDataFrame", + how: JoinAsOfHow = "inner", + on_merge: Union[OnOverlap, Callable[[Any, Any], Any]] = "raise", + grace_ms: Union[int, timedelta] = timedelta(days=7), + name: Optional[str] = None, + ) -> "StreamingDataFrame": + """ + Join the left dataframe with the records of the right dataframe with + the same key whose timestamp is less than or equal to the left timestamp. + This join is built with the enrichment use case in mind, where the left side + represents some measurements and the right side is metadata. + + To be joined, the underlying topics of the dataframes must have the same number of partitions + and use the same partitioner (all keys should be distributed across partitions using the same algorithm). + + Joining dataframes belonging to the same topics (aka "self-join") is not supported as of now. + + How it works: + - Records from the right side get written to the state store without emitting any updates downstream. + - Records on the left side query the right store for the values with the same **key** and the timestamp lower or equal to the record's timestamp. + Left side emits data downstream. + - If the match is found, the two records are merged together into a new one according to the `on_merge` logic. + - The size of the right store is controlled by the "grace_ms": + a newly added "right" record expires other values with the same key with timestamps below " - ". + + :param right: a StreamingDataFrame to join with. + + :param how: the join strategy. Can be one of: + - "inner" - emits the result when the match on the right side is found for the left record. + - "left" - emits the result for each left record even if there is no match on the right side. + Default - `"inner"`. + + :param on_merge: how to merge the matched records together assuming they are dictionaries: + - "raise" - fail with an error if the same keys are found in both dictionaries + - "keep-left" - prefer the keys from the left record. + - "keep-right" - prefer the keys from the right record + - callback - a callback in form "(, ) -> " to merge the records manually. + Use it to customize the merging logic or when one of the records is not a dictionary. + + :param grace_ms: how long to keep the right records in the store in event time. + (the time is taken from the records' timestamps). + It can be specified as either an `int` representing milliseconds or as a `timedelta` object. + The records are expired per key when the new record gets added. + Default - 7 days. + + :param name: The unique identifier of the underlying state store for the "right" dataframe. + If not provided, it will be generated based on the underlying topic names. + Provide a custom name if you need to join the same right dataframe multiple times + within the application. + + Example: + + ```python + from datetime import timedelta + from quixstreams import Application + + app = Application() + + sdf_measurements = app.dataframe(app.topic("measurements")) + sdf_metadata = app.dataframe(app.topic("metadata")) + + # Join records from the topic "measurements" + # with the latest effective records from the topic "metadata" + # using the "inner" join strategy and keeping the "metadata" records stored for 14 days in event time. + sdf_joined = sdf_measurements.join_asof(sdf_metadata, how="inner", grace_ms=timedelta(days=14)) + ``` + + """ + return JoinAsOf( + how=how, on_merge=on_merge, grace_ms=grace_ms, store_name=name + ).join(self, right) def _produce( self, @@ -1689,17 +1742,19 @@ def _add_update( self._stream = self._stream.add_update(func, metadata=metadata) # type: ignore[call-overload] return self - def _register_store(self): + def register_store(self, store_type: Optional[StoreTypes] = None) -> None: """ Register the default store for the current stream_id in StateStoreManager. """ - self.ensure_topics_copartitioned() + TopicManager.ensure_topics_copartitioned(*self._topics) # Generate a changelog topic config based on the underlying topics. changelog_topic_config = self._topic_manager.derive_topic_config(self._topics) self._processing_context.state_manager.register_store( - stream_id=self.stream_id, changelog_config=changelog_topic_config + stream_id=self.stream_id, + store_type=store_type, + changelog_config=changelog_topic_config, ) def _groupby_key( @@ -1847,19 +1902,16 @@ def wrapper( def _as_stateful( func: Callable[[Any, Any, int, Any, State], T], - processing_context: ProcessingContext, - stream_id: str, + sdf: StreamingDataFrame, ) -> Callable[[Any, Any, int, Any], T]: @functools.wraps(func) def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: - ctx = message_context() - transaction = processing_context.checkpoint.get_store_transaction( - stream_id=stream_id, - partition=ctx.partition, - ) # Pass a State object with an interface limited to the key updates only # and prefix all the state keys by the message key - state = transaction.as_state(prefix=key) + state = sdf.processing_context.checkpoint.get_store_transaction( + stream_id=sdf.stream_id, + partition=message_context().partition, + ).as_state(prefix=key) return func(value, key, timestamp, headers, state) return wrapper diff --git a/quixstreams/dataframe/exceptions.py b/quixstreams/dataframe/exceptions.py index 2bb691f4e..0534fc3f4 100644 --- a/quixstreams/dataframe/exceptions.py +++ b/quixstreams/dataframe/exceptions.py @@ -7,7 +7,6 @@ "ColumnDoesNotExist", "StreamingDataFrameDuplicate", "GroupByDuplicate", - "TopicPartitionsMismatch", ) @@ -27,6 +26,3 @@ class GroupByDuplicate(QuixException): ... class StreamingDataFrameDuplicate(QuixException): ... - - -class TopicPartitionsMismatch(QuixException): ... diff --git a/quixstreams/dataframe/joins/__init__.py b/quixstreams/dataframe/joins/__init__.py new file mode 100644 index 000000000..3df77d58f --- /dev/null +++ b/quixstreams/dataframe/joins/__init__.py @@ -0,0 +1,3 @@ +from .join_asof import JoinAsOf as JoinAsOf +from .join_asof import JoinAsOfHow as JoinAsOfHow +from .join_asof import OnOverlap as OnOverlap diff --git a/quixstreams/dataframe/joins/join_asof.py b/quixstreams/dataframe/joins/join_asof.py new file mode 100644 index 000000000..b6939c69b --- /dev/null +++ b/quixstreams/dataframe/joins/join_asof.py @@ -0,0 +1,115 @@ +import typing +from datetime import timedelta +from typing import Any, Callable, Literal, Optional, Union, cast, get_args + +from quixstreams.context import message_context +from quixstreams.dataframe.utils import ensure_milliseconds +from quixstreams.models.topics.manager import TopicManager +from quixstreams.state.rocksdb.timestamped import TimestampedPartitionTransaction + +from .utils import keep_left_merger, keep_right_merger, raise_merger + +if typing.TYPE_CHECKING: + from quixstreams.dataframe.dataframe import StreamingDataFrame + + +__all__ = ("JoinAsOfHow", "OnOverlap", "JoinAsOf") + +DISCARDED = object() +JoinAsOfHow = Literal["inner", "left"] +JoinAsOfHow_choices = get_args(JoinAsOfHow) + +OnOverlap = Literal["keep-left", "keep-right", "raise"] +OnOverlap_choices = get_args(OnOverlap) + + +class JoinAsOf: + def __init__( + self, + how: JoinAsOfHow, + on_merge: Union[OnOverlap, Callable[[Any, Any], Any]], + grace_ms: Union[int, timedelta], + store_name: Optional[str] = None, + ): + if how not in JoinAsOfHow_choices: + raise ValueError( + f'Invalid "how" value: {how}. ' + f"Valid choices are: {', '.join(JoinAsOfHow_choices)}." + ) + self._how = how + + if callable(on_merge): + self._merger = on_merge + elif on_merge == "keep-left": + self._merger = keep_left_merger + elif on_merge == "keep-right": + self._merger = keep_right_merger + elif on_merge == "raise": + self._merger = raise_merger + else: + raise ValueError( + f'Invalid "on_merge" value: {on_merge}. ' + f"Provide either one of {', '.join(OnOverlap_choices)} or " + f"a callable to merge records manually." + ) + + self._retention_ms = ensure_milliseconds(grace_ms) + self._store_name = store_name or "join" + + def join( + self, + left: "StreamingDataFrame", + right: "StreamingDataFrame", + ) -> "StreamingDataFrame": + if left.stream_id == right.stream_id: + raise ValueError( + "Joining dataframes originating from " + "the same topic is not yet supported.", + ) + TopicManager.ensure_topics_copartitioned(*left.topics, *right.topics) + + changelog_config = TopicManager.derive_topic_config(right.topics) + right.processing_context.state_manager.register_timestamped_store( + stream_id=right.stream_id, + store_name=self._store_name, + changelog_config=changelog_config, + ) + + is_inner_join = self._how == "inner" + + def left_func(value, key, timestamp, headers): + tx = cast( + TimestampedPartitionTransaction, + right.processing_context.checkpoint.get_store_transaction( + stream_id=right.stream_id, + partition=message_context().partition, + store_name=self._store_name, + ), + ) + + right_value = tx.get_latest(timestamp=timestamp, prefix=key) + if is_inner_join and not right_value: + return DISCARDED + return self._merger(value, right_value) + + def right_func(value, key, timestamp, headers): + tx = cast( + TimestampedPartitionTransaction, + right.processing_context.checkpoint.get_store_transaction( + stream_id=right.stream_id, + partition=message_context().partition, + store_name=self._store_name, + ), + ) + tx.set_for_timestamp( + timestamp=timestamp, + value=value, + prefix=key, + retention_ms=self._retention_ms, + ) + + right = right.update(right_func, metadata=True).filter(lambda value: False) + left = left.apply(left_func, metadata=True).filter( + lambda value: value is not DISCARDED + ) + return left.concat(right) diff --git a/quixstreams/dataframe/joins/utils.py b/quixstreams/dataframe/joins/utils.py new file mode 100644 index 000000000..49058e0ab --- /dev/null +++ b/quixstreams/dataframe/joins/utils.py @@ -0,0 +1,35 @@ +from typing import Mapping, Optional + + +def keep_left_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict: + """ + Merge two dictionaries, preferring values from the left dictionary + """ + left = left if left is not None else {} + right = right if right is not None else {} + return {**right, **left} + + +def keep_right_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict: + """ + Merge two dictionaries, preferring values from the right dictionary + """ + left = left if left is not None else {} + right = right if right is not None else {} + return {**left, **right} + + +def raise_merger(left: Optional[Mapping], right: Optional[Mapping]) -> dict: + """ + Merge two dictionaries and raise an error if overlapping keys detected + """ + left = left if left is not None else {} + right = right if right is not None else {} + if overlapping_columns := left.keys() & right.keys(): + overlapping_columns_str = ", ".join(sorted(overlapping_columns)) + raise ValueError( + f"Overlapping columns: {overlapping_columns_str}." + 'You need to provide either an "on_merge" value of ' + "'keep-left' or 'keep-right' or a custom merger function." + ) + return {**left, **right} diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index d3da8ac0a..38442b883 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -70,7 +70,7 @@ def process_window( pass def register_store(self) -> None: - self._dataframe.ensure_topics_copartitioned() + TopicManager.ensure_topics_copartitioned(*self._dataframe.topics) # Create a config for the changelog topic based on the underlying SDF topics changelog_config = TopicManager.derive_topic_config(self._dataframe.topics) self._dataframe.processing_context.state_manager.register_windowed_store( diff --git a/quixstreams/models/topics/exceptions.py b/quixstreams/models/topics/exceptions.py index 94d3eb9bb..de6bee6e3 100644 --- a/quixstreams/models/topics/exceptions.py +++ b/quixstreams/models/topics/exceptions.py @@ -20,3 +20,6 @@ class TopicPermissionError(QuixException): ... class TopicConfigurationError(QuixException): ... + + +class TopicPartitionsMismatch(QuixException): ... diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py index a57011f5f..598b8c3d7 100644 --- a/quixstreams/models/topics/manager.py +++ b/quixstreams/models/topics/manager.py @@ -9,6 +9,7 @@ TopicConfigurationMismatch, TopicNameLengthExceeded, TopicNotFoundError, + TopicPartitionsMismatch, ) from .topic import TimestampExtractor, Topic, TopicConfig, TopicType @@ -333,6 +334,18 @@ def derive_topic_config(cls, topics: Iterable[Topic]) -> TopicConfig: }, ) + @classmethod + def ensure_topics_copartitioned(cls, *topics: Topic): + partitions_counts = set(t.broker_config.num_partitions for t in topics) + if len(partitions_counts) > 1: + msg = ", ".join( + f'"{t.name}" ({t.broker_config.num_partitions} partitions)' + for t in topics + ) + raise TopicPartitionsMismatch( + f"The underlying topics must have the same number of partitions to use State; got {msg}" + ) + def stream_id_from_topics(self, topics: Sequence[Topic]) -> str: """ Generate a stream_id by combining names of the provided topics. diff --git a/quixstreams/state/exceptions.py b/quixstreams/state/exceptions.py index a8374a9fa..1785d2ee1 100644 --- a/quixstreams/state/exceptions.py +++ b/quixstreams/state/exceptions.py @@ -10,7 +10,7 @@ class PartitionStoreIsUsed(QuixException): ... class StoreNotRegisteredError(QuixException): ... -class WindowedStoreAlreadyRegisteredError(QuixException): ... +class StoreAlreadyRegisteredError(QuixException): ... class InvalidStoreTransactionStateError(QuixException): ... diff --git a/quixstreams/state/manager.py b/quixstreams/state/manager.py index 5a9667fc9..1c3c437ab 100644 --- a/quixstreams/state/manager.py +++ b/quixstreams/state/manager.py @@ -9,8 +9,8 @@ from .base import Store, StorePartition from .exceptions import ( PartitionStoreIsUsed, + StoreAlreadyRegisteredError, StoreNotRegisteredError, - WindowedStoreAlreadyRegisteredError, ) from .memory import MemoryStore from .recovery import ChangelogProducerFactory, RecoveryManager @@ -197,14 +197,6 @@ def register_store( changelog_producer_factory=changelog_producer_factory, options=self._rocksdb_options, ) - elif store_type == TimestampedStore: - store = TimestampedStore( - name=store_name, - stream_id=stream_id, - base_dir=str(self._state_dir), - changelog_producer_factory=changelog_producer_factory, - options=self._rocksdb_options, - ) elif store_type == MemoryStore: store = MemoryStore( name=store_name, @@ -216,6 +208,30 @@ def register_store( self._stores.setdefault(stream_id, {})[store_name] = store + def register_timestamped_store( + self, + stream_id: str, + store_name: str, + changelog_config: Optional[TopicConfig] = None, + ) -> None: + if self._stores.get(stream_id, {}).get(store_name): + raise StoreAlreadyRegisteredError( + f'Store "{store_name}" for stream_id "{stream_id}" is already registered; ' + f"provide a different name" + ) + store = TimestampedStore( + name=store_name, + stream_id=stream_id, + base_dir=str(self._state_dir), + changelog_producer_factory=self._setup_changelogs( + stream_id=stream_id, + store_name=store_name, + topic_config=changelog_config, + ), + options=self._rocksdb_options, + ) + self._stores.setdefault(stream_id, {})[store_name] = store + def register_windowed_store( self, stream_id: str, @@ -237,7 +253,7 @@ def register_windowed_store( store = self._stores.get(stream_id, {}).get(store_name) if store: - raise WindowedStoreAlreadyRegisteredError( + raise StoreAlreadyRegisteredError( "This window range and type combination already exists; " "to use this window, provide a unique name via the `name` parameter." ) diff --git a/quixstreams/state/metadata.py b/quixstreams/state/metadata.py index 7525a3591..09dd70e72 100644 --- a/quixstreams/state/metadata.py +++ b/quixstreams/state/metadata.py @@ -1,6 +1,7 @@ import enum SEPARATOR = b"|" +SEPARATOR_LENGTH = len(SEPARATOR) CHANGELOG_CF_MESSAGE_HEADER = "__column_family__" CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER = "__processed_tp_offsets__" diff --git a/quixstreams/state/rocksdb/partition.py b/quixstreams/state/rocksdb/partition.py index eed038d5c..67e50b237 100644 --- a/quixstreams/state/rocksdb/partition.py +++ b/quixstreams/state/rocksdb/partition.py @@ -20,7 +20,7 @@ from quixstreams.state.exceptions import ColumnFamilyDoesNotExist from quixstreams.state.metadata import METADATA_CF_NAME, Marker from quixstreams.state.recovery import ChangelogProducer -from quixstreams.state.serialization import int_from_int64_bytes, int_to_int64_bytes +from quixstreams.state.serialization import int_from_bytes, int_to_bytes from .exceptions import ColumnFamilyAlreadyExists from .metadata import ( @@ -197,8 +197,10 @@ def iter_items( # is not respected by Rdict for some reason. We need to manually # filter it here. for key, value in items: - if lower_bound <= key: - yield key, value + if key < lower_bound: + # Exit early if the key falls below the lower bound + break + yield key, value def begin(self) -> PartitionTransaction: return PartitionTransaction( @@ -229,7 +231,7 @@ def get_changelog_offset(self) -> Optional[int]: if offset_bytes is None: return None - return int_from_int64_bytes(offset_bytes) + return int_from_bytes(offset_bytes) def write_changelog_offset(self, offset: int): """ @@ -384,7 +386,7 @@ def _init_rocksdb(self) -> Rdict: raise logger.warning( - f"Failed to open rocksdb partition, cannot acquire a lock. " + f"Failed to open rocksdb partition , cannot acquire a lock. " f"Retrying in {self._open_retry_backoff}sec." ) @@ -394,7 +396,7 @@ def _init_rocksdb(self) -> Rdict: def _update_changelog_offset(self, batch: WriteBatch, offset: int): batch.put( CHANGELOG_OFFSET_KEY, - int_to_int64_bytes(offset), + int_to_bytes(offset), self.get_column_family_handle(METADATA_CF_NAME), ) diff --git a/quixstreams/state/rocksdb/timestamped.py b/quixstreams/state/rocksdb/timestamped.py index ecc20e2f8..b74b3330b 100644 --- a/quixstreams/state/rocksdb/timestamped.py +++ b/quixstreams/state/rocksdb/timestamped.py @@ -12,7 +12,7 @@ from quixstreams.state.serialization import ( DumpsFunc, LoadsFunc, - int_to_int64_bytes, + encode_integer_pair, serialize, ) @@ -70,7 +70,7 @@ def __init__( ) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def get_last(self, timestamp: int, prefix: Any) -> Optional[Any]: + def get_latest(self, timestamp: int, prefix: Any) -> Optional[Any]: """Get the latest value for a prefix up to a given timestamp. Searches both the transaction's update cache and the underlying RocksDB store @@ -211,7 +211,9 @@ def _ensure_bytes(self, prefix: Any) -> bytes: def _serialize_key(self, key: Union[int, bytes], prefix: bytes) -> bytes: if isinstance(key, int): - return prefix + SEPARATOR + int_to_int64_bytes(key) + # TODO: Currently using constant 0, but will be + # replaced with a global counter in the future + return prefix + SEPARATOR + encode_integer_pair(key, 0) elif isinstance(key, bytes): return prefix + SEPARATOR + key raise TypeError(f"Invalid key type: {type(key)}") @@ -228,9 +230,13 @@ def _get_min_eligible_timestamp(self, prefix: bytes) -> int: :return: The minimum eligible timestamp (int). """ cache = self._min_eligible_timestamps - return ( - cache.timestamps.get(prefix) or self.get(key=cache.key, prefix=prefix) or 0 - ) + cached = cache.timestamps.get(prefix) + if cached is not None: + return cached + stored = self.get(key=cache.key, prefix=prefix) or 0 + # Write the timestamp back to cache since it is known now + cache.timestamps[prefix] = stored + return stored def _set_min_eligible_timestamp(self, prefix: bytes, timestamp: int) -> None: """ diff --git a/quixstreams/state/rocksdb/windowed/serialization.py b/quixstreams/state/rocksdb/windowed/serialization.py index b75795918..7bf83c67f 100644 --- a/quixstreams/state/rocksdb/windowed/serialization.py +++ b/quixstreams/state/rocksdb/windowed/serialization.py @@ -1,20 +1,13 @@ -import struct - -from quixstreams.state.metadata import SEPARATOR +from quixstreams.state.metadata import SEPARATOR, SEPARATOR_LENGTH from quixstreams.state.serialization import ( - int_to_int64_bytes, + decode_integer_pair, + int_to_bytes, ) -__all__ = ("parse_window_key", "encode_integer_pair", "append_integer") - -_TIMESTAMP_BYTE_LENGTH = len(int_to_int64_bytes(0)) -_SEPARATOR_LENGTH = len(SEPARATOR) -_TIMESTAMPS_SEGMENT_LEN = _TIMESTAMP_BYTE_LENGTH * 2 + _SEPARATOR_LENGTH +__all__ = ("parse_window_key", "append_integer") -_window_pack_format = ">q" + "c" * _SEPARATOR_LENGTH + "q" -_window_packer = struct.Struct(_window_pack_format) -_window_pack = _window_packer.pack -_window_unpack = _window_packer.unpack +_TIMESTAMP_BYTE_LENGTH = len(int_to_bytes(0)) +_TIMESTAMPS_SEGMENT_LEN = _TIMESTAMP_BYTE_LENGTH * 2 + SEPARATOR_LENGTH def parse_window_key(key: bytes) -> tuple[bytes, int, int]: @@ -33,24 +26,10 @@ def parse_window_key(key: bytes) -> tuple[bytes, int, int]: key[-_TIMESTAMPS_SEGMENT_LEN:], ) - start_ms, _, end_ms = _window_unpack(timestamps_bytes) + start_ms, end_ms = decode_integer_pair(timestamps_bytes) return message_key, start_ms, end_ms -def encode_integer_pair(integer_1: int, integer_2: int) -> bytes: - """ - Encode a pair of integers into bytes of the following format: - ```|``` - - Encoding integers this way make them sortable in RocksDB within the same prefix. - - :param integer_1: first integer - :param integer_2: second integer - :return: integers as bytes - """ - return _window_pack(integer_1, SEPARATOR, integer_2) - - def append_integer(base_bytes: bytes, integer: int) -> bytes: """ Append integer to the base bytes @@ -61,4 +40,4 @@ def append_integer(base_bytes: bytes, integer: int) -> bytes: :param integer: integer to append :return: bytes """ - return base_bytes + SEPARATOR + int_to_int64_bytes(integer) + return base_bytes + SEPARATOR + int_to_bytes(integer) diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index c7c70abde..329b281de 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -12,6 +12,8 @@ from quixstreams.state.serialization import ( DumpsFunc, LoadsFunc, + encode_integer_pair, + int_to_bytes, serialize, ) from quixstreams.state.types import ExpiredWindowDetail, WindowDetail @@ -29,12 +31,7 @@ LATEST_TIMESTAMPS_CF_NAME, VALUES_CF_NAME, ) -from .serialization import ( - append_integer, - encode_integer_pair, - int_to_int64_bytes, - parse_window_key, -) +from .serialization import append_integer, parse_window_key from .state import WindowedTransactionState if TYPE_CHECKING: @@ -330,7 +327,7 @@ def expire_all_windows( if not windows: return last_expired = windows[-1] # windows are ordered - suffixes: set[bytes] = set(int_to_int64_bytes(window) for window in windows) + suffixes: set[bytes] = set(int_to_bytes(window) for window in windows) for key in self.keys(): if key[-8:] in suffixes: prefix, start, end = parse_window_key(key) diff --git a/quixstreams/state/serialization.py b/quixstreams/state/serialization.py index d672bfda1..0f97c6f4e 100644 --- a/quixstreams/state/serialization.py +++ b/quixstreams/state/serialization.py @@ -2,20 +2,28 @@ from typing import Any, Callable from .exceptions import StateSerializationError +from .metadata import SEPARATOR, SEPARATOR_LENGTH __all__ = ( "DumpsFunc", "LoadsFunc", "serialize", "deserialize", - "int_to_int64_bytes", - "int_from_int64_bytes", + "int_to_bytes", + "int_from_bytes", + "encode_integer_pair", + "decode_integer_pair", ) -_int_packer = struct.Struct(">q") +_int_packer = struct.Struct(">Q") _int_pack = _int_packer.pack _int_unpack = _int_packer.unpack +_int_pair_pack_format = ">Q" + "c" * SEPARATOR_LENGTH + "Q" +_int_pair_packer = struct.Struct(_int_pair_pack_format) +_int_pair_pack = _int_pair_packer.pack +_int_pair_unpack = _int_pair_packer.unpack + DumpsFunc = Callable[[Any], bytes] LoadsFunc = Callable[[bytes], Any] @@ -36,9 +44,35 @@ def deserialize(value: bytes, loads: LoadsFunc) -> Any: ) from exc -def int_to_int64_bytes(value: int) -> bytes: +def int_to_bytes(value: int) -> bytes: return _int_pack(value) -def int_from_int64_bytes(value: bytes) -> int: +def int_from_bytes(value: bytes) -> int: return _int_unpack(value)[0] + + +def encode_integer_pair(integer_1: int, integer_2: int) -> bytes: + """ + Encode a pair of integers into bytes of the following format: + ```|``` + + Encoding integers this way make them sortable in RocksDB within the same prefix. + + :param integer_1: first integer + :param integer_2: second integer + :return: integers as bytes + """ + return _int_pair_pack(integer_1, SEPARATOR, integer_2) + + +def decode_integer_pair(value: bytes) -> tuple[int, int]: + """ + Decode a pair of integers from bytes of the following format: + ```|``` + + :param value: bytes + :return: tuple of integers + """ + integer_1, _, integer_2 = _int_pair_unpack(value) + return integer_1, integer_2 diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index 075c0e022..78291ca91 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -4,6 +4,7 @@ import warnings from collections import namedtuple from datetime import timedelta +from functools import partial from typing import Any from unittest import mock @@ -14,11 +15,12 @@ GroupByDuplicate, GroupByNestingLimit, InvalidOperation, - TopicPartitionsMismatch, ) from quixstreams.dataframe.registry import DataFrameRegistry from quixstreams.dataframe.windows.base import WindowResult from quixstreams.models import TopicConfig +from quixstreams.models.topics.exceptions import TopicPartitionsMismatch +from quixstreams.state.exceptions import StoreAlreadyRegisteredError from tests.utils import DummySink RecordStub = namedtuple("RecordStub", ("value", "key", "timestamp")) @@ -2641,3 +2643,297 @@ def test_concat_stateful_mismatching_partitions_fails( match="The underlying topics must have the same number of partitions to use State", ): sdf1.concat(sdf2).update(lambda v, state: None, stateful=True) + + +class TestStreamingDataFrameJoinAsOf: + @pytest.fixture + def topic_manager(self, topic_manager_factory): + return topic_manager_factory() + + # TODO: Check if we already have a fixture for that to avoid the pollution + @pytest.fixture + def create_topic(self, topic_manager): + def _create_topic(num_partitions=1): + return topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, + replication_factor=1, + ), + ) + + return _create_topic + + @pytest.fixture + def create_sdf(self, dataframe_factory, state_manager): + def _create_sdf(topic): + return dataframe_factory(topic=topic, state_manager=state_manager) + + return _create_sdf + + @pytest.fixture + def assign_partition(self, state_manager): + def _assign_partition(sdf): + state_manager.on_partition_assign( + stream_id=sdf.stream_id, + partition=0, + committed_offsets={}, + ) + + return _assign_partition + + @pytest.fixture + def publish(self, message_context_factory): + def _publish(sdf, topic, value, key, timestamp): + return sdf.test( + value=value, + key=key, + timestamp=timestamp, + topic=topic, + ctx=message_context_factory(topic=topic.name), + ) + + return _publish + + @pytest.mark.parametrize( + "how, right, left, expected", + [ + ( + "inner", + {"right": 2}, + {"left": 1}, + [({"left": 1, "right": 2}, b"key", 2, None)], + ), + ( + "inner", + None, + {"left": 1}, + [], + ), + ( + "inner", + {}, + {"left": 1}, + [], + ), + ( + "left", + {"right": 2}, + {"left": 1}, + [({"left": 1, "right": 2}, b"key", 2, None)], + ), + ( + "left", + None, + {"left": 1}, + [({"left": 1}, b"key", 2, None)], + ), + ( + "left", + {}, + {"left": 1}, + [({"left": 1}, b"key", 2, None)], + ), + ], + ) + def test_how( + self, + create_topic, + create_sdf, + assign_partition, + publish, + how, + right, + left, + expected, + ): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + joined_sdf = left_sdf.join_asof(right_sdf, how=how) + assign_partition(right_sdf) + + publish(joined_sdf, right_topic, value=right, key=b"key", timestamp=1) + joined_value = publish( + joined_sdf, left_topic, value=left, key=b"key", timestamp=2 + ) + assert joined_value == expected + + def test_how_invalid_value(self, create_topic, create_sdf): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + + match = 'Invalid "how" value' + with pytest.raises(ValueError, match=match): + left_sdf.join_asof(right_sdf, how="invalid") + + def test_mismatching_partitions_fails(self, create_topic, create_sdf): + left_topic, right_topic = create_topic(), create_topic(num_partitions=2) + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + + with pytest.raises(TopicPartitionsMismatch): + left_sdf.join_asof(right_sdf) + + @pytest.mark.parametrize( + "on_merge, right, left, expected", + [ + ( + "keep-left", + None, + {"A": 1}, + {"A": 1}, + ), + ( + "keep-left", + {"B": "right", "C": 2}, + {"A": 1, "B": "left"}, + {"A": 1, "B": "left", "C": 2}, + ), + ( + "keep-right", + None, + {"A": 1}, + {"A": 1}, + ), + ( + "keep-right", + {"B": "right", "C": 2}, + {"A": 1, "B": "left"}, + {"A": 1, "B": "right", "C": 2}, + ), + ( + "raise", + None, + {"A": 1}, + {"A": 1}, + ), + ( + "raise", + {"B": 2}, + {"A": 1}, + {"A": 1, "B": 2}, + ), + ( + "raise", + {"B": "right B", "C": "right C"}, + {"A": 1, "B": "left B", "C": "left C"}, + ValueError("Overlapping columns: B, C."), + ), + ], + ) + def test_on_merge( + self, + create_topic, + create_sdf, + assign_partition, + publish, + on_merge, + right, + left, + expected, + ): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + joined_sdf = left_sdf.join_asof(right_sdf, how="left", on_merge=on_merge) + assign_partition(right_sdf) + + publish(joined_sdf, right_topic, value=right, key=b"key", timestamp=1) + + if isinstance(expected, Exception): + with pytest.raises(expected.__class__, match=expected.args[0]): + publish(joined_sdf, left_topic, value=left, key=b"key", timestamp=2) + else: + joined_value = publish( + joined_sdf, left_topic, value=left, key=b"key", timestamp=2 + ) + assert joined_value == [(expected, b"key", 2, None)] + + def test_on_merge_invalid_value(self, create_topic, create_sdf): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + + match = 'Invalid "on_merge"' + with pytest.raises(ValueError, match=match): + left_sdf.join_asof(right_sdf, on_merge="invalid") + + def test_on_merge_callback( + self, create_topic, create_sdf, assign_partition, publish + ): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + + def on_merge(left, right): + return {"left": left, "right": right} + + joined_sdf = left_sdf.join_asof(right_sdf, on_merge=on_merge) + assign_partition(right_sdf) + + publish(joined_sdf, right_topic, value=1, key=b"key", timestamp=1) + joined_value = publish(joined_sdf, left_topic, value=2, key=b"key", timestamp=2) + assert joined_value == [({"left": 2, "right": 1}, b"key", 2, None)] + + def test_grace_ms( + self, + create_topic, + create_sdf, + assign_partition, + publish, + ): + left_topic, right_topic = create_topic(), create_topic() + left_sdf, right_sdf = create_sdf(left_topic), create_sdf(right_topic) + + joined_sdf = left_sdf.join_asof(right_sdf, grace_ms=10) + assign_partition(right_sdf) + + # min eligible timestamp is 15 - 10 = 5 + publish(joined_sdf, right_topic, value={"right": 1}, key=b"key", timestamp=15) + + # min eligible timestamp is still 5 + publish(joined_sdf, right_topic, value={"right": 3}, key=b"key", timestamp=4) + publish(joined_sdf, right_topic, value={"right": 2}, key=b"key", timestamp=5) + + publish_left = partial( + publish, + joined_sdf, + left_topic, + value={"left": 4}, + key=b"key", + ) + + assert publish_left(timestamp=4) == [] + assert publish_left(timestamp=5) == [({"left": 4, "right": 2}, b"key", 5, None)] + + def test_self_join_not_supported(self, create_topic, create_sdf): + topic = create_topic() + match = ( + "Joining dataframes originating from the same topic is not yet supported." + ) + + # The very same sdf object + sdf = create_sdf(topic) + with pytest.raises(ValueError, match=match): + sdf.join_asof(sdf) + + # Same topic, different branch + sdf2 = sdf.apply(lambda v: v) + with pytest.raises(ValueError, match=match): + sdf.join_asof(sdf2) + + def test_join_same_topic_multiple_times_fails(self, create_topic, create_sdf): + topic1 = create_topic() + topic2 = create_topic() + topic3 = create_topic() + + sdf1 = create_sdf(topic1) + sdf2 = create_sdf(topic2) + sdf3 = create_sdf(topic3) + + # Join topic1 with topic2 once + sdf1.join_asof(sdf2) + + # Repeat the join + with pytest.raises(StoreAlreadyRegisteredError): + sdf1.join_asof(sdf2) + + # Try joining topic2 with another sdf + with pytest.raises(StoreAlreadyRegisteredError): + sdf3.join_asof(sdf2) diff --git a/tests/test_quixstreams/test_state/test_manager.py b/tests/test_quixstreams/test_state/test_manager.py index c0b920378..bfe995e1b 100644 --- a/tests/test_quixstreams/test_state/test_manager.py +++ b/tests/test_quixstreams/test_state/test_manager.py @@ -10,8 +10,8 @@ from quixstreams.models import TopicConfig from quixstreams.state.exceptions import ( PartitionStoreIsUsed, + StoreAlreadyRegisteredError, StoreNotRegisteredError, - WindowedStoreAlreadyRegisteredError, ) from quixstreams.state.manager import SUPPORTED_STORES from quixstreams.state.rocksdb import RocksDBStore @@ -101,9 +101,14 @@ def test_register_store_twice(self, state_manager): def test_register_windowed_store_twice(self, state_manager): state_manager.register_windowed_store("stream_id", "store") - with pytest.raises(WindowedStoreAlreadyRegisteredError): + with pytest.raises(StoreAlreadyRegisteredError): state_manager.register_windowed_store("stream_id", "store") + def test_register_timestamped_store_twice(self, state_manager): + state_manager.register_timestamped_store("stream_id", "store") + with pytest.raises(StoreAlreadyRegisteredError): + state_manager.register_timestamped_store("stream_id", "store") + def test_get_store_not_registered(self, state_manager): with pytest.raises(StoreNotRegisteredError): state_manager.get_store("topic", "store") diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_rocksdb_partition.py b/tests/test_quixstreams/test_state/test_rocksdb/test_rocksdb_partition.py index c31a67ba5..054e2cd31 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_rocksdb_partition.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_rocksdb_partition.py @@ -14,7 +14,7 @@ from quixstreams.state.rocksdb.windowed.serialization import append_integer -class TestRocksdbStorePartition: +class TestRocksDBStorePartition: def test_open_db_locked_retries(self, store_partition_factory, executor): db1 = store_partition_factory("db") @@ -67,37 +67,45 @@ def test_open_arbitrary_exception_fails(self, store_partition_factory): assert str(raised.value) == "some exception" - def test_create_and_get_column_family(self, store_partition): + def test_create_and_get_column_family(self, store_partition: RocksDBStorePartition): store_partition.create_column_family("cf") assert store_partition.get_column_family("cf") - def test_create_column_family_already_exists(self, store_partition): + def test_create_column_family_already_exists( + self, store_partition: RocksDBStorePartition + ): store_partition.create_column_family("cf") with pytest.raises(ColumnFamilyAlreadyExists): store_partition.create_column_family("cf") - def test_get_column_family_doesnt_exist(self, store_partition): + def test_get_column_family_doesnt_exist( + self, store_partition: RocksDBStorePartition + ): with pytest.raises(ColumnFamilyDoesNotExist): store_partition.get_column_family("cf") - def test_get_column_family_cached(self, store_partition): + def test_get_column_family_cached(self, store_partition: RocksDBStorePartition): store_partition.create_column_family("cf") cf1 = store_partition.get_column_family("cf") cf2 = store_partition.get_column_family("cf") assert cf1 is cf2 - def test_create_and_drop_column_family(self, store_partition): + def test_create_and_drop_column_family( + self, store_partition: RocksDBStorePartition + ): store_partition.create_column_family("cf") store_partition.drop_column_family("cf") with pytest.raises(ColumnFamilyDoesNotExist): store_partition.get_column_family("cf") - def test_drop_column_family_doesnt_exist(self, store_partition): + def test_drop_column_family_doesnt_exist( + self, store_partition: RocksDBStorePartition + ): with pytest.raises(ColumnFamilyDoesNotExist): store_partition.drop_column_family("cf") - def test_list_column_families(self, store_partition): + def test_list_column_families(self, store_partition: RocksDBStorePartition): store_partition.create_column_family("cf1") store_partition.create_column_family("cf2") cfs = store_partition.list_column_families() @@ -121,7 +129,9 @@ def test_custom_options(self, store_partition_factory, tmp_path): assert logs_dir.is_dir() assert len(list(logs_dir.rglob("*"))) == 1 - def test_list_column_families_defaults(self, store_partition): + def test_list_column_families_defaults( + self, store_partition: RocksDBStorePartition + ): cfs = store_partition.list_column_families() assert cfs == [ # "default" CF is always present in RocksDB @@ -130,7 +140,7 @@ def test_list_column_families_defaults(self, store_partition): "__metadata__", ] - def test_ensure_metadata_cf(self, store_partition): + def test_ensure_metadata_cf(self, store_partition: RocksDBStorePartition): assert store_partition.get_column_family("__metadata__") @pytest.mark.parametrize( @@ -155,7 +165,7 @@ def test_ensure_metadata_cf(self, store_partition): ], ) def test_iter_items_returns_ordered_items( - self, store_partition, cache, backwards, expected + self, store_partition: RocksDBStorePartition, cache, backwards, expected ): for key, value in expected: cache.set(key=key, value=value, prefix=b"prefix") @@ -177,7 +187,9 @@ def test_iter_items_returns_ordered_items( == expected ) - def test_iter_items_exclusive_upper_bound(self, store_partition, cache): + def test_iter_items_exclusive_upper_bound( + self, store_partition: RocksDBStorePartition, cache + ): cache.set(key=b"prefix|1", value=b"value1", prefix=b"prefix") cache.set(key=b"prefix|2", value=b"value2", prefix=b"prefix") store_partition.write(cache=cache, changelog_offset=None) @@ -188,3 +200,53 @@ def test_iter_items_exclusive_upper_bound(self, store_partition, cache): upper_bound=b"prefix|2", ) ) == [(b"prefix|1", b"value1")] + + def test_iter_items_backwards_lower_bound( + self, store_partition: RocksDBStorePartition, cache + ): + """ + Test that keys below the lower bound are filtered + """ + prefix = b"2" + lower_bound = b"3" + upper_bound = b"4" + + cache.set(key=prefix + b"|" + b"test1", value=b"", prefix=prefix) + cache.set(key=prefix + b"|" + b"test2", value=b"", prefix=prefix) + store_partition.write(cache=cache, changelog_offset=None) + + assert ( + list( + store_partition.iter_items( + lower_bound=lower_bound, + upper_bound=upper_bound, + backwards=True, + ) + ) + == [] + ) + + def test_iter_items_backwards_upper_bound( + self, store_partition: RocksDBStorePartition, cache + ): + """ + Test that keys above the upper bound are filtered + """ + prefix = b"4" + lower_bound = b"3" + upper_bound = b"4" + + cache.set(key=prefix + b"|" + b"test1", value=b"", prefix=prefix) + cache.set(key=prefix + b"|" + b"test2", value=b"", prefix=prefix) + store_partition.write(cache=cache, changelog_offset=None) + + assert ( + list( + store_partition.iter_items( + lower_bound=lower_bound, + upper_bound=upper_bound, + backwards=True, + ) + ) + == [] + ) diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_timestamped.py b/tests/test_quixstreams/test_state/test_rocksdb/test_timestamped.py index ac4acb602..7a730e77d 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_timestamped.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_timestamped.py @@ -37,7 +37,7 @@ class TestTimestampedPartitionTransaction: pytest.param(10, 9, None, id="set_timestamp_greater_than_get_timestamp"), ], ) - def test_get_last_from_cache( + def test_get_latest_from_cache( self, transaction: TimestampedPartitionTransaction, set_timestamp: int, @@ -46,7 +46,7 @@ def test_get_last_from_cache( ): with transaction() as tx: tx.set_for_timestamp(timestamp=set_timestamp, value="value", prefix=b"key") - assert tx.get_last(timestamp=get_timestamp, prefix=b"key") == expected + assert tx.get_latest(timestamp=get_timestamp, prefix=b"key") == expected @pytest.mark.parametrize( ["set_timestamp", "get_timestamp", "expected"], @@ -56,7 +56,7 @@ def test_get_last_from_cache( pytest.param(10, 9, None, id="set_timestamp_greater_than_get_timestamp"), ], ) - def test_get_last_from_store( + def test_get_latest_from_store( self, transaction: TimestampedPartitionTransaction, set_timestamp: int, @@ -67,7 +67,7 @@ def test_get_last_from_store( tx.set_for_timestamp(timestamp=set_timestamp, value="value", prefix=b"key") with transaction() as tx: - assert tx.get_last(timestamp=get_timestamp, prefix=b"key") == expected + assert tx.get_latest(timestamp=get_timestamp, prefix=b"key") == expected @pytest.mark.parametrize( ["set_timestamp_stored", "set_timestamp_cached", "get_timestamp", "expected"], @@ -76,7 +76,7 @@ def test_get_last_from_store( pytest.param(2, 3, 5, "cached", id="cached-greater-than-stored"), ], ) - def test_get_last_returns_value_for_greater_timestamp( + def test_get_latest_returns_value_for_greater_timestamp( self, transaction: TimestampedPartitionTransaction, set_timestamp_stored: int, @@ -93,17 +93,17 @@ def test_get_last_returns_value_for_greater_timestamp( tx.set_for_timestamp( timestamp=set_timestamp_cached, value="cached", prefix=b"key" ) - assert tx.get_last(timestamp=get_timestamp, prefix=b"key") == expected + assert tx.get_latest(timestamp=get_timestamp, prefix=b"key") == expected - def test_get_last_prefix_not_bytes( + def test_get_latest_prefix_not_bytes( self, transaction: TimestampedPartitionTransaction ): with transaction() as tx: tx.set_for_timestamp(timestamp=10, value="value", prefix="key") - assert tx.get_last(timestamp=10, prefix="key") == "value" - assert tx.get_last(timestamp=10, prefix=b'"key"') == "value" + assert tx.get_latest(timestamp=10, prefix="key") == "value" + assert tx.get_latest(timestamp=10, prefix=b'"key"') == "value" - def test_get_last_for_out_of_order_timestamp( + def test_get_latest_for_out_of_order_timestamp( self, transaction: TimestampedPartitionTransaction, ): @@ -111,7 +111,7 @@ def test_get_last_for_out_of_order_timestamp( tx.set_for_timestamp( timestamp=10, value="value10", prefix=b"key", retention_ms=5 ) - assert tx.get_last(timestamp=10, prefix=b"key") == "value10" + assert tx.get_latest(timestamp=10, prefix=b"key") == "value10" tx.set_for_timestamp( timestamp=5, value="value5", prefix=b"key", retention_ms=5 ) @@ -120,10 +120,10 @@ def test_get_last_for_out_of_order_timestamp( ) with transaction() as tx: - assert tx.get_last(timestamp=5, prefix=b"key") == "value5" + assert tx.get_latest(timestamp=5, prefix=b"key") == "value5" # Retention watermark is 10 - 5 = 5 so everything lower is ignored - assert tx.get_last(timestamp=4, prefix=b"key") is None + assert tx.get_latest(timestamp=4, prefix=b"key") is None def test_set_for_timestamp_with_prefix_not_bytes( self, @@ -131,8 +131,8 @@ def test_set_for_timestamp_with_prefix_not_bytes( ): with transaction() as tx: tx.set_for_timestamp(timestamp=10, value="value", prefix="key") - assert tx.get_last(timestamp=10, prefix="key") == "value" - assert tx.get_last(timestamp=10, prefix=b'"key"') == "value" + assert tx.get_latest(timestamp=10, prefix="key") == "value" + assert tx.get_latest(timestamp=10, prefix=b'"key"') == "value" def test_set_for_timestamp_with_retention_cached( self, @@ -141,8 +141,8 @@ def test_set_for_timestamp_with_retention_cached( with transaction() as tx: tx.set_for_timestamp(timestamp=2, value="v2", prefix=b"key", retention_ms=2) tx.set_for_timestamp(timestamp=5, value="v5", prefix=b"key", retention_ms=2) - assert tx.get_last(timestamp=2, prefix=b"key") is None - assert tx.get_last(timestamp=5, prefix=b"key") == "v5" + assert tx.get_latest(timestamp=2, prefix=b"key") is None + assert tx.get_latest(timestamp=5, prefix=b"key") == "v5" def test_set_for_timestamp_with_retention_stored( self, @@ -153,8 +153,8 @@ def test_set_for_timestamp_with_retention_stored( tx.set_for_timestamp(timestamp=5, value="v5", prefix=b"key", retention_ms=2) with transaction() as tx: - assert tx.get_last(timestamp=2, prefix=b"key") is None - assert tx.get_last(timestamp=5, prefix=b"key") == "v5" + assert tx.get_latest(timestamp=2, prefix=b"key") is None + assert tx.get_latest(timestamp=5, prefix=b"key") == "v5" def test_expire_multiple_keys(self, transaction: TimestampedPartitionTransaction): with transaction() as tx: @@ -178,8 +178,8 @@ def test_expire_multiple_keys(self, transaction: TimestampedPartitionTransaction assert tx.get(key=12, prefix=b"key2") == "212" # Expiration advances only on `set_for_timestamp` calls - assert tx.get_last(timestamp=30, prefix=b"key1") == "112" - assert tx.get_last(timestamp=30, prefix=b"key2") == "212" + assert tx.get_latest(timestamp=30, prefix=b"key1") == "112" + assert tx.get_latest(timestamp=30, prefix=b"key2") == "212" def test_set_for_timestamp_overwrites_value_with_same_timestamp( self, @@ -188,7 +188,7 @@ def test_set_for_timestamp_overwrites_value_with_same_timestamp( with transaction() as tx: tx.set_for_timestamp(timestamp=1, value="11", prefix=b"key") tx.set_for_timestamp(timestamp=1, value="21", prefix=b"key") - assert tx.get_last(timestamp=1, prefix=b"key") == "21" + assert tx.get_latest(timestamp=1, prefix=b"key") == "21" with transaction() as tx: - assert tx.get_last(timestamp=1, prefix=b"key") == "21" + assert tx.get_latest(timestamp=1, prefix=b"key") == "21" diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_serialization.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_serialization.py index 73a5291f2..2d6007d19 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_serialization.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_serialization.py @@ -1,29 +1,8 @@ import pytest from quixstreams.state.metadata import SEPARATOR -from quixstreams.state.rocksdb.windowed.serialization import ( - append_integer, - encode_integer_pair, - parse_window_key, -) - - -@pytest.mark.parametrize( - "start, end", - [ - (0, 0), - (1, 2), - (-1, 2), - (2, -2), - ], -) -def test_encode_integer_pair(start, end): - key = encode_integer_pair(start, end) - assert isinstance(key, bytes) - - prefix, decoded_start, decoded_end = parse_window_key(key) - assert decoded_start == start - assert decoded_end == end +from quixstreams.state.rocksdb.windowed.serialization import append_integer +from quixstreams.state.serialization import encode_integer_pair @pytest.mark.parametrize("base_bytes", [b"", b"base_bytes"]) diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py index a44ce8ae9..5217b5961 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py @@ -3,7 +3,7 @@ import pytest from quixstreams.state.rocksdb.windowed.metadata import VALUES_CF_NAME -from quixstreams.state.rocksdb.windowed.serialization import encode_integer_pair +from quixstreams.state.serialization import encode_integer_pair @pytest.fixture diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py index a58906367..6808b0fef 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py @@ -4,7 +4,7 @@ CHANGELOG_CF_MESSAGE_HEADER, CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, ) -from quixstreams.state.rocksdb.windowed.serialization import encode_integer_pair +from quixstreams.state.serialization import encode_integer_pair from quixstreams.utils.json import dumps diff --git a/tests/test_quixstreams/test_state/test_serialization.py b/tests/test_quixstreams/test_state/test_serialization.py new file mode 100644 index 000000000..f8162f741 --- /dev/null +++ b/tests/test_quixstreams/test_state/test_serialization.py @@ -0,0 +1,23 @@ +import pytest + +from quixstreams.state.serialization import ( + decode_integer_pair, + encode_integer_pair, +) + + +@pytest.mark.parametrize( + "start, end", + [ + (0, 0), + (1, 18446744073709551615), + ], +) +def test_encode_integer_pair(start, end): + # This test also covers decode_integer_pair function + key = encode_integer_pair(start, end) + assert isinstance(key, bytes) + + decoded_start, decoded_end = decode_integer_pair(key) + assert decoded_start == start + assert decoded_end == end