Code refactoring for bpa operator
[icn.git] / cmd / bpa-operator / vendor / github.com / emicklei / go-restful / cors_filter.go
1 package restful
2
3 // Copyright 2013 Ernest Micklei. All rights reserved.
4 // Use of this source code is governed by a license
5 // that can be found in the LICENSE file.
6
7 import (
8         "regexp"
9         "strconv"
10         "strings"
11 )
12
13 // CrossOriginResourceSharing is used to create a Container Filter that implements CORS.
14 // Cross-origin resource sharing (CORS) is a mechanism that allows JavaScript on a web page
15 // to make XMLHttpRequests to another domain, not the domain the JavaScript originated from.
16 //
17 // http://en.wikipedia.org/wiki/Cross-origin_resource_sharing
18 // http://enable-cors.org/server.html
19 // http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
20 type CrossOriginResourceSharing struct {
21         ExposeHeaders  []string // list of Header names
22         AllowedHeaders []string // list of Header names
23         AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
24         AllowedMethods []string
25         MaxAge         int // number of seconds before requiring new Options request
26         CookiesAllowed bool
27         Container      *Container
28
29         allowedOriginPatterns []*regexp.Regexp // internal field for origin regexp check.
30 }
31
32 // Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html
33 // and http://www.html5rocks.com/static/images/cors_server_flowchart.png
34 func (c CrossOriginResourceSharing) Filter(req *Request, resp *Response, chain *FilterChain) {
35         origin := req.Request.Header.Get(HEADER_Origin)
36         if len(origin) == 0 {
37                 if trace {
38                         traceLogger.Print("no Http header Origin set")
39                 }
40                 chain.ProcessFilter(req, resp)
41                 return
42         }
43         if !c.isOriginAllowed(origin) { // check whether this origin is allowed
44                 if trace {
45                         traceLogger.Printf("HTTP Origin:%s is not part of %v, neither matches any part of %v", origin, c.AllowedDomains, c.allowedOriginPatterns)
46                 }
47                 chain.ProcessFilter(req, resp)
48                 return
49         }
50         if req.Request.Method != "OPTIONS" {
51                 c.doActualRequest(req, resp)
52                 chain.ProcessFilter(req, resp)
53                 return
54         }
55         if acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod); acrm != "" {
56                 c.doPreflightRequest(req, resp)
57         } else {
58                 c.doActualRequest(req, resp)
59                 chain.ProcessFilter(req, resp)
60                 return
61         }
62 }
63
64 func (c CrossOriginResourceSharing) doActualRequest(req *Request, resp *Response) {
65         c.setOptionsHeaders(req, resp)
66         // continue processing the response
67 }
68
69 func (c *CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response) {
70         if len(c.AllowedMethods) == 0 {
71                 if c.Container == nil {
72                         c.AllowedMethods = DefaultContainer.computeAllowedMethods(req)
73                 } else {
74                         c.AllowedMethods = c.Container.computeAllowedMethods(req)
75                 }
76         }
77
78         acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod)
79         if !c.isValidAccessControlRequestMethod(acrm, c.AllowedMethods) {
80                 if trace {
81                         traceLogger.Printf("Http header %s:%s is not in %v",
82                                 HEADER_AccessControlRequestMethod,
83                                 acrm,
84                                 c.AllowedMethods)
85                 }
86                 return
87         }
88         acrhs := req.Request.Header.Get(HEADER_AccessControlRequestHeaders)
89         if len(acrhs) > 0 {
90                 for _, each := range strings.Split(acrhs, ",") {
91                         if !c.isValidAccessControlRequestHeader(strings.Trim(each, " ")) {
92                                 if trace {
93                                         traceLogger.Printf("Http header %s:%s is not in %v",
94                                                 HEADER_AccessControlRequestHeaders,
95                                                 acrhs,
96                                                 c.AllowedHeaders)
97                                 }
98                                 return
99                         }
100                 }
101         }
102         resp.AddHeader(HEADER_AccessControlAllowMethods, strings.Join(c.AllowedMethods, ","))
103         resp.AddHeader(HEADER_AccessControlAllowHeaders, acrhs)
104         c.setOptionsHeaders(req, resp)
105
106         // return http 200 response, no body
107 }
108
109 func (c CrossOriginResourceSharing) setOptionsHeaders(req *Request, resp *Response) {
110         c.checkAndSetExposeHeaders(resp)
111         c.setAllowOriginHeader(req, resp)
112         c.checkAndSetAllowCredentials(resp)
113         if c.MaxAge > 0 {
114                 resp.AddHeader(HEADER_AccessControlMaxAge, strconv.Itoa(c.MaxAge))
115         }
116 }
117
118 func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
119         if len(origin) == 0 {
120                 return false
121         }
122         if len(c.AllowedDomains) == 0 {
123                 return true
124         }
125
126         allowed := false
127         for _, domain := range c.AllowedDomains {
128                 if domain == origin {
129                         allowed = true
130                         break
131                 }
132         }
133
134         if !allowed {
135                 if len(c.allowedOriginPatterns) == 0 {
136                         // compile allowed domains to allowed origin patterns
137                         allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
138                         if err != nil {
139                                 return false
140                         }
141                         c.allowedOriginPatterns = allowedOriginRegexps
142                 }
143
144                 for _, pattern := range c.allowedOriginPatterns {
145                         if allowed = pattern.MatchString(origin); allowed {
146                                 break
147                         }
148                 }
149         }
150
151         return allowed
152 }
153
154 func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
155         origin := req.Request.Header.Get(HEADER_Origin)
156         if c.isOriginAllowed(origin) {
157                 resp.AddHeader(HEADER_AccessControlAllowOrigin, origin)
158         }
159 }
160
161 func (c CrossOriginResourceSharing) checkAndSetExposeHeaders(resp *Response) {
162         if len(c.ExposeHeaders) > 0 {
163                 resp.AddHeader(HEADER_AccessControlExposeHeaders, strings.Join(c.ExposeHeaders, ","))
164         }
165 }
166
167 func (c CrossOriginResourceSharing) checkAndSetAllowCredentials(resp *Response) {
168         if c.CookiesAllowed {
169                 resp.AddHeader(HEADER_AccessControlAllowCredentials, "true")
170         }
171 }
172
173 func (c CrossOriginResourceSharing) isValidAccessControlRequestMethod(method string, allowedMethods []string) bool {
174         for _, each := range allowedMethods {
175                 if each == method {
176                         return true
177                 }
178         }
179         return false
180 }
181
182 func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header string) bool {
183         for _, each := range c.AllowedHeaders {
184                 if strings.ToLower(each) == strings.ToLower(header) {
185                         return true
186                 }
187         }
188         return false
189 }
190
191 // Take a list of strings and compile them into a list of regular expressions.
192 func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
193         regexps := []*regexp.Regexp{}
194         for _, regexpStr := range regexpStrings {
195                 r, err := regexp.Compile(regexpStr)
196                 if err != nil {
197                         return regexps, err
198                 }
199                 regexps = append(regexps, r)
200         }
201         return regexps, nil
202 }