Skip to content

Commit 1ae852b

Browse files
committed
Optimize groupb y for single partition topics
Group by operations on topics with a single partition are now optimized to avoid creating a repartition topic. Instead, the messages are directly transformed to use the new key, as all messages go to the same partition.
1 parent e0f7410 commit 1ae852b

File tree

4 files changed

+281
-106
lines changed

4 files changed

+281
-106
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class StreamingDataFrame:
9292
What it Does:
9393
9494
- Builds a data processing pipeline, declaratively (not executed immediately)
95-
- Executes this pipeline on inputs at runtime (Kafka message values)
95+
- Executes this pipeline on inputs at runtime (Kafka message values)
9696
- Provides functions/interface similar to Pandas Dataframes/Series
9797
- Enables stateful processing (and manages everything related to it)
9898
@@ -135,15 +135,20 @@ def __init__(
135135
registry: DataFrameRegistry,
136136
processing_context: ProcessingContext,
137137
stream: Optional[Stream] = None,
138+
stream_id: Optional[str] = None,
138139
):
139140
if not topics:
140141
raise ValueError("At least one Topic must be passed")
141142

142-
self._stream: Stream = stream or Stream()
143143
# Implicitly deduplicate Topic objects into a tuple and sort them by name
144144
self._topics: tuple[Topic, ...] = tuple(
145145
sorted({t.name: t for t in topics}.values(), key=attrgetter("name"))
146146
)
147+
148+
self._stream: Stream = stream or Stream()
149+
self._stream_id: str = stream_id or topic_manager.stream_id_from_topics(
150+
self.topics
151+
)
147152
self._topic_manager = topic_manager
148153
self._registry = registry
149154
self._processing_context = processing_context
@@ -174,7 +179,7 @@ def stream_id(self) -> str:
174179
175180
By default, a topic name or a combination of topic names are used as `stream_id`.
176181
"""
177-
return self._topic_manager.stream_id_from_topics(self.topics)
182+
return self._stream_id
178183

179184
@property
180185
def topics(self) -> tuple[Topic, ...]:
@@ -591,6 +596,11 @@ def func(d: dict, state: State):
591596
# Generate a config for the new repartition topic based on the underlying topics
592597
repartition_config = self._topic_manager.derive_topic_config(self._topics)
593598

599+
# If the topic has only one partition, we don't need a repartition topic
600+
# we can directly change the messages key as they all go to the same partition.
601+
if repartition_config.num_partitions == 1:
602+
return self._single_partition_groupby(operation, key)
603+
594604
groupby_topic = self._topic_manager.repartition_topic(
595605
operation=operation,
596606
stream_id=self.stream_id,
@@ -606,6 +616,29 @@ def func(d: dict, state: State):
606616
self._registry.register_groupby(source_sdf=self, new_sdf=groupby_sdf)
607617
return groupby_sdf
608618

619+
def _single_partition_groupby(
620+
self, operation: str, key: Union[str, Callable[[Any], Any]]
621+
) -> "StreamingDataFrame":
622+
if isinstance(key, str):
623+
624+
def _callback(value, _, timestamp, headers):
625+
return value, value[key], timestamp, headers
626+
else:
627+
628+
def _callback(value, _, timestamp, headers):
629+
return value, key(value), timestamp, headers
630+
631+
stream = self.stream.add_transform(_callback, expand=False)
632+
633+
groupby_sdf = self.__dataframe_clone__(
634+
stream=stream, stream_id=f"{self.stream_id}--groupby--{operation}"
635+
)
636+
self._registry.register_groupby(
637+
source_sdf=self, new_sdf=groupby_sdf, register_new_root=False
638+
)
639+
640+
return groupby_sdf
641+
609642
def contains(self, keys: Union[str, list[str]]) -> StreamingSeries:
610643
"""
611644
Check if keys are present in the Row value.
@@ -1679,6 +1712,7 @@ def __dataframe_clone__(
16791712
self,
16801713
*topics: Topic,
16811714
stream: Optional[Stream] = None,
1715+
stream_id: Optional[str] = None,
16821716
) -> "StreamingDataFrame":
16831717
"""
16841718
Clone the StreamingDataFrame with a new `stream`, `topics`,
@@ -1692,6 +1726,7 @@ def __dataframe_clone__(
16921726
clone = self.__class__(
16931727
*(topics or self._topics),
16941728
stream=stream,
1729+
stream_id=stream_id,
16951730
processing_context=self._processing_context,
16961731
topic_manager=self._topic_manager,
16971732
registry=self._registry,

quixstreams/dataframe/registry.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def register_root(
7070
self._registry[topic.name] = dataframe.stream
7171

7272
def register_groupby(
73-
self, source_sdf: "StreamingDataFrame", new_sdf: "StreamingDataFrame"
73+
self,
74+
source_sdf: "StreamingDataFrame",
75+
new_sdf: "StreamingDataFrame",
76+
register_new_root: bool = True,
7477
):
7578
"""
7679
Register a "groupby" SDF, which is one generated with `SDF.group_by()`.
@@ -81,16 +84,26 @@ def register_groupby(
8184
raise GroupByNestingLimit(
8285
"Subsequent (nested) `SDF.group_by()` operations are not allowed."
8386
)
84-
try:
85-
self.register_root(new_sdf)
86-
except StreamingDataFrameDuplicate:
87+
88+
if new_sdf.stream_id in self._repartition_origins:
8789
raise GroupByDuplicate(
8890
"A `SDF.group_by()` operation appears to be the same as another, "
8991
"either from using the same column or name parameter; "
9092
"adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
9193
)
94+
9295
self._repartition_origins.add(new_sdf.stream_id)
9396

97+
if register_new_root:
98+
try:
99+
self.register_root(new_sdf)
100+
except StreamingDataFrameDuplicate:
101+
raise GroupByDuplicate(
102+
"A `SDF.group_by()` operation appears to be the same as another, "
103+
"either from using the same column or name parameter; "
104+
"adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
105+
)
106+
94107
def compose_all(
95108
self, sink: Optional[VoidExecutor] = None
96109
) -> dict[str, VoidExecutor]:

tests/test_quixstreams/test_app.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -513,12 +513,14 @@ def test_state_dir_env(self):
513513
assert app.config.state_dir == Path("/path/to/other")
514514

515515

516+
@pytest.mark.parametrize("number_of_partitions", [1, 2])
516517
class TestAppGroupBy:
517518
def test_group_by(
518519
self,
519520
app_factory,
520521
internal_consumer_factory,
521522
executor,
523+
number_of_partitions,
522524
):
523525
"""
524526
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
@@ -540,8 +542,12 @@ def on_message_processed(*_):
540542
timestamp_ms = int(time.time() * 1000)
541543
user_id = "abc123"
542544
value_in = {"user": user_id}
543-
expected_message_count = 1
544-
total_messages = expected_message_count * 2 # groupby reproduces each message
545+
546+
if number_of_partitions == 1:
547+
total_messages = 1 # groupby optimisation for 1 partition
548+
else:
549+
total_messages = 2 # groupby reproduces each message
550+
545551
app = app_factory(
546552
auto_offset_reset="earliest",
547553
on_message_processed=on_message_processed,
@@ -551,6 +557,9 @@ def on_message_processed(*_):
551557
str(uuid.uuid4()),
552558
value_deserializer="json",
553559
value_serializer="json",
560+
config=TopicConfig(
561+
num_partitions=number_of_partitions, replication_factor=1
562+
),
554563
)
555564
app_topic_out = app.topic(
556565
str(uuid.uuid4()),
@@ -607,6 +616,7 @@ def test_group_by_with_window(
607616
internal_consumer_factory,
608617
executor,
609618
processing_guarantee,
619+
number_of_partitions,
610620
):
611621
"""
612622
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
@@ -631,8 +641,12 @@ def on_message_processed(*_):
631641
timestamp_ms = timestamp_ms - (timestamp_ms % window_duration_ms)
632642
user_id = "abc123"
633643
value_in = {"user": user_id}
634-
expected_message_count = 1
635-
total_messages = expected_message_count * 2 # groupby reproduces each message
644+
645+
if number_of_partitions == 1:
646+
total_messages = 1 # groupby optimisation for 1 partition
647+
else:
648+
total_messages = 2 # groupby reproduces each message
649+
636650
app = app_factory(
637651
auto_offset_reset="earliest",
638652
on_message_processed=on_message_processed,
@@ -643,6 +657,9 @@ def on_message_processed(*_):
643657
str(uuid.uuid4()),
644658
value_deserializer="json",
645659
value_serializer="json",
660+
config=TopicConfig(
661+
num_partitions=number_of_partitions, replication_factor=1
662+
),
646663
)
647664
app_topic_out = app.topic(
648665
str(uuid.uuid4()),
@@ -2380,11 +2397,9 @@ def on_message_processed(topic_, partition, offset):
23802397
assert row.timestamp == timestamp_ms
23812398
assert row.headers == headers
23822399

2400+
@pytest.mark.parametrize("number_of_partitions", [1, 2])
23832401
def test_group_by(
2384-
self,
2385-
app_factory,
2386-
internal_consumer_factory,
2387-
executor,
2402+
self, app_factory, internal_consumer_factory, executor, number_of_partitions
23882403
):
23892404
"""
23902405
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
@@ -2411,11 +2426,17 @@ def on_message_processed(*_):
24112426
str(uuid.uuid4()),
24122427
value_deserializer="json",
24132428
value_serializer="json",
2429+
config=TopicConfig(
2430+
num_partitions=number_of_partitions, replication_factor=1
2431+
),
24142432
)
24152433
input_topic_b = app.topic(
24162434
str(uuid.uuid4()),
24172435
value_deserializer="json",
24182436
value_serializer="json",
2437+
config=TopicConfig(
2438+
num_partitions=number_of_partitions, replication_factor=1
2439+
),
24192440
)
24202441
input_topics = [input_topic_a, input_topic_b]
24212442
output_topic_user = app.topic(
@@ -2433,8 +2454,14 @@ def on_message_processed(*_):
24332454
user_id = "abc123"
24342455
account_id = "def456"
24352456
value_in = {"user": user_id, "account": account_id}
2436-
# expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs
2437-
expected_processed = 6
2457+
2458+
if number_of_partitions == 1:
2459+
# expected_processed = 1 (input msg per SDF) * 1 (2 optimized groupbys that don't reprocesses input) * 2 SDFs
2460+
expected_processed = 2
2461+
else:
2462+
# expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs
2463+
expected_processed = 6
2464+
24382465
expected_output_topic_count = 2
24392466

24402467
sdf_a = app.dataframe(topic=input_topic_a)

0 commit comments

Comments
 (0)