Code refactoring for bpa operator
[icn.git] / cmd / bpa-operator / vendor / k8s.io / client-go / plugin / pkg / client / auth / oidc / oidc.go
1 /*
2 Copyright 2016 The Kubernetes Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8     http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 package oidc
18
19 import (
20         "context"
21         "encoding/base64"
22         "encoding/json"
23         "errors"
24         "fmt"
25         "io/ioutil"
26         "net/http"
27         "strings"
28         "sync"
29         "time"
30
31         "golang.org/x/oauth2"
32         "k8s.io/apimachinery/pkg/util/net"
33         restclient "k8s.io/client-go/rest"
34         "k8s.io/klog"
35 )
36
37 const (
38         cfgIssuerUrl                = "idp-issuer-url"
39         cfgClientID                 = "client-id"
40         cfgClientSecret             = "client-secret"
41         cfgCertificateAuthority     = "idp-certificate-authority"
42         cfgCertificateAuthorityData = "idp-certificate-authority-data"
43         cfgIDToken                  = "id-token"
44         cfgRefreshToken             = "refresh-token"
45
46         // Unused. Scopes aren't sent during refreshing.
47         cfgExtraScopes = "extra-scopes"
48 )
49
50 func init() {
51         if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
52                 klog.Fatalf("Failed to register oidc auth plugin: %v", err)
53         }
54 }
55
56 // expiryDelta determines how earlier a token should be considered
57 // expired than its actual expiration time. It is used to avoid late
58 // expirations due to client-server time mismatches.
59 //
60 // NOTE(ericchiang): this is take from golang.org/x/oauth2
61 const expiryDelta = 10 * time.Second
62
63 var cache = newClientCache()
64
65 // Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. This ensures
66 // current requests from different clients don't concurrently attempt to refresh the same
67 // set of credentials.
68 type clientCache struct {
69         mu sync.RWMutex
70
71         cache map[cacheKey]*oidcAuthProvider
72 }
73
74 func newClientCache() *clientCache {
75         return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)}
76 }
77
78 type cacheKey struct {
79         // Canonical issuer URL string of the provider.
80         issuerURL string
81         clientID  string
82 }
83
84 func (c *clientCache) getClient(issuer, clientID string) (*oidcAuthProvider, bool) {
85         c.mu.RLock()
86         defer c.mu.RUnlock()
87         client, ok := c.cache[cacheKey{issuer, clientID}]
88         return client, ok
89 }
90
91 // setClient attempts to put the client in the cache but may return any clients
92 // with the same keys set before. This is so there's only ever one client for a provider.
93 func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
94         c.mu.Lock()
95         defer c.mu.Unlock()
96         key := cacheKey{issuer, clientID}
97
98         // If another client has already initialized a client for the given provider we want
99         // to use that client instead of the one we're trying to set. This is so all transports
100         // share a client and can coordinate around the same mutex when refreshing and writing
101         // to the kubeconfig.
102         if oldClient, ok := c.cache[key]; ok {
103                 return oldClient
104         }
105
106         c.cache[key] = client
107         return client
108 }
109
110 func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
111         issuer := cfg[cfgIssuerUrl]
112         if issuer == "" {
113                 return nil, fmt.Errorf("Must provide %s", cfgIssuerUrl)
114         }
115
116         clientID := cfg[cfgClientID]
117         if clientID == "" {
118                 return nil, fmt.Errorf("Must provide %s", cfgClientID)
119         }
120
121         // Check cache for existing provider.
122         if provider, ok := cache.getClient(issuer, clientID); ok {
123                 return provider, nil
124         }
125
126         if len(cfg[cfgExtraScopes]) > 0 {
127                 klog.V(2).Infof("%s auth provider field depricated, refresh request don't send scopes",
128                         cfgExtraScopes)
129         }
130
131         var certAuthData []byte
132         var err error
133         if cfg[cfgCertificateAuthorityData] != "" {
134                 certAuthData, err = base64.StdEncoding.DecodeString(cfg[cfgCertificateAuthorityData])
135                 if err != nil {
136                         return nil, err
137                 }
138         }
139
140         clientConfig := restclient.Config{
141                 TLSClientConfig: restclient.TLSClientConfig{
142                         CAFile: cfg[cfgCertificateAuthority],
143                         CAData: certAuthData,
144                 },
145         }
146
147         trans, err := restclient.TransportFor(&clientConfig)
148         if err != nil {
149                 return nil, err
150         }
151         hc := &http.Client{Transport: trans}
152
153         provider := &oidcAuthProvider{
154                 client:    hc,
155                 now:       time.Now,
156                 cfg:       cfg,
157                 persister: persister,
158         }
159
160         return cache.setClient(issuer, clientID, provider), nil
161 }
162
163 type oidcAuthProvider struct {
164         client *http.Client
165
166         // Method for determining the current time.
167         now func() time.Time
168
169         // Mutex guards persisting to the kubeconfig file and allows synchronized
170         // updates to the in-memory config. It also ensures concurrent calls to
171         // the RoundTripper only trigger a single refresh request.
172         mu        sync.Mutex
173         cfg       map[string]string
174         persister restclient.AuthProviderConfigPersister
175 }
176
177 func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
178         return &roundTripper{
179                 wrapped:  rt,
180                 provider: p,
181         }
182 }
183
184 func (p *oidcAuthProvider) Login() error {
185         return errors.New("not yet implemented")
186 }
187
188 type roundTripper struct {
189         provider *oidcAuthProvider
190         wrapped  http.RoundTripper
191 }
192
193 var _ net.RoundTripperWrapper = &roundTripper{}
194
195 func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
196         if len(req.Header.Get("Authorization")) != 0 {
197                 return r.wrapped.RoundTrip(req)
198         }
199         token, err := r.provider.idToken()
200         if err != nil {
201                 return nil, err
202         }
203
204         // shallow copy of the struct
205         r2 := new(http.Request)
206         *r2 = *req
207         // deep copy of the Header so we don't modify the original
208         // request's Header (as per RoundTripper contract).
209         r2.Header = make(http.Header)
210         for k, s := range req.Header {
211                 r2.Header[k] = s
212         }
213         r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
214
215         return r.wrapped.RoundTrip(r2)
216 }
217
218 func (t *roundTripper) WrappedRoundTripper() http.RoundTripper { return t.wrapped }
219
220 func (p *oidcAuthProvider) idToken() (string, error) {
221         p.mu.Lock()
222         defer p.mu.Unlock()
223
224         if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
225                 valid, err := idTokenExpired(p.now, idToken)
226                 if err != nil {
227                         return "", err
228                 }
229                 if valid {
230                         // If the cached id token is still valid use it.
231                         return idToken, nil
232                 }
233         }
234
235         // Try to request a new token using the refresh token.
236         rt, ok := p.cfg[cfgRefreshToken]
237         if !ok || len(rt) == 0 {
238                 return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
239         }
240
241         // Determine provider's OAuth2 token endpoint.
242         tokenURL, err := tokenEndpoint(p.client, p.cfg[cfgIssuerUrl])
243         if err != nil {
244                 return "", err
245         }
246
247         config := oauth2.Config{
248                 ClientID:     p.cfg[cfgClientID],
249                 ClientSecret: p.cfg[cfgClientSecret],
250                 Endpoint:     oauth2.Endpoint{TokenURL: tokenURL},
251         }
252
253         ctx := context.WithValue(context.Background(), oauth2.HTTPClient, p.client)
254         token, err := config.TokenSource(ctx, &oauth2.Token{RefreshToken: rt}).Token()
255         if err != nil {
256                 return "", fmt.Errorf("failed to refresh token: %v", err)
257         }
258
259         idToken, ok := token.Extra("id_token").(string)
260         if !ok {
261                 // id_token isn't a required part of a refresh token response, so some
262                 // providers (Okta) don't return this value.
263                 //
264                 // See https://github.com/kubernetes/kubernetes/issues/36847
265                 return "", fmt.Errorf("token response did not contain an id_token, either the scope \"openid\" wasn't requested upon login, or the provider doesn't support id_tokens as part of the refresh response.")
266         }
267
268         // Create a new config to persist.
269         newCfg := make(map[string]string)
270         for key, val := range p.cfg {
271                 newCfg[key] = val
272         }
273
274         // Update the refresh token if the server returned another one.
275         if token.RefreshToken != "" && token.RefreshToken != rt {
276                 newCfg[cfgRefreshToken] = token.RefreshToken
277         }
278         newCfg[cfgIDToken] = idToken
279
280         // Persist new config and if successful, update the in memory config.
281         if err = p.persister.Persist(newCfg); err != nil {
282                 return "", fmt.Errorf("could not persist new tokens: %v", err)
283         }
284         p.cfg = newCfg
285
286         return idToken, nil
287 }
288
289 // tokenEndpoint uses OpenID Connect discovery to determine the OAuth2 token
290 // endpoint for the provider, the endpoint the client will use the refresh
291 // token against.
292 func tokenEndpoint(client *http.Client, issuer string) (string, error) {
293         // Well known URL for getting OpenID Connect metadata.
294         //
295         // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
296         wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
297         resp, err := client.Get(wellKnown)
298         if err != nil {
299                 return "", err
300         }
301         defer resp.Body.Close()
302
303         body, err := ioutil.ReadAll(resp.Body)
304         if err != nil {
305                 return "", err
306         }
307         if resp.StatusCode != http.StatusOK {
308                 // Don't produce an error that's too huge (e.g. if we get HTML back for some reason).
309                 const n = 80
310                 if len(body) > n {
311                         body = append(body[:n], []byte("...")...)
312                 }
313                 return "", fmt.Errorf("oidc: failed to query metadata endpoint %s: %q", resp.Status, body)
314         }
315
316         // Metadata object. We only care about the token_endpoint, the thing endpoint
317         // we'll be refreshing against.
318         //
319         // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
320         var metadata struct {
321                 TokenURL string `json:"token_endpoint"`
322         }
323         if err := json.Unmarshal(body, &metadata); err != nil {
324                 return "", fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
325         }
326         if metadata.TokenURL == "" {
327                 return "", fmt.Errorf("oidc: discovery object doesn't contain a token_endpoint")
328         }
329         return metadata.TokenURL, nil
330 }
331
332 func idTokenExpired(now func() time.Time, idToken string) (bool, error) {
333         parts := strings.Split(idToken, ".")
334         if len(parts) != 3 {
335                 return false, fmt.Errorf("ID Token is not a valid JWT")
336         }
337
338         payload, err := base64.RawURLEncoding.DecodeString(parts[1])
339         if err != nil {
340                 return false, err
341         }
342         var claims struct {
343                 Expiry jsonTime `json:"exp"`
344         }
345         if err := json.Unmarshal(payload, &claims); err != nil {
346                 return false, fmt.Errorf("parsing claims: %v", err)
347         }
348
349         return now().Add(expiryDelta).Before(time.Time(claims.Expiry)), nil
350 }
351
352 // jsonTime is a json.Unmarshaler that parses a unix timestamp.
353 // Because JSON numbers don't differentiate between ints and floats,
354 // we want to ensure we can parse either.
355 type jsonTime time.Time
356
357 func (j *jsonTime) UnmarshalJSON(b []byte) error {
358         var n json.Number
359         if err := json.Unmarshal(b, &n); err != nil {
360                 return err
361         }
362         var unix int64
363
364         if t, err := n.Int64(); err == nil {
365                 unix = t
366         } else {
367                 f, err := n.Float64()
368                 if err != nil {
369                         return err
370                 }
371                 unix = int64(f)
372         }
373         *j = jsonTime(time.Unix(unix, 0))
374         return nil
375 }
376
377 func (j jsonTime) MarshalJSON() ([]byte, error) {
378         return json.Marshal(time.Time(j).Unix())
379 }