Remove BPA from Makefile
[icn.git] / cmd / bpa-operator / vendor / k8s.io / client-go / plugin / pkg / client / auth / azure / azure.go
1 /*
2 Copyright 2017 The Kubernetes Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8     http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 package azure
18
19 import (
20         "encoding/json"
21         "errors"
22         "fmt"
23         "net/http"
24         "os"
25         "sync"
26
27         "github.com/Azure/go-autorest/autorest"
28         "github.com/Azure/go-autorest/autorest/adal"
29         "github.com/Azure/go-autorest/autorest/azure"
30         "k8s.io/klog"
31
32         "k8s.io/apimachinery/pkg/util/net"
33         restclient "k8s.io/client-go/rest"
34 )
35
36 const (
37         azureTokenKey = "azureTokenKey"
38         tokenType     = "Bearer"
39         authHeader    = "Authorization"
40
41         cfgClientID     = "client-id"
42         cfgTenantID     = "tenant-id"
43         cfgAccessToken  = "access-token"
44         cfgRefreshToken = "refresh-token"
45         cfgExpiresIn    = "expires-in"
46         cfgExpiresOn    = "expires-on"
47         cfgEnvironment  = "environment"
48         cfgApiserverID  = "apiserver-id"
49 )
50
51 func init() {
52         if err := restclient.RegisterAuthProviderPlugin("azure", newAzureAuthProvider); err != nil {
53                 klog.Fatalf("Failed to register azure auth plugin: %v", err)
54         }
55 }
56
57 var cache = newAzureTokenCache()
58
59 type azureTokenCache struct {
60         lock  sync.Mutex
61         cache map[string]*azureToken
62 }
63
64 func newAzureTokenCache() *azureTokenCache {
65         return &azureTokenCache{cache: make(map[string]*azureToken)}
66 }
67
68 func (c *azureTokenCache) getToken(tokenKey string) *azureToken {
69         c.lock.Lock()
70         defer c.lock.Unlock()
71         return c.cache[tokenKey]
72 }
73
74 func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
75         c.lock.Lock()
76         defer c.lock.Unlock()
77         c.cache[tokenKey] = token
78 }
79
80 func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
81         var ts tokenSource
82
83         environment, err := azure.EnvironmentFromName(cfg[cfgEnvironment])
84         if err != nil {
85                 environment = azure.PublicCloud
86         }
87         ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID])
88         if err != nil {
89                 return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
90         }
91         cacheSource := newAzureTokenSource(ts, cache, cfg, persister)
92
93         return &azureAuthProvider{
94                 tokenSource: cacheSource,
95         }, nil
96 }
97
98 type azureAuthProvider struct {
99         tokenSource tokenSource
100 }
101
102 func (p *azureAuthProvider) Login() error {
103         return errors.New("not yet implemented")
104 }
105
106 func (p *azureAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
107         return &azureRoundTripper{
108                 tokenSource:  p.tokenSource,
109                 roundTripper: rt,
110         }
111 }
112
113 type azureRoundTripper struct {
114         tokenSource  tokenSource
115         roundTripper http.RoundTripper
116 }
117
118 var _ net.RoundTripperWrapper = &azureRoundTripper{}
119
120 func (r *azureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
121         if len(req.Header.Get(authHeader)) != 0 {
122                 return r.roundTripper.RoundTrip(req)
123         }
124
125         token, err := r.tokenSource.Token()
126         if err != nil {
127                 klog.Errorf("Failed to acquire a token: %v", err)
128                 return nil, fmt.Errorf("acquiring a token for authorization header: %v", err)
129         }
130
131         // clone the request in order to avoid modifying the headers of the original request
132         req2 := new(http.Request)
133         *req2 = *req
134         req2.Header = make(http.Header, len(req.Header))
135         for k, s := range req.Header {
136                 req2.Header[k] = append([]string(nil), s...)
137         }
138
139         req2.Header.Set(authHeader, fmt.Sprintf("%s %s", tokenType, token.token.AccessToken))
140
141         return r.roundTripper.RoundTrip(req2)
142 }
143
144 func (r *azureRoundTripper) WrappedRoundTripper() http.RoundTripper { return r.roundTripper }
145
146 type azureToken struct {
147         token       adal.Token
148         clientID    string
149         tenantID    string
150         apiserverID string
151 }
152
153 type tokenSource interface {
154         Token() (*azureToken, error)
155 }
156
157 type azureTokenSource struct {
158         source    tokenSource
159         cache     *azureTokenCache
160         lock      sync.Mutex
161         cfg       map[string]string
162         persister restclient.AuthProviderConfigPersister
163 }
164
165 func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, persister restclient.AuthProviderConfigPersister) tokenSource {
166         return &azureTokenSource{
167                 source:    source,
168                 cache:     cache,
169                 cfg:       cfg,
170                 persister: persister,
171         }
172 }
173
174 // Token fetches a token from the cache of configuration if present otherwise
175 // acquires a new token from the configured source. Automatically refreshes
176 // the token if expired.
177 func (ts *azureTokenSource) Token() (*azureToken, error) {
178         ts.lock.Lock()
179         defer ts.lock.Unlock()
180
181         var err error
182         token := ts.cache.getToken(azureTokenKey)
183         if token == nil {
184                 token, err = ts.retrieveTokenFromCfg()
185                 if err != nil {
186                         token, err = ts.source.Token()
187                         if err != nil {
188                                 return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
189                         }
190                 }
191                 if !token.token.IsExpired() {
192                         ts.cache.setToken(azureTokenKey, token)
193                         err = ts.storeTokenInCfg(token)
194                         if err != nil {
195                                 return nil, fmt.Errorf("storing the token in configuration: %v", err)
196                         }
197                 }
198         }
199         if token.token.IsExpired() {
200                 token, err = ts.refreshToken(token)
201                 if err != nil {
202                         return nil, fmt.Errorf("refreshing the expired token: %v", err)
203                 }
204                 ts.cache.setToken(azureTokenKey, token)
205                 err = ts.storeTokenInCfg(token)
206                 if err != nil {
207                         return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
208                 }
209         }
210         return token, nil
211 }
212
213 func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
214         accessToken := ts.cfg[cfgAccessToken]
215         if accessToken == "" {
216                 return nil, fmt.Errorf("no access token in cfg: %s", cfgAccessToken)
217         }
218         refreshToken := ts.cfg[cfgRefreshToken]
219         if refreshToken == "" {
220                 return nil, fmt.Errorf("no refresh token in cfg: %s", cfgRefreshToken)
221         }
222         clientID := ts.cfg[cfgClientID]
223         if clientID == "" {
224                 return nil, fmt.Errorf("no client ID in cfg: %s", cfgClientID)
225         }
226         tenantID := ts.cfg[cfgTenantID]
227         if tenantID == "" {
228                 return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
229         }
230         apiserverID := ts.cfg[cfgApiserverID]
231         if apiserverID == "" {
232                 return nil, fmt.Errorf("no apiserver ID in cfg: %s", apiserverID)
233         }
234         expiresIn := ts.cfg[cfgExpiresIn]
235         if expiresIn == "" {
236                 return nil, fmt.Errorf("no expiresIn in cfg: %s", cfgExpiresIn)
237         }
238         expiresOn := ts.cfg[cfgExpiresOn]
239         if expiresOn == "" {
240                 return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
241         }
242
243         return &azureToken{
244                 token: adal.Token{
245                         AccessToken:  accessToken,
246                         RefreshToken: refreshToken,
247                         ExpiresIn:    json.Number(expiresIn),
248                         ExpiresOn:    json.Number(expiresOn),
249                         NotBefore:    json.Number(expiresOn),
250                         Resource:     fmt.Sprintf("spn:%s", apiserverID),
251                         Type:         tokenType,
252                 },
253                 clientID:    clientID,
254                 tenantID:    tenantID,
255                 apiserverID: apiserverID,
256         }, nil
257 }
258
259 func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
260         newCfg := make(map[string]string)
261         newCfg[cfgAccessToken] = token.token.AccessToken
262         newCfg[cfgRefreshToken] = token.token.RefreshToken
263         newCfg[cfgClientID] = token.clientID
264         newCfg[cfgTenantID] = token.tenantID
265         newCfg[cfgApiserverID] = token.apiserverID
266         newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
267         newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
268
269         err := ts.persister.Persist(newCfg)
270         if err != nil {
271                 return fmt.Errorf("persisting the configuration: %v", err)
272         }
273         ts.cfg = newCfg
274         return nil
275 }
276
277 func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
278         oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, token.tenantID)
279         if err != nil {
280                 return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
281         }
282
283         callback := func(t adal.Token) error {
284                 return nil
285         }
286         spt, err := adal.NewServicePrincipalTokenFromManualToken(
287                 *oauthConfig,
288                 token.clientID,
289                 token.apiserverID,
290                 token.token,
291                 callback)
292         if err != nil {
293                 return nil, fmt.Errorf("creating new service principal for token refresh: %v", err)
294         }
295
296         if err := spt.Refresh(); err != nil {
297                 return nil, fmt.Errorf("refreshing token: %v", err)
298         }
299
300         return &azureToken{
301                 token:       spt.Token(),
302                 clientID:    token.clientID,
303                 tenantID:    token.tenantID,
304                 apiserverID: token.apiserverID,
305         }, nil
306 }
307
308 type azureTokenSourceDeviceCode struct {
309         environment azure.Environment
310         clientID    string
311         tenantID    string
312         apiserverID string
313 }
314
315 func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string) (tokenSource, error) {
316         if clientID == "" {
317                 return nil, errors.New("client-id is empty")
318         }
319         if tenantID == "" {
320                 return nil, errors.New("tenant-id is empty")
321         }
322         if apiserverID == "" {
323                 return nil, errors.New("apiserver-id is empty")
324         }
325         return &azureTokenSourceDeviceCode{
326                 environment: environment,
327                 clientID:    clientID,
328                 tenantID:    tenantID,
329                 apiserverID: apiserverID,
330         }, nil
331 }
332
333 func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
334         oauthConfig, err := adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
335         if err != nil {
336                 return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
337         }
338         client := &autorest.Client{}
339         deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
340         if err != nil {
341                 return nil, fmt.Errorf("initialing the device code authentication: %v", err)
342         }
343
344         _, err = fmt.Fprintln(os.Stderr, *deviceCode.Message)
345         if err != nil {
346                 return nil, fmt.Errorf("prompting the device code message: %v", err)
347         }
348
349         token, err := adal.WaitForUserCompletion(client, deviceCode)
350         if err != nil {
351                 return nil, fmt.Errorf("waiting for device code authentication to complete: %v", err)
352         }
353
354         return &azureToken{
355                 token:       *token,
356                 clientID:    ts.clientID,
357                 tenantID:    ts.tenantID,
358                 apiserverID: ts.apiserverID,
359         }, nil
360 }