-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathwiznet5k_dns.py
252 lines (212 loc) · 7.7 KB
/
wiznet5k_dns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# SPDX-FileCopyrightText: 2009-2010 MCQN Ltd
# SPDX-FileCopyrightText: Brent Rubell for Adafruit Industries
# SPDX-FileCopyrightText: 2021 Vincenzo D'Angelo
#
# SPDX-License-Identifier: MIT
"""
`wiznet5k_dns`
================================================================================
Pure-Python implementation of the Micropython on ESP32 DNS client for WIZnet 5k-based
ethernet modules.
* Author(s): MCQN Ltd, Brent Rubell, Vincenzo D'Angelo
"""
import time
import wiznet5k_socket as socket
from wiznet5k_socket import htons
from random import getrandbits
from micropython import const
QUERY_FLAG = const(0x00)
OPCODE_STANDARD_QUERY = const(0x00)
RECURSION_DESIRED_FLAG = 1 << 8
TYPE_A = const(0x0001)
CLASS_IN = const(0x0001)
DATA_LEN = const(0x0004)
# Return codes for gethostbyname
SUCCESS = const(1)
TIMED_OUT = const(-1)
INVALID_SERVER = const(-2)
TRUNCATED = const(-3)
INVALID_RESPONSE = const(-4)
DNS_PORT = const(0x35) # port used for DNS request
class DNS:
"""W5K DNS implementation.
:param iface: Network interface
"""
def __init__(self, iface, dns_address, debug=False):
self._debug = debug
self._iface = iface
socket.set_interface(iface)
self._sock = socket.socket(type=socket.SOCK_DGRAM)
self._sock.settimeout(1)
self._dns_server = dns_address
self._host = 0
self._request_id = 0 # request identifier
self._pkt_buf = bytearray()
def gethostbyname(self, hostname):
"""Translate a host name to IPv4 address format.
:param str hostname: Desired host name to connect to.
Returns the IPv4 address as a bytearray if successful, -1 otherwise.
"""
if self._dns_server is None:
return INVALID_SERVER
self._host = hostname
# build DNS request packet
self._build_dns_header()
self._build_dns_question()
# Send DNS request packet
self._sock.bind((None, DNS_PORT))
self._sock.connect((self._dns_server, DNS_PORT))
if self._debug:
print("* DNS: Sending request packet...")
self._sock.send(self._pkt_buf)
# wait and retry 3 times for a response
retries = 0
addr = -1
while (retries < 5) and (addr == -1):
addr = self._parse_dns_response()
if addr == -1 and self._debug:
print("* DNS ERROR: Failed to resolve DNS response, retrying...")
retries += 1
self._sock.close()
return addr
# pylint: disable=too-many-return-statements, too-many-branches, too-many-statements, too-many-locals
def _parse_dns_response(self):
"""Receives and parses DNS query response.
Returns desired hostname address if obtained, -1 otherwise.
"""
# wait for a response
start_time = time.time()
packet_sz = self._sock.available()
while packet_sz <= 0:
packet_sz = self._sock.available()
if (time.time() - start_time) > 1.0:
if self._debug:
print("* DNS ERROR: Did not receive DNS response!")
return -1
time.sleep(0.05)
# recv packet into buf
self._pkt_buf = self._sock.recv()
if self._debug:
print("DNS Packet Received: ", self._pkt_buf)
# Validate request identifier
xid = int.from_bytes(self._pkt_buf[0:2], "l")
if not xid == self._request_id:
if self._debug:
print("* DNS ERROR: Received request identifer {} does not match expected {}".format(xid, self._request_id))
return -1
# Validate flags
flags = int.from_bytes(self._pkt_buf[2:4], "l")
if not flags in (0x8180, 0x8580):
if self._debug:
print("* DNS ERROR: Invalid flags, ", flags)
return -1
# Number of questions
qr_count = int.from_bytes(self._pkt_buf[4:6], "l")
if not qr_count >= 1:
if self._debug:
print("* DNS ERROR: Question count >=1, ", qr_count)
return -1
# Number of answers
an_count = int.from_bytes(self._pkt_buf[6:8], "l")
if self._debug:
print("* DNS Answer Count: ", an_count)
if not an_count >= 1:
return -1
# Parse query
ptr = 12
name_len = 1
while name_len > 0:
# read the length of the name
name_len = self._pkt_buf[ptr]
if name_len == 0x00:
# we reached the end of this name
ptr += 1 # inc. pointer by 0x00
break
# advance pointer
ptr += name_len + 1
# Validate Query is Type A
q_type = int.from_bytes(self._pkt_buf[ptr : ptr + 2], "l")
if not q_type == TYPE_A:
if self._debug:
print("* DNS ERROR: Incorrect Query Type: ", q_type)
return -1
ptr += 2
# Validate Query is Type A
q_class = int.from_bytes(self._pkt_buf[ptr : ptr + 2], "l")
if not q_class == TYPE_A:
if self._debug:
print("* DNS ERROR: Incorrect Query Class: ", q_class)
return -1
ptr += 2
# Let's take the first type-a answer
if self._pkt_buf[ptr] != 0xC0:
return -1
ptr += 1
if self._pkt_buf[ptr] != 0xC:
return -1
ptr += 1
# Validate Answer Type A
ans_type = int.from_bytes(self._pkt_buf[ptr : ptr + 2], "l")
if not ans_type == TYPE_A:
if self._debug:
print("* DNS ERROR: Incorrect Answer Type: ", ans_type)
return -1
ptr += 2
# Validate Answer Class IN
ans_class = int.from_bytes(self._pkt_buf[ptr : ptr + 2], "l")
if not ans_class == TYPE_A:
if self._debug:
print("* DNS ERROR: Incorrect Answer Class: ", ans_class)
return -1
ptr += 2
# skip over TTL
ptr += 4
# Validate addr is IPv4
data_len = int.from_bytes(self._pkt_buf[ptr : ptr + 2], "l")
if not data_len == DATA_LEN:
if self._debug:
print("* DNS ERROR: Unexpected Data Length: ", data_len)
return -1
ptr += 2
# Return address
return self._pkt_buf[ptr : ptr + 4]
def _build_dns_header(self):
"""Builds DNS header."""
# generate a random, 16-bit, request identifier
self._request_id = getrandbits(16)
# ID, 16-bit identifier
self._pkt_buf.append(self._request_id >> 8)
self._pkt_buf.append(self._request_id & 0xFF)
# Flags (0x0100)
self._pkt_buf.append(0x01)
self._pkt_buf.append(0x00)
# QDCOUNT
self._pkt_buf.append(0x00)
self._pkt_buf.append(0x01)
# ANCOUNT
self._pkt_buf.append(0x00)
self._pkt_buf.append(0x00)
# NSCOUNT
self._pkt_buf.append(0x00)
self._pkt_buf.append(0x00)
# ARCOUNT
self._pkt_buf.append(0x00)
self._pkt_buf.append(0x00)
def _build_dns_question(self):
"""Build DNS question"""
host = self._host.decode("utf-8")
host = host.split(".")
# write out each section of host
for i, _ in enumerate(host):
# append the sz of the section
self._pkt_buf.append(len(host[i]))
# append the section data
self._pkt_buf += host[i]
# end of the name
self._pkt_buf.append(0x00)
# Type A record
self._pkt_buf.append(htons(TYPE_A) & 0xFF)
self._pkt_buf.append(htons(TYPE_A) >> 8)
# Class IN
self._pkt_buf.append(htons(CLASS_IN) & 0xFF)
self._pkt_buf.append(htons(CLASS_IN) >> 8)