2 Copyright 2017 The Kubernetes Authors.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
27 "github.com/Azure/go-autorest/autorest"
28 "github.com/Azure/go-autorest/autorest/adal"
29 "github.com/Azure/go-autorest/autorest/azure"
32 "k8s.io/apimachinery/pkg/util/net"
33 restclient "k8s.io/client-go/rest"
37 azureTokenKey = "azureTokenKey"
39 authHeader = "Authorization"
41 cfgClientID = "client-id"
42 cfgTenantID = "tenant-id"
43 cfgAccessToken = "access-token"
44 cfgRefreshToken = "refresh-token"
45 cfgExpiresIn = "expires-in"
46 cfgExpiresOn = "expires-on"
47 cfgEnvironment = "environment"
48 cfgApiserverID = "apiserver-id"
52 if err := restclient.RegisterAuthProviderPlugin("azure", newAzureAuthProvider); err != nil {
53 klog.Fatalf("Failed to register azure auth plugin: %v", err)
57 var cache = newAzureTokenCache()
59 type azureTokenCache struct {
61 cache map[string]*azureToken
64 func newAzureTokenCache() *azureTokenCache {
65 return &azureTokenCache{cache: make(map[string]*azureToken)}
68 func (c *azureTokenCache) getToken(tokenKey string) *azureToken {
71 return c.cache[tokenKey]
74 func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
77 c.cache[tokenKey] = token
80 func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
83 environment, err := azure.EnvironmentFromName(cfg[cfgEnvironment])
85 environment = azure.PublicCloud
87 ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID])
89 return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
91 cacheSource := newAzureTokenSource(ts, cache, cfg, persister)
93 return &azureAuthProvider{
94 tokenSource: cacheSource,
98 type azureAuthProvider struct {
99 tokenSource tokenSource
102 func (p *azureAuthProvider) Login() error {
103 return errors.New("not yet implemented")
106 func (p *azureAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
107 return &azureRoundTripper{
108 tokenSource: p.tokenSource,
113 type azureRoundTripper struct {
114 tokenSource tokenSource
115 roundTripper http.RoundTripper
118 var _ net.RoundTripperWrapper = &azureRoundTripper{}
120 func (r *azureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
121 if len(req.Header.Get(authHeader)) != 0 {
122 return r.roundTripper.RoundTrip(req)
125 token, err := r.tokenSource.Token()
127 klog.Errorf("Failed to acquire a token: %v", err)
128 return nil, fmt.Errorf("acquiring a token for authorization header: %v", err)
131 // clone the request in order to avoid modifying the headers of the original request
132 req2 := new(http.Request)
134 req2.Header = make(http.Header, len(req.Header))
135 for k, s := range req.Header {
136 req2.Header[k] = append([]string(nil), s...)
139 req2.Header.Set(authHeader, fmt.Sprintf("%s %s", tokenType, token.token.AccessToken))
141 return r.roundTripper.RoundTrip(req2)
144 func (r *azureRoundTripper) WrappedRoundTripper() http.RoundTripper { return r.roundTripper }
146 type azureToken struct {
153 type tokenSource interface {
154 Token() (*azureToken, error)
157 type azureTokenSource struct {
159 cache *azureTokenCache
161 cfg map[string]string
162 persister restclient.AuthProviderConfigPersister
165 func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, persister restclient.AuthProviderConfigPersister) tokenSource {
166 return &azureTokenSource{
170 persister: persister,
174 // Token fetches a token from the cache of configuration if present otherwise
175 // acquires a new token from the configured source. Automatically refreshes
176 // the token if expired.
177 func (ts *azureTokenSource) Token() (*azureToken, error) {
179 defer ts.lock.Unlock()
182 token := ts.cache.getToken(azureTokenKey)
184 token, err = ts.retrieveTokenFromCfg()
186 token, err = ts.source.Token()
188 return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
191 if !token.token.IsExpired() {
192 ts.cache.setToken(azureTokenKey, token)
193 err = ts.storeTokenInCfg(token)
195 return nil, fmt.Errorf("storing the token in configuration: %v", err)
199 if token.token.IsExpired() {
200 token, err = ts.refreshToken(token)
202 return nil, fmt.Errorf("refreshing the expired token: %v", err)
204 ts.cache.setToken(azureTokenKey, token)
205 err = ts.storeTokenInCfg(token)
207 return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
213 func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
214 accessToken := ts.cfg[cfgAccessToken]
215 if accessToken == "" {
216 return nil, fmt.Errorf("no access token in cfg: %s", cfgAccessToken)
218 refreshToken := ts.cfg[cfgRefreshToken]
219 if refreshToken == "" {
220 return nil, fmt.Errorf("no refresh token in cfg: %s", cfgRefreshToken)
222 clientID := ts.cfg[cfgClientID]
224 return nil, fmt.Errorf("no client ID in cfg: %s", cfgClientID)
226 tenantID := ts.cfg[cfgTenantID]
228 return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
230 apiserverID := ts.cfg[cfgApiserverID]
231 if apiserverID == "" {
232 return nil, fmt.Errorf("no apiserver ID in cfg: %s", apiserverID)
234 expiresIn := ts.cfg[cfgExpiresIn]
236 return nil, fmt.Errorf("no expiresIn in cfg: %s", cfgExpiresIn)
238 expiresOn := ts.cfg[cfgExpiresOn]
240 return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
245 AccessToken: accessToken,
246 RefreshToken: refreshToken,
247 ExpiresIn: json.Number(expiresIn),
248 ExpiresOn: json.Number(expiresOn),
249 NotBefore: json.Number(expiresOn),
250 Resource: fmt.Sprintf("spn:%s", apiserverID),
255 apiserverID: apiserverID,
259 func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
260 newCfg := make(map[string]string)
261 newCfg[cfgAccessToken] = token.token.AccessToken
262 newCfg[cfgRefreshToken] = token.token.RefreshToken
263 newCfg[cfgClientID] = token.clientID
264 newCfg[cfgTenantID] = token.tenantID
265 newCfg[cfgApiserverID] = token.apiserverID
266 newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
267 newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
269 err := ts.persister.Persist(newCfg)
271 return fmt.Errorf("persisting the configuration: %v", err)
277 func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
278 oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, token.tenantID)
280 return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
283 callback := func(t adal.Token) error {
286 spt, err := adal.NewServicePrincipalTokenFromManualToken(
293 return nil, fmt.Errorf("creating new service principal for token refresh: %v", err)
296 if err := spt.Refresh(); err != nil {
297 return nil, fmt.Errorf("refreshing token: %v", err)
302 clientID: token.clientID,
303 tenantID: token.tenantID,
304 apiserverID: token.apiserverID,
308 type azureTokenSourceDeviceCode struct {
309 environment azure.Environment
315 func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string) (tokenSource, error) {
317 return nil, errors.New("client-id is empty")
320 return nil, errors.New("tenant-id is empty")
322 if apiserverID == "" {
323 return nil, errors.New("apiserver-id is empty")
325 return &azureTokenSourceDeviceCode{
326 environment: environment,
329 apiserverID: apiserverID,
333 func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
334 oauthConfig, err := adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
336 return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
338 client := &autorest.Client{}
339 deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
341 return nil, fmt.Errorf("initialing the device code authentication: %v", err)
344 _, err = fmt.Fprintln(os.Stderr, *deviceCode.Message)
346 return nil, fmt.Errorf("prompting the device code message: %v", err)
349 token, err := adal.WaitForUserCompletion(client, deviceCode)
351 return nil, fmt.Errorf("waiting for device code authentication to complete: %v", err)
356 clientID: ts.clientID,
357 tenantID: ts.tenantID,
358 apiserverID: ts.apiserverID,