Skip to content

Commit 32d34ef

Browse files
committed
internal: include clientID in auth style cache key
Fixes #654 Change-Id: I735891f2a77c3797662b2eadab7e7828ff14bf5f Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/666915 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Junyang Shao <shaojunyang@google.com> Reviewed-by: Matt Hickford <matt.hickford@gmail.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
1 parent 2d34e30 commit 32d34ef

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

internal/token.go

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
105105
return nil
106106
}
107107

108-
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
109-
//
110-
// Deprecated: this function no longer does anything. Caller code that
111-
// wants to avoid potential extra HTTP requests made during
112-
// auto-probing of the provider's auth style should set
113-
// Endpoint.AuthStyle.
114-
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
115-
116108
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
117109
type AuthStyle int
118110

@@ -149,33 +141,38 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
149141
return c
150142
}
151143

144+
type authStyleCacheKey struct {
145+
url string
146+
clientID string
147+
}
148+
152149
// AuthStyleCache is the set of tokenURLs we've successfully used via
153150
// RetrieveToken and which style auth we ended up using.
154151
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
155152
// the set of OAuth2 servers a program contacts over time is fixed and
156153
// small.
157154
type AuthStyleCache struct {
158155
mu sync.Mutex
159-
m map[string]AuthStyle // keyed by tokenURL
156+
m map[authStyleCacheKey]AuthStyle
160157
}
161158

162159
// lookupAuthStyle reports which auth style we last used with tokenURL
163160
// when calling RetrieveToken and whether we have ever done so.
164-
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
161+
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
165162
c.mu.Lock()
166163
defer c.mu.Unlock()
167-
style, ok = c.m[tokenURL]
164+
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
168165
return
169166
}
170167

171168
// setAuthStyle adds an entry to authStyleCache, documented above.
172-
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
169+
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
173170
c.mu.Lock()
174171
defer c.mu.Unlock()
175172
if c.m == nil {
176-
c.m = make(map[string]AuthStyle)
173+
c.m = make(map[authStyleCacheKey]AuthStyle)
177174
}
178-
c.m[tokenURL] = v
175+
c.m[authStyleCacheKey{tokenURL, clientID}] = v
179176
}
180177

181178
// newTokenRequest returns a new *http.Request to retrieve a new token
@@ -218,7 +215,7 @@ func cloneURLValues(v url.Values) url.Values {
218215
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
219216
needsAuthStyleProbe := authStyle == AuthStyleUnknown
220217
if needsAuthStyleProbe {
221-
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
218+
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
222219
authStyle = style
223220
needsAuthStyleProbe = false
224221
} else {
@@ -248,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
248245
token, err = doTokenRoundTrip(ctx, req)
249246
}
250247
if needsAuthStyleProbe && err == nil {
251-
styleCache.setAuthStyle(tokenURL, authStyle)
248+
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
252249
}
253250
// Don't overwrite `RefreshToken` with an empty value
254251
// if this was a token refreshing request.

internal/token_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) {
7575
t.Errorf("expiration time = %v; want %v", e, want)
7676
}
7777
}
78+
79+
func TestAuthStyleCache(t *testing.T) {
80+
var c LazyAuthStyleCache
81+
82+
cases := []struct {
83+
url string
84+
clientID string
85+
style AuthStyle
86+
}{
87+
{
88+
"https://host1.example.com/token",
89+
"client_1",
90+
AuthStyleInHeader,
91+
}, {
92+
"https://host2.example.com/token",
93+
"client_2",
94+
AuthStyleInParams,
95+
}, {
96+
"https://host1.example.com/token",
97+
"client_3",
98+
AuthStyleInParams,
99+
},
100+
}
101+
102+
for _, tt := range cases {
103+
t.Run(tt.clientID, func(t *testing.T) {
104+
cc := c.Get()
105+
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
106+
if ok {
107+
t.Fatalf("unexpected auth style found on first request: %v", got)
108+
}
109+
110+
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
111+
112+
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
113+
if !ok {
114+
t.Fatalf("auth style not found in cache")
115+
}
116+
117+
if got != tt.style {
118+
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
119+
}
120+
})
121+
}
122+
}

0 commit comments

Comments
 (0)