3 // Copyright 2017 Microsoft Corporation
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
36 "github.com/Azure/go-autorest/autorest/date"
37 "github.com/Azure/go-autorest/tracing"
38 "github.com/dgrijalva/jwt-go"
42 defaultRefresh = 5 * time.Minute
44 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
45 OAuthGrantTypeDeviceCode = "device_code"
47 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
48 OAuthGrantTypeClientCredentials = "client_credentials"
50 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
51 OAuthGrantTypeUserPass = "password"
53 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
54 OAuthGrantTypeRefreshToken = "refresh_token"
56 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
57 OAuthGrantTypeAuthorizationCode = "authorization_code"
59 // metadataHeader is the header required by MSI extension
60 metadataHeader = "Metadata"
62 // msiEndpoint is the well known endpoint for getting MSI authentications tokens
63 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
65 // the default number of attempts to refresh an MSI authentication token
66 defaultMaxMSIRefreshAttempts = 5
69 // OAuthTokenProvider is an interface which should be implemented by an access token retriever
70 type OAuthTokenProvider interface {
74 // TokenRefreshError is an interface used by errors returned during token refresh.
75 type TokenRefreshError interface {
77 Response() *http.Response
80 // Refresher is an interface for token refresh functionality
81 type Refresher interface {
83 RefreshExchange(resource string) error
87 // RefresherWithContext is an interface for token refresh functionality
88 type RefresherWithContext interface {
89 RefreshWithContext(ctx context.Context) error
90 RefreshExchangeWithContext(ctx context.Context, resource string) error
91 EnsureFreshWithContext(ctx context.Context) error
94 // TokenRefreshCallback is the type representing callbacks that will be called after
95 // a successful token refresh
96 type TokenRefreshCallback func(Token) error
98 // Token encapsulates the access token used to authorize Azure requests.
99 // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
101 AccessToken string `json:"access_token"`
102 RefreshToken string `json:"refresh_token"`
104 ExpiresIn json.Number `json:"expires_in"`
105 ExpiresOn json.Number `json:"expires_on"`
106 NotBefore json.Number `json:"not_before"`
108 Resource string `json:"resource"`
109 Type string `json:"token_type"`
112 func newToken() Token {
120 // IsZero returns true if the token object is zero-initialized.
121 func (t Token) IsZero() bool {
125 // Expires returns the time.Time when the Token expires.
126 func (t Token) Expires() time.Time {
127 s, err := t.ExpiresOn.Float64()
132 expiration := date.NewUnixTimeFromSeconds(s)
134 return time.Time(expiration).UTC()
137 // IsExpired returns true if the Token is expired, false otherwise.
138 func (t Token) IsExpired() bool {
139 return t.WillExpireIn(0)
142 // WillExpireIn returns true if the Token will expire after the passed time.Duration interval
143 // from now, false otherwise.
144 func (t Token) WillExpireIn(d time.Duration) bool {
145 return !t.Expires().After(time.Now().Add(d))
148 //OAuthToken return the current access token
149 func (t *Token) OAuthToken() string {
153 // ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
154 // that is submitted when acquiring an oAuth token.
155 type ServicePrincipalSecret interface {
156 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
159 // ServicePrincipalNoSecret represents a secret type that contains no secret
160 // meaning it is not valid for fetching a fresh token. This is used by Manual
161 type ServicePrincipalNoSecret struct {
164 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret
165 // It only returns an error for the ServicePrincipalNoSecret type
166 func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
167 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
170 // MarshalJSON implements the json.Marshaler interface.
171 func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
172 type tokenType struct {
173 Type string `json:"type"`
175 return json.Marshal(tokenType{
176 Type: "ServicePrincipalNoSecret",
180 // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
181 type ServicePrincipalTokenSecret struct {
182 ClientSecret string `json:"value"`
185 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
186 // It will populate the form submitted during oAuth Token Acquisition using the client_secret.
187 func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
188 v.Set("client_secret", tokenSecret.ClientSecret)
192 // MarshalJSON implements the json.Marshaler interface.
193 func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
194 type tokenType struct {
195 Type string `json:"type"`
196 Value string `json:"value"`
198 return json.Marshal(tokenType{
199 Type: "ServicePrincipalTokenSecret",
200 Value: tokenSecret.ClientSecret,
204 // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
205 type ServicePrincipalCertificateSecret struct {
206 Certificate *x509.Certificate
207 PrivateKey *rsa.PrivateKey
210 // SignJwt returns the JWT signed with the certificate's private key.
211 func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
213 _, err := hasher.Write(secret.Certificate.Raw)
218 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
220 // The jti (JWT ID) claim provides a unique identifier for the JWT.
221 jti := make([]byte, 20)
222 _, err = rand.Read(jti)
227 token := jwt.New(jwt.SigningMethodRS256)
228 token.Header["x5t"] = thumbprint
229 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
230 token.Header["x5c"] = x5c
231 token.Claims = jwt.MapClaims{
232 "aud": spt.inner.OauthConfig.TokenEndpoint.String(),
233 "iss": spt.inner.ClientID,
234 "sub": spt.inner.ClientID,
235 "jti": base64.URLEncoding.EncodeToString(jti),
236 "nbf": time.Now().Unix(),
237 "exp": time.Now().Add(time.Hour * 24).Unix(),
240 signedString, err := token.SignedString(secret.PrivateKey)
241 return signedString, err
244 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
245 // It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
246 func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
247 jwt, err := secret.SignJwt(spt)
252 v.Set("client_assertion", jwt)
253 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
257 // MarshalJSON implements the json.Marshaler interface.
258 func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
259 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
262 // ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
263 type ServicePrincipalMSISecret struct {
266 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
267 func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
271 // MarshalJSON implements the json.Marshaler interface.
272 func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
273 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
276 // ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
277 type ServicePrincipalUsernamePasswordSecret struct {
278 Username string `json:"username"`
279 Password string `json:"password"`
282 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
283 func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
284 v.Set("username", secret.Username)
285 v.Set("password", secret.Password)
289 // MarshalJSON implements the json.Marshaler interface.
290 func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
291 type tokenType struct {
292 Type string `json:"type"`
293 Username string `json:"username"`
294 Password string `json:"password"`
296 return json.Marshal(tokenType{
297 Type: "ServicePrincipalUsernamePasswordSecret",
298 Username: secret.Username,
299 Password: secret.Password,
303 // ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
304 type ServicePrincipalAuthorizationCodeSecret struct {
305 ClientSecret string `json:"value"`
306 AuthorizationCode string `json:"authCode"`
307 RedirectURI string `json:"redirect"`
310 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
311 func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
312 v.Set("code", secret.AuthorizationCode)
313 v.Set("client_secret", secret.ClientSecret)
314 v.Set("redirect_uri", secret.RedirectURI)
318 // MarshalJSON implements the json.Marshaler interface.
319 func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
320 type tokenType struct {
321 Type string `json:"type"`
322 Value string `json:"value"`
323 AuthCode string `json:"authCode"`
324 Redirect string `json:"redirect"`
326 return json.Marshal(tokenType{
327 Type: "ServicePrincipalAuthorizationCodeSecret",
328 Value: secret.ClientSecret,
329 AuthCode: secret.AuthorizationCode,
330 Redirect: secret.RedirectURI,
334 // ServicePrincipalToken encapsulates a Token created for a Service Principal.
335 type ServicePrincipalToken struct {
336 inner servicePrincipalToken
337 refreshLock *sync.RWMutex
339 refreshCallbacks []TokenRefreshCallback
340 // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
341 MaxMSIRefreshAttempts int
344 // MarshalTokenJSON returns the marshalled inner token.
345 func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
346 return json.Marshal(spt.inner.Token)
349 // SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
350 func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
351 spt.refreshCallbacks = callbacks
354 // MarshalJSON implements the json.Marshaler interface.
355 func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
356 return json.Marshal(spt.inner)
359 // UnmarshalJSON implements the json.Unmarshaler interface.
360 func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
361 // need to determine the token type
362 raw := map[string]interface{}{}
363 err := json.Unmarshal(data, &raw)
367 secret := raw["secret"].(map[string]interface{})
368 switch secret["type"] {
369 case "ServicePrincipalNoSecret":
370 spt.inner.Secret = &ServicePrincipalNoSecret{}
371 case "ServicePrincipalTokenSecret":
372 spt.inner.Secret = &ServicePrincipalTokenSecret{}
373 case "ServicePrincipalCertificateSecret":
374 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
375 case "ServicePrincipalMSISecret":
376 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
377 case "ServicePrincipalUsernamePasswordSecret":
378 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
379 case "ServicePrincipalAuthorizationCodeSecret":
380 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
382 return fmt.Errorf("unrecognized token type '%s'", secret["type"])
384 err = json.Unmarshal(data, &spt.inner)
388 // Don't override the refreshLock or the sender if those have been already set.
389 if spt.refreshLock == nil {
390 spt.refreshLock = &sync.RWMutex{}
392 if spt.sender == nil {
393 spt.sender = &http.Client{Transport: tracing.Transport}
398 // internal type used for marshalling/unmarshalling
399 type servicePrincipalToken struct {
400 Token Token `json:"token"`
401 Secret ServicePrincipalSecret `json:"secret"`
402 OauthConfig OAuthConfig `json:"oauth"`
403 ClientID string `json:"clientID"`
404 Resource string `json:"resource"`
405 AutoRefresh bool `json:"autoRefresh"`
406 RefreshWithin time.Duration `json:"refreshWithin"`
409 func validateOAuthConfig(oac OAuthConfig) error {
411 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
416 // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
417 func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
418 if err := validateOAuthConfig(oauthConfig); err != nil {
421 if err := validateStringParam(id, "id"); err != nil {
424 if err := validateStringParam(resource, "resource"); err != nil {
428 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
430 spt := &ServicePrincipalToken{
431 inner: servicePrincipalToken{
433 OauthConfig: oauthConfig,
438 RefreshWithin: defaultRefresh,
440 refreshLock: &sync.RWMutex{},
441 sender: &http.Client{Transport: tracing.Transport},
442 refreshCallbacks: callbacks,
447 // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
448 func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
449 if err := validateOAuthConfig(oauthConfig); err != nil {
452 if err := validateStringParam(clientID, "clientID"); err != nil {
455 if err := validateStringParam(resource, "resource"); err != nil {
459 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
461 spt, err := NewServicePrincipalTokenWithSecret(
465 &ServicePrincipalNoSecret{},
471 spt.inner.Token = token
476 // NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
477 func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
478 if err := validateOAuthConfig(oauthConfig); err != nil {
481 if err := validateStringParam(clientID, "clientID"); err != nil {
484 if err := validateStringParam(resource, "resource"); err != nil {
488 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
491 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
493 spt, err := NewServicePrincipalTokenWithSecret(
503 spt.inner.Token = token
508 // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
509 // credentials scoped to the named resource.
510 func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
511 if err := validateOAuthConfig(oauthConfig); err != nil {
514 if err := validateStringParam(clientID, "clientID"); err != nil {
517 if err := validateStringParam(secret, "secret"); err != nil {
520 if err := validateStringParam(resource, "resource"); err != nil {
523 return NewServicePrincipalTokenWithSecret(
527 &ServicePrincipalTokenSecret{
528 ClientSecret: secret,
534 // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
535 func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
536 if err := validateOAuthConfig(oauthConfig); err != nil {
539 if err := validateStringParam(clientID, "clientID"); err != nil {
542 if err := validateStringParam(resource, "resource"); err != nil {
545 if certificate == nil {
546 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
548 if privateKey == nil {
549 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
551 return NewServicePrincipalTokenWithSecret(
555 &ServicePrincipalCertificateSecret{
556 PrivateKey: privateKey,
557 Certificate: certificate,
563 // NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
564 func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
565 if err := validateOAuthConfig(oauthConfig); err != nil {
568 if err := validateStringParam(clientID, "clientID"); err != nil {
571 if err := validateStringParam(username, "username"); err != nil {
574 if err := validateStringParam(password, "password"); err != nil {
577 if err := validateStringParam(resource, "resource"); err != nil {
580 return NewServicePrincipalTokenWithSecret(
584 &ServicePrincipalUsernamePasswordSecret{
592 // NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
593 func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
595 if err := validateOAuthConfig(oauthConfig); err != nil {
598 if err := validateStringParam(clientID, "clientID"); err != nil {
601 if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
604 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
607 if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
610 if err := validateStringParam(resource, "resource"); err != nil {
614 return NewServicePrincipalTokenWithSecret(
618 &ServicePrincipalAuthorizationCodeSecret{
619 ClientSecret: clientSecret,
620 AuthorizationCode: authorizationCode,
621 RedirectURI: redirectURI,
627 // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
628 func GetMSIVMEndpoint() (string, error) {
629 return msiEndpoint, nil
632 // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
633 // It will use the system assigned identity when creating the token.
634 func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
635 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
638 // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
639 // It will use the specified user assigned identity when creating the token.
640 func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
641 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
644 func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
645 if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
648 if err := validateStringParam(resource, "resource"); err != nil {
651 if userAssignedID != nil {
652 if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
656 // We set the oauth config token endpoint to be MSI's endpoint
657 msiEndpointURL, err := url.Parse(msiEndpoint)
663 v.Set("resource", resource)
664 v.Set("api-version", "2018-02-01")
665 if userAssignedID != nil {
666 v.Set("client_id", *userAssignedID)
668 msiEndpointURL.RawQuery = v.Encode()
670 spt := &ServicePrincipalToken{
671 inner: servicePrincipalToken{
673 OauthConfig: OAuthConfig{
674 TokenEndpoint: *msiEndpointURL,
676 Secret: &ServicePrincipalMSISecret{},
679 RefreshWithin: defaultRefresh,
681 refreshLock: &sync.RWMutex{},
682 sender: &http.Client{Transport: tracing.Transport},
683 refreshCallbacks: callbacks,
684 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
687 if userAssignedID != nil {
688 spt.inner.ClientID = *userAssignedID
694 // internal type that implements TokenRefreshError
695 type tokenRefreshError struct {
700 // Error implements the error interface which is part of the TokenRefreshError interface.
701 func (tre tokenRefreshError) Error() string {
705 // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
706 func (tre tokenRefreshError) Response() *http.Response {
710 func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
711 return tokenRefreshError{message: message, resp: resp}
714 // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
715 // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
716 func (spt *ServicePrincipalToken) EnsureFresh() error {
717 return spt.EnsureFreshWithContext(context.Background())
720 // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
721 // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
722 func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
723 if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
724 // take the write lock then check to see if the token was already refreshed
725 spt.refreshLock.Lock()
726 defer spt.refreshLock.Unlock()
727 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
728 return spt.refreshInternal(ctx, spt.inner.Resource)
734 // InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
735 func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
736 if spt.refreshCallbacks != nil {
737 for _, callback := range spt.refreshCallbacks {
738 err := callback(spt.inner.Token)
740 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
747 // Refresh obtains a fresh token for the Service Principal.
748 // This method is not safe for concurrent use and should be syncrhonized.
749 func (spt *ServicePrincipalToken) Refresh() error {
750 return spt.RefreshWithContext(context.Background())
753 // RefreshWithContext obtains a fresh token for the Service Principal.
754 // This method is not safe for concurrent use and should be syncrhonized.
755 func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
756 spt.refreshLock.Lock()
757 defer spt.refreshLock.Unlock()
758 return spt.refreshInternal(ctx, spt.inner.Resource)
761 // RefreshExchange refreshes the token, but for a different resource.
762 // This method is not safe for concurrent use and should be syncrhonized.
763 func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
764 return spt.RefreshExchangeWithContext(context.Background(), resource)
767 // RefreshExchangeWithContext refreshes the token, but for a different resource.
768 // This method is not safe for concurrent use and should be syncrhonized.
769 func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
770 spt.refreshLock.Lock()
771 defer spt.refreshLock.Unlock()
772 return spt.refreshInternal(ctx, resource)
775 func (spt *ServicePrincipalToken) getGrantType() string {
776 switch spt.inner.Secret.(type) {
777 case *ServicePrincipalUsernamePasswordSecret:
778 return OAuthGrantTypeUserPass
779 case *ServicePrincipalAuthorizationCodeSecret:
780 return OAuthGrantTypeAuthorizationCode
782 return OAuthGrantTypeClientCredentials
786 func isIMDS(u url.URL) bool {
787 imds, err := url.Parse(msiEndpoint)
791 return u.Host == imds.Host && u.Path == imds.Path
794 func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
795 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
797 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
799 req.Header.Add("User-Agent", UserAgent())
800 req = req.WithContext(ctx)
801 if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
803 v.Set("client_id", spt.inner.ClientID)
804 v.Set("resource", resource)
806 if spt.inner.Token.RefreshToken != "" {
807 v.Set("grant_type", OAuthGrantTypeRefreshToken)
808 v.Set("refresh_token", spt.inner.Token.RefreshToken)
809 // web apps must specify client_secret when refreshing tokens
810 // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
811 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
812 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
818 v.Set("grant_type", spt.getGrantType())
819 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
826 body := ioutil.NopCloser(strings.NewReader(s))
827 req.ContentLength = int64(len(s))
828 req.Header.Set(contentType, mimeTypeFormPost)
832 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
833 req.Method = http.MethodGet
834 req.Header.Set(metadataHeader, "true")
837 var resp *http.Response
838 if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
839 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
841 resp, err = spt.sender.Do(req)
844 return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
847 defer resp.Body.Close()
848 rb, err := ioutil.ReadAll(resp.Body)
850 if resp.StatusCode != http.StatusOK {
852 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
854 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
857 // for the following error cases don't return a TokenRefreshError. the operation succeeded
858 // but some transient failure happened during deserialization. by returning a generic error
859 // the retry logic will kick in (we don't retry on TokenRefreshError).
862 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
864 if len(strings.Trim(string(rb), " ")) == 0 {
865 return fmt.Errorf("adal: Empty service principal token received during refresh")
868 err = json.Unmarshal(rb, &token)
870 return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
873 spt.inner.Token = token
875 return spt.InvokeRefreshCallbacks(token)
878 // retry logic specific to retrieving a token from the IMDS endpoint
879 func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
880 // copied from client.go due to circular dependency
882 http.StatusRequestTimeout, // 408
883 http.StatusTooManyRequests, // 429
884 http.StatusInternalServerError, // 500
885 http.StatusBadGateway, // 502
886 http.StatusServiceUnavailable, // 503
887 http.StatusGatewayTimeout, // 504
889 // extra retry status codes specific to IMDS
890 retries = append(retries,
894 http.StatusNotImplemented,
895 http.StatusHTTPVersionNotSupported,
896 http.StatusVariantAlsoNegotiates,
897 http.StatusInsufficientStorage,
898 http.StatusLoopDetected,
899 http.StatusNotExtended,
900 http.StatusNetworkAuthenticationRequired)
902 // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
904 const maxDelay time.Duration = 60 * time.Second
907 delay := time.Duration(0)
909 for attempt < maxAttempts {
910 resp, err = sender.Do(req)
911 // retry on temporary network errors, e.g. transient network failures.
912 // if we don't receive a response then assume we can't connect to the
913 // endpoint so we're likely not running on an Azure VM so don't retry.
914 if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
918 // perform exponential backoff with a cap.
919 // must increment attempt before calculating delay.
921 // the base value of 2 is the "delta backoff" as specified in the guidance doc
922 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
923 if delay > maxDelay {
928 case <-time.After(delay):
929 // intentionally left blank
930 case <-req.Context().Done():
931 err = req.Context().Err()
938 // returns true if the specified error is a temporary network error or false if it's not.
939 // if the error doesn't implement the net.Error interface the return value is true.
940 func isTemporaryNetworkError(err error) bool {
941 if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
947 // returns true if slice ints contains the value n
948 func containsInt(ints []int, n int) bool {
949 for _, i := range ints {
957 // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
958 func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
959 spt.inner.AutoRefresh = autoRefresh
962 // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
963 // refresh the token.
964 func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
965 spt.inner.RefreshWithin = d
969 // SetSender sets the http.Client used when obtaining the Service Principal token. An
970 // undecorated http.Client is used by default.
971 func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
973 // OAuthToken implements the OAuthTokenProvider interface. It returns the current access token.
974 func (spt *ServicePrincipalToken) OAuthToken() string {
975 spt.refreshLock.RLock()
976 defer spt.refreshLock.RUnlock()
977 return spt.inner.Token.OAuthToken()
980 // Token returns a copy of the current token.
981 func (spt *ServicePrincipalToken) Token() Token {
982 spt.refreshLock.RLock()
983 defer spt.refreshLock.RUnlock()
984 return spt.inner.Token