github.com/vicanso/pike@v1.0.1-0.20210630235453-9099e041f6ec/location/location.go (about)

     1  // MIT License
     2  
     3  // Copyright (c) 2020 Tree Xie
     4  
     5  // Permission is hereby granted, free of charge, to any person obtaining a copy
     6  // of this software and associated documentation files (the "Software"), to deal
     7  // in the Software without restriction, including without limitation the rights
     8  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     9  // copies of the Software, and to permit persons to whom the Software is
    10  // furnished to do so, subject to the following conditions:
    11  
    12  // The above copyright notice and this permission notice shall be included in all
    13  // copies or substantial portions of the Software.
    14  
    15  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    16  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    17  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    18  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    19  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    20  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    21  // SOFTWARE.
    22  
    23  // Location相关处理函数,根据host,url判断当前请求所属location
    24  
    25  package location
    26  
    27  import (
    28  	"net/http"
    29  	"net/url"
    30  	"os"
    31  	"regexp"
    32  	"sort"
    33  	"strconv"
    34  	"strings"
    35  	"sync"
    36  	"time"
    37  
    38  	"github.com/vicanso/pike/config"
    39  	"github.com/vicanso/pike/log"
    40  	"go.uber.org/atomic"
    41  	"go.uber.org/zap"
    42  )
    43  
    44  // Location location config
    45  type (
    46  	Location struct {
    47  		Name     string
    48  		Upstream string
    49  		Prefixes []string
    50  		Rewrites []string
    51  		Hosts    []string
    52  		// Querystrings   []string
    53  		ProxyTimeout   time.Duration
    54  		ResponseHeader http.Header
    55  		RequestHeader  http.Header
    56  		Query          url.Values
    57  		URLRewriter    Rewriter
    58  		priority       atomic.Int32
    59  	}
    60  	rewriteRegexp struct {
    61  		Regexp *regexp.Regexp
    62  		Value  string
    63  	}
    64  )
    65  type Rewriter func(req *http.Request)
    66  
    67  // Locations location list
    68  type Locations struct {
    69  	mutex     *sync.RWMutex
    70  	locations []*Location
    71  }
    72  
    73  var defaultLocations = NewLocations()
    74  
    75  func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
    76  	groups := pattern.FindAllStringSubmatch(input, -1)
    77  	if groups == nil {
    78  		return nil
    79  	}
    80  	values := groups[0][1:]
    81  	replace := make([]string, 2*len(values))
    82  	for i, v := range values {
    83  		j := 2 * i
    84  		replace[j] = "$" + strconv.Itoa(i+1)
    85  		replace[j+1] = v
    86  	}
    87  	return strings.NewReplacer(replace...)
    88  }
    89  
    90  // generateURLRewriter generate url rewriter
    91  func generateURLRewriter(arr []string) Rewriter {
    92  	size := len(arr)
    93  	if size == 0 {
    94  		return nil
    95  	}
    96  	rewrites := make([]*rewriteRegexp, 0, size)
    97  
    98  	for _, value := range arr {
    99  		arr := strings.Split(value, ":")
   100  		if len(arr) != 2 {
   101  			continue
   102  		}
   103  		k := arr[0]
   104  		v := arr[1]
   105  		k = strings.Replace(k, "*", "(\\S*)", -1)
   106  		reg, err := regexp.Compile(k)
   107  		if err != nil {
   108  			log.Default().Error("rewrite compile error",
   109  				zap.String("value", k),
   110  				zap.Error(err),
   111  			)
   112  			continue
   113  		}
   114  		rewrites = append(rewrites, &rewriteRegexp{
   115  			Regexp: reg,
   116  			Value:  v,
   117  		})
   118  	}
   119  	if len(rewrites) == 0 {
   120  		return nil
   121  	}
   122  	return func(req *http.Request) {
   123  		urlPath := req.URL.Path
   124  		for _, rewrite := range rewrites {
   125  			replacer := captureTokens(rewrite.Regexp, urlPath)
   126  			if replacer != nil {
   127  				urlPath = replacer.Replace(rewrite.Value)
   128  			}
   129  		}
   130  		req.URL.Path = urlPath
   131  	}
   132  }
   133  
   134  // Match check location's hosts and prefixes match host/url
   135  func (l *Location) Match(host, url string) bool {
   136  	if len(l.Hosts) != 0 {
   137  		found := false
   138  		for _, item := range l.Hosts {
   139  			if item == host {
   140  				found = true
   141  				break
   142  			}
   143  		}
   144  		if !found {
   145  			return false
   146  		}
   147  	}
   148  	if len(l.Prefixes) != 0 {
   149  		found := false
   150  		for _, item := range l.Prefixes {
   151  			if strings.HasPrefix(url, item) {
   152  				found = true
   153  				break
   154  			}
   155  		}
   156  		if !found {
   157  			return false
   158  		}
   159  	}
   160  	return true
   161  }
   162  
   163  func (l *Location) mergeHeader(dst, src http.Header) {
   164  	for key, values := range src {
   165  		for _, value := range values {
   166  			dst.Add(key, value)
   167  		}
   168  	}
   169  }
   170  
   171  // AddRequestHeader add request header
   172  func (l *Location) AddRequestHeader(header http.Header) {
   173  	l.mergeHeader(header, l.RequestHeader)
   174  }
   175  
   176  // AddResponseHeader add response header
   177  func (l *Location) AddResponseHeader(header http.Header) {
   178  	l.mergeHeader(header, l.ResponseHeader)
   179  }
   180  
   181  // ShouldModifyQuery should modify query
   182  func (l *Location) ShouldModifyQuery() bool {
   183  	return len(l.Query) != 0
   184  }
   185  
   186  // AddQuery add query to request
   187  func (l *Location) AddQuery(req *http.Request) {
   188  	query := req.URL.Query()
   189  	for key, values := range l.Query {
   190  		for _, value := range values {
   191  			query.Add(key, value)
   192  		}
   193  	}
   194  	req.URL.RawQuery = query.Encode()
   195  }
   196  
   197  func (l *Location) getPriority() int {
   198  	priority := l.priority.Load()
   199  	if priority != 0 {
   200  		return int(priority)
   201  	}
   202  	// 默认设置为8
   203  	priority = 8
   204  	if len(l.Prefixes) != 0 {
   205  		priority -= 4
   206  	}
   207  	if len(l.Hosts) != 0 {
   208  		priority -= 2
   209  	}
   210  	l.priority.Store(priority)
   211  	return int(priority)
   212  }
   213  
   214  // NewLocations new a location list
   215  func NewLocations(opts ...Location) *Locations {
   216  	ls := &Locations{
   217  		mutex: &sync.RWMutex{},
   218  	}
   219  	ls.Set(opts)
   220  	return ls
   221  }
   222  
   223  // Set set location list
   224  func (ls *Locations) Set(locations []Location) {
   225  	data := make([]*Location, len(locations))
   226  	for index := range locations {
   227  		// 需要注意,golang 的range 返回的item是复用同一块内存的,
   228  		// 需要对数据另外获取保存,因此使用index来获取元素
   229  		p := &locations[index]
   230  		p.URLRewriter = generateURLRewriter(p.Rewrites)
   231  		data[index] = p
   232  	}
   233  
   234  	// Sort sort locations
   235  	sort.Slice(data, func(i, j int) bool {
   236  		return data[i].getPriority() < data[j].getPriority()
   237  	})
   238  	ls.mutex.Lock()
   239  	defer ls.mutex.Unlock()
   240  	ls.locations = data
   241  }
   242  
   243  // GetLocations get locations
   244  func (ls *Locations) GetLocations() []*Location {
   245  	ls.mutex.RLock()
   246  	defer ls.mutex.RUnlock()
   247  	locations := ls.locations
   248  	return locations
   249  }
   250  
   251  // Get get match location
   252  func (ls *Locations) Get(host, url string, names ...string) *Location {
   253  	locations := ls.GetLocations()
   254  	for _, item := range locations {
   255  		for _, name := range names {
   256  			if item.Name == name && item.Match(host, url) {
   257  				return item
   258  			}
   259  		}
   260  	}
   261  	return nil
   262  }
   263  
   264  // enhanceGetValue 如果以$开头,则优先从env中获取,如果获取失败,则直接返回原值
   265  func enhanceGetValue(key string) string {
   266  	if strings.HasPrefix(key, "$") {
   267  		return os.Getenv(key[1:])
   268  	}
   269  	return key
   270  }
   271  
   272  func convertConfigs(configs []config.LocationConfig) []Location {
   273  	locations := make([]Location, 0)
   274  	fn := func(arr []string) http.Header {
   275  		h := make(http.Header)
   276  		for _, value := range arr {
   277  			arr := strings.Split(value, ":")
   278  			if len(arr) != 2 {
   279  				continue
   280  			}
   281  			h.Add(enhanceGetValue(arr[0]), enhanceGetValue(arr[1]))
   282  		}
   283  		return h
   284  	}
   285  	// 将配置转换为header与url.values
   286  	for _, item := range configs {
   287  		d, _ := time.ParseDuration(item.ProxyTimeout)
   288  		l := Location{
   289  			Name:         item.Name,
   290  			Upstream:     item.Upstream,
   291  			Prefixes:     item.Prefixes,
   292  			Rewrites:     item.Rewrites,
   293  			Hosts:        item.Hosts,
   294  			ProxyTimeout: d,
   295  		}
   296  		l.ResponseHeader = fn(item.RespHeaders)
   297  		l.RequestHeader = fn(item.ReqHeaders)
   298  		if len(item.QueryStrings) != 0 {
   299  			query := make(url.Values)
   300  			for _, str := range item.QueryStrings {
   301  				arr := strings.Split(str, ":")
   302  				if len(arr) != 2 {
   303  					continue
   304  				}
   305  				query.Add(enhanceGetValue(arr[0]), enhanceGetValue(arr[1]))
   306  			}
   307  			l.Query = query
   308  		}
   309  
   310  		locations = append(locations, l)
   311  	}
   312  	return locations
   313  }
   314  
   315  // Reset reset location list to default
   316  func Reset(configs []config.LocationConfig) {
   317  	defaultLocations.Set(convertConfigs(configs))
   318  }
   319  
   320  // Get get location form default locations
   321  func Get(host, url string, names ...string) *Location {
   322  	return defaultLocations.Get(host, url, names...)
   323  }