Skip to content

Commit a8476fb

Browse files
committed
Merge remote-tracking branch 'oauth2cli/dev' into wip
2 parents 3a83efc + b2b00eb commit a8476fb

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

msal/oauth2cli/authcode.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
try: # Python 3
1616
from http.server import HTTPServer, BaseHTTPRequestHandler
1717
from urllib.parse import urlparse, parse_qs, urlencode
18+
from html import escape
1819
except ImportError: # Fall back to Python 2
1920
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
2021
from urlparse import urlparse, parse_qs
2122
from urllib import urlencode
23+
from cgi import escape
2224

2325

2426
logger = logging.getLogger(__name__)
@@ -77,25 +79,42 @@ def _qs2kv(qs):
7779
for k, v in qs.items()}
7880

7981

82+
def _is_html(text):
83+
return text.startswith("<") # Good enough for our purpose
84+
85+
86+
def _escape(key_value_pairs):
87+
return {k: escape(v) for k, v in key_value_pairs.items()}
88+
89+
8090
class _AuthCodeHandler(BaseHTTPRequestHandler):
8191
def do_GET(self):
8292
# For flexibility, we choose to not check self.path matching redirect_uri
8393
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
8494
qs = parse_qs(urlparse(self.path).query)
8595
if qs.get('code') or qs.get("error"): # So, it is an auth response
86-
self.server.auth_response = _qs2kv(qs)
87-
logger.debug("Got auth response: %s", self.server.auth_response)
88-
template = (self.server.success_template
89-
if "code" in qs else self.server.error_template)
90-
self._send_full_response(
91-
template.safe_substitute(**self.server.auth_response))
92-
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
96+
auth_response = _qs2kv(qs)
97+
logger.debug("Got auth response: %s", auth_response)
98+
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
99+
# OAuth2 successful and error responses contain state when it was used
100+
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
101+
self._send_full_response("State mismatch") # Possibly an attack
102+
else:
103+
template = (self.server.success_template
104+
if "code" in qs else self.server.error_template)
105+
if _is_html(template.template):
106+
safe_data = _escape(auth_response) # Foiling an XSS attack
107+
else:
108+
safe_data = auth_response
109+
self._send_full_response(template.safe_substitute(**safe_data))
110+
self.server.auth_response = auth_response # Set it now, after the response is likely sent
93111
else:
94112
self._send_full_response(self.server.welcome_page)
113+
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
95114

96115
def _send_full_response(self, body, is_ok=True):
97116
self.send_response(200 if is_ok else 400)
98-
content_type = 'text/html' if body.startswith('<') else 'text/plain'
117+
content_type = 'text/html' if _is_html(body) else 'text/plain'
99118
self.send_header('Content-type', content_type)
100119
self.end_headers()
101120
self.wfile.write(body.encode("utf-8"))
@@ -281,16 +300,14 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,
281300

282301
self._server.timeout = timeout # Otherwise its handle_timeout() won't work
283302
self._server.auth_response = {} # Shared with _AuthCodeHandler
303+
self._server.auth_state = state # So handler will check it before sending response
284304
while not self._closing: # Otherwise, the handle_request() attempt
285305
# would yield noisy ValueError trace
286306
# Derived from
287307
# https://docs.python.org/2/library/basehttpserver.html#more-examples
288308
self._server.handle_request()
289309
if self._server.auth_response:
290-
if state and state != self._server.auth_response.get("state"):
291-
logger.debug("State mismatch. Ignoring this noise.")
292-
else:
293-
break
310+
break
294311
result.update(self._server.auth_response) # Return via writable result param
295312

296313
def close(self):
@@ -318,6 +335,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
318335
default="https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
319336
p.add_argument('client_id', help="The client_id of your application")
320337
p.add_argument('--port', type=int, default=0, help="The port in redirect_uri")
338+
p.add_argument('--timeout', type=int, default=60, help="Timeout value, in second")
321339
p.add_argument('--host', default="127.0.0.1", help="The host of redirect_uri")
322340
p.add_argument('--scope', default=None, help="The scope list")
323341
args = parser.parse_args()
@@ -331,8 +349,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
331349
auth_uri=flow["auth_uri"],
332350
welcome_template=
333351
"<a href='$auth_uri'>Sign In</a>, or <a href='$abort_uri'>Abort</a",
334-
error_template="Oh no. $error",
352+
error_template="<html>Oh no. $error</html>",
335353
success_template="Oh yeah. Got $code",
336-
timeout=60,
354+
timeout=args.timeout,
337355
state=flow["state"], # Optional
338356
), indent=4))

msal/oauth2cli/oauth2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def _obtain_token_by_browser(
666666
**(auth_params or {}))
667667
auth_response = auth_code_receiver.get_auth_response(
668668
auth_uri=flow["auth_uri"],
669-
state=flow["state"], # Optional but we choose to do it upfront
669+
state=flow["state"], # So receiver can check it early
670670
timeout=timeout,
671671
welcome_template=welcome_template,
672672
success_template=success_template,

tests/test_authcode.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import socket
33
import sys
44

5+
import requests
6+
57
from msal.oauth2cli.authcode import AuthCodeReceiver
68

79

@@ -17,10 +19,24 @@ def test_setup_at_a_ephemeral_port_and_teardown(self):
1719
self.assertNotEqual(port, receiver.get_port())
1820

1921
def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
20-
port = 12345 # Assuming this port is available
21-
with AuthCodeReceiver(port=port) as receiver:
22+
with AuthCodeReceiver() as receiver:
2223
expected_error = OSError if sys.version_info[0] > 2 else socket.error
2324
with self.assertRaises(expected_error):
24-
with AuthCodeReceiver(port=port) as receiver2:
25+
with AuthCodeReceiver(port=receiver.get_port()):
2526
pass
2627

28+
def test_template_should_escape_input(self):
29+
with AuthCodeReceiver() as receiver:
30+
receiver._scheduled_actions = [( # Injection happens here when the port is known
31+
1, # Delay it until the receiver is activated by get_auth_response()
32+
lambda: self.assertEqual(
33+
"<html>&lt;tag&gt;foo&lt;/tag&gt;</html>",
34+
requests.get("http://localhost:{}?error=<tag>foo</tag>".format(
35+
receiver.get_port())).text,
36+
"Unsafe data in HTML should be escaped",
37+
))]
38+
receiver.get_auth_response( # Starts server and hang until timeout
39+
timeout=3,
40+
error_template="<html>$error</html>",
41+
)
42+

0 commit comments

Comments
 (0)