Code refactoring for bpa operator
[icn.git] / cmd / bpa-operator / vendor / github.com / Azure / go-autorest / autorest / adal / token.go
1 package adal
2
3 // Copyright 2017 Microsoft Corporation
4 //
5 //  Licensed under the Apache License, Version 2.0 (the "License");
6 //  you may not use this file except in compliance with the License.
7 //  You may obtain a copy of the License at
8 //
9 //      http://www.apache.org/licenses/LICENSE-2.0
10 //
11 //  Unless required by applicable law or agreed to in writing, software
12 //  distributed under the License is distributed on an "AS IS" BASIS,
13 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 //  See the License for the specific language governing permissions and
15 //  limitations under the License.
16
17 import (
18         "context"
19         "crypto/rand"
20         "crypto/rsa"
21         "crypto/sha1"
22         "crypto/x509"
23         "encoding/base64"
24         "encoding/json"
25         "errors"
26         "fmt"
27         "io/ioutil"
28         "math"
29         "net"
30         "net/http"
31         "net/url"
32         "strings"
33         "sync"
34         "time"
35
36         "github.com/Azure/go-autorest/autorest/date"
37         "github.com/Azure/go-autorest/tracing"
38         "github.com/dgrijalva/jwt-go"
39 )
40
41 const (
42         defaultRefresh = 5 * time.Minute
43
44         // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
45         OAuthGrantTypeDeviceCode = "device_code"
46
47         // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
48         OAuthGrantTypeClientCredentials = "client_credentials"
49
50         // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
51         OAuthGrantTypeUserPass = "password"
52
53         // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
54         OAuthGrantTypeRefreshToken = "refresh_token"
55
56         // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
57         OAuthGrantTypeAuthorizationCode = "authorization_code"
58
59         // metadataHeader is the header required by MSI extension
60         metadataHeader = "Metadata"
61
62         // msiEndpoint is the well known endpoint for getting MSI authentications tokens
63         msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
64
65         // the default number of attempts to refresh an MSI authentication token
66         defaultMaxMSIRefreshAttempts = 5
67 )
68
69 // OAuthTokenProvider is an interface which should be implemented by an access token retriever
70 type OAuthTokenProvider interface {
71         OAuthToken() string
72 }
73
74 // TokenRefreshError is an interface used by errors returned during token refresh.
75 type TokenRefreshError interface {
76         error
77         Response() *http.Response
78 }
79
80 // Refresher is an interface for token refresh functionality
81 type Refresher interface {
82         Refresh() error
83         RefreshExchange(resource string) error
84         EnsureFresh() error
85 }
86
87 // RefresherWithContext is an interface for token refresh functionality
88 type RefresherWithContext interface {
89         RefreshWithContext(ctx context.Context) error
90         RefreshExchangeWithContext(ctx context.Context, resource string) error
91         EnsureFreshWithContext(ctx context.Context) error
92 }
93
94 // TokenRefreshCallback is the type representing callbacks that will be called after
95 // a successful token refresh
96 type TokenRefreshCallback func(Token) error
97
98 // Token encapsulates the access token used to authorize Azure requests.
99 // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
100 type Token struct {
101         AccessToken  string `json:"access_token"`
102         RefreshToken string `json:"refresh_token"`
103
104         ExpiresIn json.Number `json:"expires_in"`
105         ExpiresOn json.Number `json:"expires_on"`
106         NotBefore json.Number `json:"not_before"`
107
108         Resource string `json:"resource"`
109         Type     string `json:"token_type"`
110 }
111
112 func newToken() Token {
113         return Token{
114                 ExpiresIn: "0",
115                 ExpiresOn: "0",
116                 NotBefore: "0",
117         }
118 }
119
120 // IsZero returns true if the token object is zero-initialized.
121 func (t Token) IsZero() bool {
122         return t == Token{}
123 }
124
125 // Expires returns the time.Time when the Token expires.
126 func (t Token) Expires() time.Time {
127         s, err := t.ExpiresOn.Float64()
128         if err != nil {
129                 s = -3600
130         }
131
132         expiration := date.NewUnixTimeFromSeconds(s)
133
134         return time.Time(expiration).UTC()
135 }
136
137 // IsExpired returns true if the Token is expired, false otherwise.
138 func (t Token) IsExpired() bool {
139         return t.WillExpireIn(0)
140 }
141
142 // WillExpireIn returns true if the Token will expire after the passed time.Duration interval
143 // from now, false otherwise.
144 func (t Token) WillExpireIn(d time.Duration) bool {
145         return !t.Expires().After(time.Now().Add(d))
146 }
147
148 //OAuthToken return the current access token
149 func (t *Token) OAuthToken() string {
150         return t.AccessToken
151 }
152
153 // ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
154 // that is submitted when acquiring an oAuth token.
155 type ServicePrincipalSecret interface {
156         SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
157 }
158
159 // ServicePrincipalNoSecret represents a secret type that contains no secret
160 // meaning it is not valid for fetching a fresh token. This is used by Manual
161 type ServicePrincipalNoSecret struct {
162 }
163
164 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret
165 // It only returns an error for the ServicePrincipalNoSecret type
166 func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
167         return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
168 }
169
170 // MarshalJSON implements the json.Marshaler interface.
171 func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
172         type tokenType struct {
173                 Type string `json:"type"`
174         }
175         return json.Marshal(tokenType{
176                 Type: "ServicePrincipalNoSecret",
177         })
178 }
179
180 // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
181 type ServicePrincipalTokenSecret struct {
182         ClientSecret string `json:"value"`
183 }
184
185 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
186 // It will populate the form submitted during oAuth Token Acquisition using the client_secret.
187 func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
188         v.Set("client_secret", tokenSecret.ClientSecret)
189         return nil
190 }
191
192 // MarshalJSON implements the json.Marshaler interface.
193 func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
194         type tokenType struct {
195                 Type  string `json:"type"`
196                 Value string `json:"value"`
197         }
198         return json.Marshal(tokenType{
199                 Type:  "ServicePrincipalTokenSecret",
200                 Value: tokenSecret.ClientSecret,
201         })
202 }
203
204 // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
205 type ServicePrincipalCertificateSecret struct {
206         Certificate *x509.Certificate
207         PrivateKey  *rsa.PrivateKey
208 }
209
210 // SignJwt returns the JWT signed with the certificate's private key.
211 func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
212         hasher := sha1.New()
213         _, err := hasher.Write(secret.Certificate.Raw)
214         if err != nil {
215                 return "", err
216         }
217
218         thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
219
220         // The jti (JWT ID) claim provides a unique identifier for the JWT.
221         jti := make([]byte, 20)
222         _, err = rand.Read(jti)
223         if err != nil {
224                 return "", err
225         }
226
227         token := jwt.New(jwt.SigningMethodRS256)
228         token.Header["x5t"] = thumbprint
229         x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
230         token.Header["x5c"] = x5c
231         token.Claims = jwt.MapClaims{
232                 "aud": spt.inner.OauthConfig.TokenEndpoint.String(),
233                 "iss": spt.inner.ClientID,
234                 "sub": spt.inner.ClientID,
235                 "jti": base64.URLEncoding.EncodeToString(jti),
236                 "nbf": time.Now().Unix(),
237                 "exp": time.Now().Add(time.Hour * 24).Unix(),
238         }
239
240         signedString, err := token.SignedString(secret.PrivateKey)
241         return signedString, err
242 }
243
244 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
245 // It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
246 func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
247         jwt, err := secret.SignJwt(spt)
248         if err != nil {
249                 return err
250         }
251
252         v.Set("client_assertion", jwt)
253         v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
254         return nil
255 }
256
257 // MarshalJSON implements the json.Marshaler interface.
258 func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
259         return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
260 }
261
262 // ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
263 type ServicePrincipalMSISecret struct {
264 }
265
266 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
267 func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
268         return nil
269 }
270
271 // MarshalJSON implements the json.Marshaler interface.
272 func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
273         return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
274 }
275
276 // ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
277 type ServicePrincipalUsernamePasswordSecret struct {
278         Username string `json:"username"`
279         Password string `json:"password"`
280 }
281
282 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
283 func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
284         v.Set("username", secret.Username)
285         v.Set("password", secret.Password)
286         return nil
287 }
288
289 // MarshalJSON implements the json.Marshaler interface.
290 func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
291         type tokenType struct {
292                 Type     string `json:"type"`
293                 Username string `json:"username"`
294                 Password string `json:"password"`
295         }
296         return json.Marshal(tokenType{
297                 Type:     "ServicePrincipalUsernamePasswordSecret",
298                 Username: secret.Username,
299                 Password: secret.Password,
300         })
301 }
302
303 // ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
304 type ServicePrincipalAuthorizationCodeSecret struct {
305         ClientSecret      string `json:"value"`
306         AuthorizationCode string `json:"authCode"`
307         RedirectURI       string `json:"redirect"`
308 }
309
310 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
311 func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
312         v.Set("code", secret.AuthorizationCode)
313         v.Set("client_secret", secret.ClientSecret)
314         v.Set("redirect_uri", secret.RedirectURI)
315         return nil
316 }
317
318 // MarshalJSON implements the json.Marshaler interface.
319 func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
320         type tokenType struct {
321                 Type     string `json:"type"`
322                 Value    string `json:"value"`
323                 AuthCode string `json:"authCode"`
324                 Redirect string `json:"redirect"`
325         }
326         return json.Marshal(tokenType{
327                 Type:     "ServicePrincipalAuthorizationCodeSecret",
328                 Value:    secret.ClientSecret,
329                 AuthCode: secret.AuthorizationCode,
330                 Redirect: secret.RedirectURI,
331         })
332 }
333
334 // ServicePrincipalToken encapsulates a Token created for a Service Principal.
335 type ServicePrincipalToken struct {
336         inner            servicePrincipalToken
337         refreshLock      *sync.RWMutex
338         sender           Sender
339         refreshCallbacks []TokenRefreshCallback
340         // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
341         MaxMSIRefreshAttempts int
342 }
343
344 // MarshalTokenJSON returns the marshalled inner token.
345 func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
346         return json.Marshal(spt.inner.Token)
347 }
348
349 // SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
350 func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
351         spt.refreshCallbacks = callbacks
352 }
353
354 // MarshalJSON implements the json.Marshaler interface.
355 func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
356         return json.Marshal(spt.inner)
357 }
358
359 // UnmarshalJSON implements the json.Unmarshaler interface.
360 func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
361         // need to determine the token type
362         raw := map[string]interface{}{}
363         err := json.Unmarshal(data, &raw)
364         if err != nil {
365                 return err
366         }
367         secret := raw["secret"].(map[string]interface{})
368         switch secret["type"] {
369         case "ServicePrincipalNoSecret":
370                 spt.inner.Secret = &ServicePrincipalNoSecret{}
371         case "ServicePrincipalTokenSecret":
372                 spt.inner.Secret = &ServicePrincipalTokenSecret{}
373         case "ServicePrincipalCertificateSecret":
374                 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
375         case "ServicePrincipalMSISecret":
376                 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
377         case "ServicePrincipalUsernamePasswordSecret":
378                 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
379         case "ServicePrincipalAuthorizationCodeSecret":
380                 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
381         default:
382                 return fmt.Errorf("unrecognized token type '%s'", secret["type"])
383         }
384         err = json.Unmarshal(data, &spt.inner)
385         if err != nil {
386                 return err
387         }
388         // Don't override the refreshLock or the sender if those have been already set.
389         if spt.refreshLock == nil {
390                 spt.refreshLock = &sync.RWMutex{}
391         }
392         if spt.sender == nil {
393                 spt.sender = &http.Client{Transport: tracing.Transport}
394         }
395         return nil
396 }
397
398 // internal type used for marshalling/unmarshalling
399 type servicePrincipalToken struct {
400         Token         Token                  `json:"token"`
401         Secret        ServicePrincipalSecret `json:"secret"`
402         OauthConfig   OAuthConfig            `json:"oauth"`
403         ClientID      string                 `json:"clientID"`
404         Resource      string                 `json:"resource"`
405         AutoRefresh   bool                   `json:"autoRefresh"`
406         RefreshWithin time.Duration          `json:"refreshWithin"`
407 }
408
409 func validateOAuthConfig(oac OAuthConfig) error {
410         if oac.IsZero() {
411                 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
412         }
413         return nil
414 }
415
416 // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
417 func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
418         if err := validateOAuthConfig(oauthConfig); err != nil {
419                 return nil, err
420         }
421         if err := validateStringParam(id, "id"); err != nil {
422                 return nil, err
423         }
424         if err := validateStringParam(resource, "resource"); err != nil {
425                 return nil, err
426         }
427         if secret == nil {
428                 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
429         }
430         spt := &ServicePrincipalToken{
431                 inner: servicePrincipalToken{
432                         Token:         newToken(),
433                         OauthConfig:   oauthConfig,
434                         Secret:        secret,
435                         ClientID:      id,
436                         Resource:      resource,
437                         AutoRefresh:   true,
438                         RefreshWithin: defaultRefresh,
439                 },
440                 refreshLock:      &sync.RWMutex{},
441                 sender:           &http.Client{Transport: tracing.Transport},
442                 refreshCallbacks: callbacks,
443         }
444         return spt, nil
445 }
446
447 // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
448 func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
449         if err := validateOAuthConfig(oauthConfig); err != nil {
450                 return nil, err
451         }
452         if err := validateStringParam(clientID, "clientID"); err != nil {
453                 return nil, err
454         }
455         if err := validateStringParam(resource, "resource"); err != nil {
456                 return nil, err
457         }
458         if token.IsZero() {
459                 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
460         }
461         spt, err := NewServicePrincipalTokenWithSecret(
462                 oauthConfig,
463                 clientID,
464                 resource,
465                 &ServicePrincipalNoSecret{},
466                 callbacks...)
467         if err != nil {
468                 return nil, err
469         }
470
471         spt.inner.Token = token
472
473         return spt, nil
474 }
475
476 // NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
477 func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
478         if err := validateOAuthConfig(oauthConfig); err != nil {
479                 return nil, err
480         }
481         if err := validateStringParam(clientID, "clientID"); err != nil {
482                 return nil, err
483         }
484         if err := validateStringParam(resource, "resource"); err != nil {
485                 return nil, err
486         }
487         if secret == nil {
488                 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
489         }
490         if token.IsZero() {
491                 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
492         }
493         spt, err := NewServicePrincipalTokenWithSecret(
494                 oauthConfig,
495                 clientID,
496                 resource,
497                 secret,
498                 callbacks...)
499         if err != nil {
500                 return nil, err
501         }
502
503         spt.inner.Token = token
504
505         return spt, nil
506 }
507
508 // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
509 // credentials scoped to the named resource.
510 func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
511         if err := validateOAuthConfig(oauthConfig); err != nil {
512                 return nil, err
513         }
514         if err := validateStringParam(clientID, "clientID"); err != nil {
515                 return nil, err
516         }
517         if err := validateStringParam(secret, "secret"); err != nil {
518                 return nil, err
519         }
520         if err := validateStringParam(resource, "resource"); err != nil {
521                 return nil, err
522         }
523         return NewServicePrincipalTokenWithSecret(
524                 oauthConfig,
525                 clientID,
526                 resource,
527                 &ServicePrincipalTokenSecret{
528                         ClientSecret: secret,
529                 },
530                 callbacks...,
531         )
532 }
533
534 // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
535 func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
536         if err := validateOAuthConfig(oauthConfig); err != nil {
537                 return nil, err
538         }
539         if err := validateStringParam(clientID, "clientID"); err != nil {
540                 return nil, err
541         }
542         if err := validateStringParam(resource, "resource"); err != nil {
543                 return nil, err
544         }
545         if certificate == nil {
546                 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
547         }
548         if privateKey == nil {
549                 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
550         }
551         return NewServicePrincipalTokenWithSecret(
552                 oauthConfig,
553                 clientID,
554                 resource,
555                 &ServicePrincipalCertificateSecret{
556                         PrivateKey:  privateKey,
557                         Certificate: certificate,
558                 },
559                 callbacks...,
560         )
561 }
562
563 // NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
564 func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
565         if err := validateOAuthConfig(oauthConfig); err != nil {
566                 return nil, err
567         }
568         if err := validateStringParam(clientID, "clientID"); err != nil {
569                 return nil, err
570         }
571         if err := validateStringParam(username, "username"); err != nil {
572                 return nil, err
573         }
574         if err := validateStringParam(password, "password"); err != nil {
575                 return nil, err
576         }
577         if err := validateStringParam(resource, "resource"); err != nil {
578                 return nil, err
579         }
580         return NewServicePrincipalTokenWithSecret(
581                 oauthConfig,
582                 clientID,
583                 resource,
584                 &ServicePrincipalUsernamePasswordSecret{
585                         Username: username,
586                         Password: password,
587                 },
588                 callbacks...,
589         )
590 }
591
592 // NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
593 func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
594
595         if err := validateOAuthConfig(oauthConfig); err != nil {
596                 return nil, err
597         }
598         if err := validateStringParam(clientID, "clientID"); err != nil {
599                 return nil, err
600         }
601         if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
602                 return nil, err
603         }
604         if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
605                 return nil, err
606         }
607         if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
608                 return nil, err
609         }
610         if err := validateStringParam(resource, "resource"); err != nil {
611                 return nil, err
612         }
613
614         return NewServicePrincipalTokenWithSecret(
615                 oauthConfig,
616                 clientID,
617                 resource,
618                 &ServicePrincipalAuthorizationCodeSecret{
619                         ClientSecret:      clientSecret,
620                         AuthorizationCode: authorizationCode,
621                         RedirectURI:       redirectURI,
622                 },
623                 callbacks...,
624         )
625 }
626
627 // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
628 func GetMSIVMEndpoint() (string, error) {
629         return msiEndpoint, nil
630 }
631
632 // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
633 // It will use the system assigned identity when creating the token.
634 func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
635         return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
636 }
637
638 // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
639 // It will use the specified user assigned identity when creating the token.
640 func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
641         return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
642 }
643
644 func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
645         if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
646                 return nil, err
647         }
648         if err := validateStringParam(resource, "resource"); err != nil {
649                 return nil, err
650         }
651         if userAssignedID != nil {
652                 if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
653                         return nil, err
654                 }
655         }
656         // We set the oauth config token endpoint to be MSI's endpoint
657         msiEndpointURL, err := url.Parse(msiEndpoint)
658         if err != nil {
659                 return nil, err
660         }
661
662         v := url.Values{}
663         v.Set("resource", resource)
664         v.Set("api-version", "2018-02-01")
665         if userAssignedID != nil {
666                 v.Set("client_id", *userAssignedID)
667         }
668         msiEndpointURL.RawQuery = v.Encode()
669
670         spt := &ServicePrincipalToken{
671                 inner: servicePrincipalToken{
672                         Token: newToken(),
673                         OauthConfig: OAuthConfig{
674                                 TokenEndpoint: *msiEndpointURL,
675                         },
676                         Secret:        &ServicePrincipalMSISecret{},
677                         Resource:      resource,
678                         AutoRefresh:   true,
679                         RefreshWithin: defaultRefresh,
680                 },
681                 refreshLock:           &sync.RWMutex{},
682                 sender:                &http.Client{Transport: tracing.Transport},
683                 refreshCallbacks:      callbacks,
684                 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
685         }
686
687         if userAssignedID != nil {
688                 spt.inner.ClientID = *userAssignedID
689         }
690
691         return spt, nil
692 }
693
694 // internal type that implements TokenRefreshError
695 type tokenRefreshError struct {
696         message string
697         resp    *http.Response
698 }
699
700 // Error implements the error interface which is part of the TokenRefreshError interface.
701 func (tre tokenRefreshError) Error() string {
702         return tre.message
703 }
704
705 // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
706 func (tre tokenRefreshError) Response() *http.Response {
707         return tre.resp
708 }
709
710 func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
711         return tokenRefreshError{message: message, resp: resp}
712 }
713
714 // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
715 // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
716 func (spt *ServicePrincipalToken) EnsureFresh() error {
717         return spt.EnsureFreshWithContext(context.Background())
718 }
719
720 // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
721 // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
722 func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
723         if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
724                 // take the write lock then check to see if the token was already refreshed
725                 spt.refreshLock.Lock()
726                 defer spt.refreshLock.Unlock()
727                 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
728                         return spt.refreshInternal(ctx, spt.inner.Resource)
729                 }
730         }
731         return nil
732 }
733
734 // InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
735 func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
736         if spt.refreshCallbacks != nil {
737                 for _, callback := range spt.refreshCallbacks {
738                         err := callback(spt.inner.Token)
739                         if err != nil {
740                                 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
741                         }
742                 }
743         }
744         return nil
745 }
746
747 // Refresh obtains a fresh token for the Service Principal.
748 // This method is not safe for concurrent use and should be syncrhonized.
749 func (spt *ServicePrincipalToken) Refresh() error {
750         return spt.RefreshWithContext(context.Background())
751 }
752
753 // RefreshWithContext obtains a fresh token for the Service Principal.
754 // This method is not safe for concurrent use and should be syncrhonized.
755 func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
756         spt.refreshLock.Lock()
757         defer spt.refreshLock.Unlock()
758         return spt.refreshInternal(ctx, spt.inner.Resource)
759 }
760
761 // RefreshExchange refreshes the token, but for a different resource.
762 // This method is not safe for concurrent use and should be syncrhonized.
763 func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
764         return spt.RefreshExchangeWithContext(context.Background(), resource)
765 }
766
767 // RefreshExchangeWithContext refreshes the token, but for a different resource.
768 // This method is not safe for concurrent use and should be syncrhonized.
769 func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
770         spt.refreshLock.Lock()
771         defer spt.refreshLock.Unlock()
772         return spt.refreshInternal(ctx, resource)
773 }
774
775 func (spt *ServicePrincipalToken) getGrantType() string {
776         switch spt.inner.Secret.(type) {
777         case *ServicePrincipalUsernamePasswordSecret:
778                 return OAuthGrantTypeUserPass
779         case *ServicePrincipalAuthorizationCodeSecret:
780                 return OAuthGrantTypeAuthorizationCode
781         default:
782                 return OAuthGrantTypeClientCredentials
783         }
784 }
785
786 func isIMDS(u url.URL) bool {
787         imds, err := url.Parse(msiEndpoint)
788         if err != nil {
789                 return false
790         }
791         return u.Host == imds.Host && u.Path == imds.Path
792 }
793
794 func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
795         req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
796         if err != nil {
797                 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
798         }
799         req.Header.Add("User-Agent", UserAgent())
800         req = req.WithContext(ctx)
801         if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
802                 v := url.Values{}
803                 v.Set("client_id", spt.inner.ClientID)
804                 v.Set("resource", resource)
805
806                 if spt.inner.Token.RefreshToken != "" {
807                         v.Set("grant_type", OAuthGrantTypeRefreshToken)
808                         v.Set("refresh_token", spt.inner.Token.RefreshToken)
809                         // web apps must specify client_secret when refreshing tokens
810                         // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
811                         if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
812                                 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
813                                 if err != nil {
814                                         return err
815                                 }
816                         }
817                 } else {
818                         v.Set("grant_type", spt.getGrantType())
819                         err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
820                         if err != nil {
821                                 return err
822                         }
823                 }
824
825                 s := v.Encode()
826                 body := ioutil.NopCloser(strings.NewReader(s))
827                 req.ContentLength = int64(len(s))
828                 req.Header.Set(contentType, mimeTypeFormPost)
829                 req.Body = body
830         }
831
832         if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
833                 req.Method = http.MethodGet
834                 req.Header.Set(metadataHeader, "true")
835         }
836
837         var resp *http.Response
838         if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
839                 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
840         } else {
841                 resp, err = spt.sender.Do(req)
842         }
843         if err != nil {
844                 return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
845         }
846
847         defer resp.Body.Close()
848         rb, err := ioutil.ReadAll(resp.Body)
849
850         if resp.StatusCode != http.StatusOK {
851                 if err != nil {
852                         return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
853                 }
854                 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
855         }
856
857         // for the following error cases don't return a TokenRefreshError.  the operation succeeded
858         // but some transient failure happened during deserialization.  by returning a generic error
859         // the retry logic will kick in (we don't retry on TokenRefreshError).
860
861         if err != nil {
862                 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
863         }
864         if len(strings.Trim(string(rb), " ")) == 0 {
865                 return fmt.Errorf("adal: Empty service principal token received during refresh")
866         }
867         var token Token
868         err = json.Unmarshal(rb, &token)
869         if err != nil {
870                 return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
871         }
872
873         spt.inner.Token = token
874
875         return spt.InvokeRefreshCallbacks(token)
876 }
877
878 // retry logic specific to retrieving a token from the IMDS endpoint
879 func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
880         // copied from client.go due to circular dependency
881         retries := []int{
882                 http.StatusRequestTimeout,      // 408
883                 http.StatusTooManyRequests,     // 429
884                 http.StatusInternalServerError, // 500
885                 http.StatusBadGateway,          // 502
886                 http.StatusServiceUnavailable,  // 503
887                 http.StatusGatewayTimeout,      // 504
888         }
889         // extra retry status codes specific to IMDS
890         retries = append(retries,
891                 http.StatusNotFound,
892                 http.StatusGone,
893                 // all remaining 5xx
894                 http.StatusNotImplemented,
895                 http.StatusHTTPVersionNotSupported,
896                 http.StatusVariantAlsoNegotiates,
897                 http.StatusInsufficientStorage,
898                 http.StatusLoopDetected,
899                 http.StatusNotExtended,
900                 http.StatusNetworkAuthenticationRequired)
901
902         // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
903
904         const maxDelay time.Duration = 60 * time.Second
905
906         attempt := 0
907         delay := time.Duration(0)
908
909         for attempt < maxAttempts {
910                 resp, err = sender.Do(req)
911                 // retry on temporary network errors, e.g. transient network failures.
912                 // if we don't receive a response then assume we can't connect to the
913                 // endpoint so we're likely not running on an Azure VM so don't retry.
914                 if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
915                         return
916                 }
917
918                 // perform exponential backoff with a cap.
919                 // must increment attempt before calculating delay.
920                 attempt++
921                 // the base value of 2 is the "delta backoff" as specified in the guidance doc
922                 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
923                 if delay > maxDelay {
924                         delay = maxDelay
925                 }
926
927                 select {
928                 case <-time.After(delay):
929                         // intentionally left blank
930                 case <-req.Context().Done():
931                         err = req.Context().Err()
932                         return
933                 }
934         }
935         return
936 }
937
938 // returns true if the specified error is a temporary network error or false if it's not.
939 // if the error doesn't implement the net.Error interface the return value is true.
940 func isTemporaryNetworkError(err error) bool {
941         if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
942                 return true
943         }
944         return false
945 }
946
947 // returns true if slice ints contains the value n
948 func containsInt(ints []int, n int) bool {
949         for _, i := range ints {
950                 if i == n {
951                         return true
952                 }
953         }
954         return false
955 }
956
957 // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
958 func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
959         spt.inner.AutoRefresh = autoRefresh
960 }
961
962 // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
963 // refresh the token.
964 func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
965         spt.inner.RefreshWithin = d
966         return
967 }
968
969 // SetSender sets the http.Client used when obtaining the Service Principal token. An
970 // undecorated http.Client is used by default.
971 func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
972
973 // OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
974 func (spt *ServicePrincipalToken) OAuthToken() string {
975         spt.refreshLock.RLock()
976         defer spt.refreshLock.RUnlock()
977         return spt.inner.Token.OAuthToken()
978 }
979
980 // Token returns a copy of the current token.
981 func (spt *ServicePrincipalToken) Token() Token {
982         spt.refreshLock.RLock()
983         defer spt.refreshLock.RUnlock()
984         return spt.inner.Token
985 }