github.com/vicanso/pike@v1.0.1-0.20210630235453-9099e041f6ec/location/location.go (about) 1 // MIT License 2 3 // Copyright (c) 2020 Tree Xie 4 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 12 // The above copyright notice and this permission notice shall be included in all 13 // copies or substantial portions of the Software. 14 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 // SOFTWARE. 22 23 // Location相关处理函数,根据host,url判断当前请求所属location 24 25 package location 26 27 import ( 28 "net/http" 29 "net/url" 30 "os" 31 "regexp" 32 "sort" 33 "strconv" 34 "strings" 35 "sync" 36 "time" 37 38 "github.com/vicanso/pike/config" 39 "github.com/vicanso/pike/log" 40 "go.uber.org/atomic" 41 "go.uber.org/zap" 42 ) 43 44 // Location location config 45 type ( 46 Location struct { 47 Name string 48 Upstream string 49 Prefixes []string 50 Rewrites []string 51 Hosts []string 52 // Querystrings []string 53 ProxyTimeout time.Duration 54 ResponseHeader http.Header 55 RequestHeader http.Header 56 Query url.Values 57 URLRewriter Rewriter 58 priority atomic.Int32 59 } 60 rewriteRegexp struct { 61 Regexp *regexp.Regexp 62 Value string 63 } 64 ) 65 type Rewriter func(req *http.Request) 66 67 // Locations location list 68 type Locations struct { 69 mutex *sync.RWMutex 70 locations []*Location 71 } 72 73 var defaultLocations = NewLocations() 74 75 func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { 76 groups := pattern.FindAllStringSubmatch(input, -1) 77 if groups == nil { 78 return nil 79 } 80 values := groups[0][1:] 81 replace := make([]string, 2*len(values)) 82 for i, v := range values { 83 j := 2 * i 84 replace[j] = "$" + strconv.Itoa(i+1) 85 replace[j+1] = v 86 } 87 return strings.NewReplacer(replace...) 88 } 89 90 // generateURLRewriter generate url rewriter 91 func generateURLRewriter(arr []string) Rewriter { 92 size := len(arr) 93 if size == 0 { 94 return nil 95 } 96 rewrites := make([]*rewriteRegexp, 0, size) 97 98 for _, value := range arr { 99 arr := strings.Split(value, ":") 100 if len(arr) != 2 { 101 continue 102 } 103 k := arr[0] 104 v := arr[1] 105 k = strings.Replace(k, "*", "(\\S*)", -1) 106 reg, err := regexp.Compile(k) 107 if err != nil { 108 log.Default().Error("rewrite compile error", 109 zap.String("value", k), 110 zap.Error(err), 111 ) 112 continue 113 } 114 rewrites = append(rewrites, &rewriteRegexp{ 115 Regexp: reg, 116 Value: v, 117 }) 118 } 119 if len(rewrites) == 0 { 120 return nil 121 } 122 return func(req *http.Request) { 123 urlPath := req.URL.Path 124 for _, rewrite := range rewrites { 125 replacer := captureTokens(rewrite.Regexp, urlPath) 126 if replacer != nil { 127 urlPath = replacer.Replace(rewrite.Value) 128 } 129 } 130 req.URL.Path = urlPath 131 } 132 } 133 134 // Match check location's hosts and prefixes match host/url 135 func (l *Location) Match(host, url string) bool { 136 if len(l.Hosts) != 0 { 137 found := false 138 for _, item := range l.Hosts { 139 if item == host { 140 found = true 141 break 142 } 143 } 144 if !found { 145 return false 146 } 147 } 148 if len(l.Prefixes) != 0 { 149 found := false 150 for _, item := range l.Prefixes { 151 if strings.HasPrefix(url, item) { 152 found = true 153 break 154 } 155 } 156 if !found { 157 return false 158 } 159 } 160 return true 161 } 162 163 func (l *Location) mergeHeader(dst, src http.Header) { 164 for key, values := range src { 165 for _, value := range values { 166 dst.Add(key, value) 167 } 168 } 169 } 170 171 // AddRequestHeader add request header 172 func (l *Location) AddRequestHeader(header http.Header) { 173 l.mergeHeader(header, l.RequestHeader) 174 } 175 176 // AddResponseHeader add response header 177 func (l *Location) AddResponseHeader(header http.Header) { 178 l.mergeHeader(header, l.ResponseHeader) 179 } 180 181 // ShouldModifyQuery should modify query 182 func (l *Location) ShouldModifyQuery() bool { 183 return len(l.Query) != 0 184 } 185 186 // AddQuery add query to request 187 func (l *Location) AddQuery(req *http.Request) { 188 query := req.URL.Query() 189 for key, values := range l.Query { 190 for _, value := range values { 191 query.Add(key, value) 192 } 193 } 194 req.URL.RawQuery = query.Encode() 195 } 196 197 func (l *Location) getPriority() int { 198 priority := l.priority.Load() 199 if priority != 0 { 200 return int(priority) 201 } 202 // 默认设置为8 203 priority = 8 204 if len(l.Prefixes) != 0 { 205 priority -= 4 206 } 207 if len(l.Hosts) != 0 { 208 priority -= 2 209 } 210 l.priority.Store(priority) 211 return int(priority) 212 } 213 214 // NewLocations new a location list 215 func NewLocations(opts ...Location) *Locations { 216 ls := &Locations{ 217 mutex: &sync.RWMutex{}, 218 } 219 ls.Set(opts) 220 return ls 221 } 222 223 // Set set location list 224 func (ls *Locations) Set(locations []Location) { 225 data := make([]*Location, len(locations)) 226 for index := range locations { 227 // 需要注意,golang 的range 返回的item是复用同一块内存的, 228 // 需要对数据另外获取保存,因此使用index来获取元素 229 p := &locations[index] 230 p.URLRewriter = generateURLRewriter(p.Rewrites) 231 data[index] = p 232 } 233 234 // Sort sort locations 235 sort.Slice(data, func(i, j int) bool { 236 return data[i].getPriority() < data[j].getPriority() 237 }) 238 ls.mutex.Lock() 239 defer ls.mutex.Unlock() 240 ls.locations = data 241 } 242 243 // GetLocations get locations 244 func (ls *Locations) GetLocations() []*Location { 245 ls.mutex.RLock() 246 defer ls.mutex.RUnlock() 247 locations := ls.locations 248 return locations 249 } 250 251 // Get get match location 252 func (ls *Locations) Get(host, url string, names ...string) *Location { 253 locations := ls.GetLocations() 254 for _, item := range locations { 255 for _, name := range names { 256 if item.Name == name && item.Match(host, url) { 257 return item 258 } 259 } 260 } 261 return nil 262 } 263 264 // enhanceGetValue 如果以$开头,则优先从env中获取,如果获取失败,则直接返回原值 265 func enhanceGetValue(key string) string { 266 if strings.HasPrefix(key, "$") { 267 return os.Getenv(key[1:]) 268 } 269 return key 270 } 271 272 func convertConfigs(configs []config.LocationConfig) []Location { 273 locations := make([]Location, 0) 274 fn := func(arr []string) http.Header { 275 h := make(http.Header) 276 for _, value := range arr { 277 arr := strings.Split(value, ":") 278 if len(arr) != 2 { 279 continue 280 } 281 h.Add(enhanceGetValue(arr[0]), enhanceGetValue(arr[1])) 282 } 283 return h 284 } 285 // 将配置转换为header与url.values 286 for _, item := range configs { 287 d, _ := time.ParseDuration(item.ProxyTimeout) 288 l := Location{ 289 Name: item.Name, 290 Upstream: item.Upstream, 291 Prefixes: item.Prefixes, 292 Rewrites: item.Rewrites, 293 Hosts: item.Hosts, 294 ProxyTimeout: d, 295 } 296 l.ResponseHeader = fn(item.RespHeaders) 297 l.RequestHeader = fn(item.ReqHeaders) 298 if len(item.QueryStrings) != 0 { 299 query := make(url.Values) 300 for _, str := range item.QueryStrings { 301 arr := strings.Split(str, ":") 302 if len(arr) != 2 { 303 continue 304 } 305 query.Add(enhanceGetValue(arr[0]), enhanceGetValue(arr[1])) 306 } 307 l.Query = query 308 } 309 310 locations = append(locations, l) 311 } 312 return locations 313 } 314 315 // Reset reset location list to default 316 func Reset(configs []config.LocationConfig) { 317 defaultLocations.Set(convertConfigs(configs)) 318 } 319 320 // Get get location form default locations 321 func Get(host, url string, names ...string) *Location { 322 return defaultLocations.Get(host, url, names...) 323 }