Skip to content

Commit 4fb8006

Browse files
committed
[kafka] feat: Adding end_time functionality to kafka_consumer
THis commit adds an end_time functionality to the kafka consumer function which makes it more batch-processing friendly, as it allows the user to achieve indempotency
1 parent cd2f2c1 commit 4fb8006

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

sources/kafka/__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def kafka_consumer(
3535
batch_size: Optional[int] = 3000,
3636
batch_timeout: Optional[int] = 3,
3737
start_from: Optional[TAnyDateTime] = None,
38+
end_time: Optional[TAnyDateTime] = None,
3839
) -> Iterable[TDataItem]:
3940
"""Extract recent messages from the given Kafka topics.
4041
@@ -78,7 +79,18 @@ def kafka_consumer(
7879
if start_from is not None:
7980
start_from = ensure_pendulum_datetime(start_from)
8081

81-
tracker = OffsetTracker(consumer, topics, dlt.current.resource_state(), start_from)
82+
if end_time is not None:
83+
end_time = ensure_pendulum_datetime(end_time)
84+
85+
if end_time is not None and start_from is None:
86+
raise ValueError("`start_from` must be provided if `end_time` is provided")
87+
elif end_time is not None and start_from is not None:
88+
if start_from > end_time:
89+
raise ValueError("`start_from` must be before `end_time`")
90+
91+
tracker = OffsetTracker(
92+
consumer, topics, dlt.current.resource_state(), start_from, end_time
93+
)
8294

8395
# read messages up to the maximum offsets,
8496
# not waiting for new messages
@@ -97,7 +109,19 @@ def kafka_consumer(
97109
else:
98110
raise err
99111
else:
100-
batch.append(msg_processor(msg))
101-
tracker.renew(msg)
112+
topic = msg.topic()
113+
partition = str(msg.partition())
114+
current_offset = msg.offset()
115+
max_offset = tracker[topic][partition]["max"]
116+
117+
# Only process the message if it's within the partition's max offset
118+
if current_offset < max_offset:
119+
batch.append(msg_processor(msg))
120+
tracker.renew(msg)
121+
else:
122+
logger.info(
123+
f"Skipping message on {topic} partition {partition} at offset {current_offset} "
124+
f"- beyond max offset {max_offset}"
125+
)
102126

103127
yield batch

sources/kafka/helpers.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List
22

33
from confluent_kafka import Consumer, Message, TopicPartition # type: ignore
4-
from confluent_kafka.admin import AdminClient, TopicMetadata # type: ignore
4+
from confluent_kafka.admin import TopicMetadata # type: ignore
55

66
from dlt import config, secrets
77
from dlt.common import pendulum
@@ -54,15 +54,17 @@ def default_msg_processor(msg: Message) -> Dict[str, Any]:
5454
class OffsetTracker(dict): # type: ignore
5555
"""Object to control offsets of the given topics.
5656
57-
Tracks all the partitions of the given topics with two params:
58-
current offset and maximum offset (partition length).
57+
Tracks all the partitions of the given topics with three params:
58+
current offset, maximum offset (partition length), and an end time.
5959
6060
Args:
6161
consumer (confluent_kafka.Consumer): Kafka consumer.
6262
topic_names (List): Names of topics to track.
6363
pl_state (DictStrAny): Pipeline current state.
6464
start_from (Optional[pendulum.DateTime]): A timestamp, after which messages
6565
are read. Older messages are ignored.
66+
end_time (Optional[pendulum.DateTime]): A timestamp, before which messages
67+
are read. Newer messages are ignored.
6668
"""
6769

6870
def __init__(
@@ -71,6 +73,7 @@ def __init__(
7173
topic_names: List[str],
7274
pl_state: DictStrAny,
7375
start_from: pendulum.DateTime = None,
76+
end_time: pendulum.DateTime = None,
7477
):
7578
super().__init__()
7679

@@ -82,7 +85,7 @@ def __init__(
8285
"offsets", {t_name: {} for t_name in topic_names}
8386
)
8487

85-
self._init_partition_offsets(start_from)
88+
self._init_partition_offsets(start_from, end_time)
8689

8790
def _read_topics(self, topic_names: List[str]) -> Dict[str, TopicMetadata]:
8891
"""Read the given topics metadata from Kafka.
@@ -104,7 +107,9 @@ def _read_topics(self, topic_names: List[str]) -> Dict[str, TopicMetadata]:
104107

105108
return tracked_topics
106109

107-
def _init_partition_offsets(self, start_from: pendulum.DateTime) -> None:
110+
def _init_partition_offsets(
111+
self, start_from: pendulum.DateTime, end_time: pendulum.DateTime
112+
) -> None:
108113
"""Designate current and maximum offsets for every partition.
109114
110115
Current offsets are read from the state, if present. Set equal
@@ -113,6 +118,8 @@ def _init_partition_offsets(self, start_from: pendulum.DateTime) -> None:
113118
Args:
114119
start_from (pendulum.DateTime): A timestamp, at which to start
115120
reading. Older messages are ignored.
121+
end_time (pendulum.DateTime): A timestamp, before which messages
122+
are read. Newer messages are ignored.
116123
"""
117124
all_parts = []
118125
for t_name, topic in self._topics.items():
@@ -128,27 +135,38 @@ def _init_partition_offsets(self, start_from: pendulum.DateTime) -> None:
128135
for part in topic.partitions
129136
]
130137

131-
# get offsets for the timestamp, if given
132-
if start_from is not None:
133-
ts_offsets = self._consumer.offsets_for_times(parts)
138+
# get offsets for the timestamp ranges, if given
139+
if start_from is not None and end_time is not None:
140+
start_ts_offsets = self._consumer.offsets_for_times(parts)
141+
end_ts_offsets = self._consumer.offsets_for_times(
142+
[
143+
TopicPartition(t_name, part, end_time.int_timestamp * 1000)
144+
for part in topic.partitions
145+
]
146+
)
134147

135148
# designate current and maximum offsets for every partition
136149
for i, part in enumerate(parts):
137150
max_offset = self._consumer.get_watermark_offsets(part)[1]
138151

139-
if start_from is not None:
140-
if ts_offsets[i].offset != -1:
141-
cur_offset = ts_offsets[i].offset
152+
if start_from is not None and end_time is not None:
153+
if start_ts_offsets[i].offset != -1:
154+
cur_offset = start_ts_offsets[i].offset
142155
else:
143156
cur_offset = max_offset - 1
157+
if end_ts_offsets[i].offset != -1:
158+
end_offset = end_ts_offsets[i].offset
159+
else:
160+
end_offset = max_offset
144161
else:
145162
cur_offset = (
146163
self._cur_offsets[t_name].get(str(part.partition), -1) + 1
147164
)
165+
end_offset = max_offset
148166

149167
self[t_name][str(part.partition)] = {
150168
"cur": cur_offset,
151-
"max": max_offset,
169+
"max": end_offset,
152170
}
153171

154172
parts[i].offset = cur_offset

0 commit comments

Comments
 (0)