github.com/m-lab/locate@v0.17.6/handler/handler.go (about)

     1  // Package handler provides a client and handlers for responding to locate
     2  // requests.
     3  package handler
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"html/template"
    12  	"math/rand"
    13  	"net/http"
    14  	"net/url"
    15  	"path"
    16  	"strconv"
    17  	"strings"
    18  	"time"
    19  
    20  	"github.com/google/uuid"
    21  	log "github.com/sirupsen/logrus"
    22  	"gopkg.in/square/go-jose.v2/jwt"
    23  
    24  	"github.com/m-lab/go/rtx"
    25  	v2 "github.com/m-lab/locate/api/v2"
    26  	"github.com/m-lab/locate/clientgeo"
    27  	"github.com/m-lab/locate/heartbeat"
    28  	"github.com/m-lab/locate/limits"
    29  	"github.com/m-lab/locate/metrics"
    30  	"github.com/m-lab/locate/siteinfo"
    31  	"github.com/m-lab/locate/static"
    32  	prom "github.com/prometheus/client_golang/api/prometheus/v1"
    33  	"github.com/prometheus/common/model"
    34  )
    35  
    36  var (
    37  	errFailedToLookupClient = errors.New("Failed to look up client location")
    38  	tooManyRequests         = "Too many periodic requests. Please contact support@measurementlab.net."
    39  )
    40  
    41  // Signer defines how access tokens are signed.
    42  type Signer interface {
    43  	Sign(cl jwt.Claims) (string, error)
    44  }
    45  
    46  type Limiter interface {
    47  	IsLimited(ip, ua string) (limits.LimitStatus, error)
    48  }
    49  
    50  // Client contains state needed for xyz.
    51  type Client struct {
    52  	Signer
    53  	project string
    54  	LocatorV2
    55  	ClientLocator
    56  	PrometheusClient
    57  	targetTmpl       *template.Template
    58  	agentLimits      limits.Agents
    59  	ipLimiter        Limiter
    60  	earlyExitClients map[string]bool
    61  }
    62  
    63  // LocatorV2 defines how the Nearest handler requests machines nearest to the
    64  // client.
    65  type LocatorV2 interface {
    66  	Nearest(service string, lat, lon float64, opts *heartbeat.NearestOptions) (*heartbeat.TargetInfo, error)
    67  	heartbeat.StatusTracker
    68  }
    69  
    70  // ClientLocator defines the interfeace for looking up the client geo location.
    71  type ClientLocator interface {
    72  	Locate(req *http.Request) (*clientgeo.Location, error)
    73  }
    74  
    75  // PrometheusClient defines the interface to query Prometheus.
    76  type PrometheusClient interface {
    77  	Query(ctx context.Context, query string, ts time.Time, opts ...prom.Option) (model.Value, prom.Warnings, error)
    78  }
    79  
    80  type paramOpts struct {
    81  	raw       url.Values
    82  	version   string
    83  	ranks     map[string]int
    84  	svcParams map[string]float64
    85  }
    86  
    87  func init() {
    88  	log.SetFormatter(&log.JSONFormatter{})
    89  	log.SetLevel(log.InfoLevel)
    90  }
    91  
    92  // NewClient creates a new client.
    93  func NewClient(project string, private Signer, locatorV2 LocatorV2, client ClientLocator,
    94  	prom PrometheusClient, lmts limits.Agents, limiter Limiter, earlyExitClients []string) *Client {
    95  	// Convert slice to map for O(1) lookups
    96  	earlyExitMap := make(map[string]bool)
    97  	for _, client := range earlyExitClients {
    98  		earlyExitMap[client] = true
    99  	}
   100  	return &Client{
   101  		Signer:           private,
   102  		project:          project,
   103  		LocatorV2:        locatorV2,
   104  		ClientLocator:    client,
   105  		PrometheusClient: prom,
   106  		targetTmpl:       template.Must(template.New("name").Parse("{{.Hostname}}{{.Ports}}")),
   107  		agentLimits:      lmts,
   108  		ipLimiter:        limiter,
   109  		earlyExitClients: earlyExitMap,
   110  	}
   111  }
   112  
   113  // NewClientDirect creates a new client with a target template using only the target machine.
   114  func NewClientDirect(project string, private Signer, locatorV2 LocatorV2, client ClientLocator, prom PrometheusClient) *Client {
   115  	return &Client{
   116  		Signer:           private,
   117  		project:          project,
   118  		LocatorV2:        locatorV2,
   119  		ClientLocator:    client,
   120  		PrometheusClient: prom,
   121  		// Useful for the locatetest package when running a local server.
   122  		targetTmpl: template.Must(template.New("name").Parse("{{.Hostname}}{{.Ports}}")),
   123  	}
   124  }
   125  
   126  func (c *Client) extraParams(hostname string, index int, p paramOpts) url.Values {
   127  	v := url.Values{}
   128  
   129  	// Add client parameters.
   130  	for key := range p.raw {
   131  		if strings.HasPrefix(key, "client_") {
   132  			// note: we only use the first value.
   133  			v.Set(key, p.raw.Get(key))
   134  		}
   135  
   136  		val, ok := p.svcParams[key]
   137  		if ok && rand.Float64() < val {
   138  			v.Set(key, p.raw.Get(key))
   139  		}
   140  	}
   141  
   142  	// Add early_exit parameter for specified clients
   143  	clientName := p.raw.Get("client_name")
   144  	if clientName != "" && c.earlyExitClients[clientName] {
   145  		v.Set(static.EarlyExitParameter, static.EarlyExitDefaultValue)
   146  	}
   147  
   148  	// Add Locate Service version.
   149  	v.Set("locate_version", p.version)
   150  
   151  	// Add metro rank.
   152  	rank, ok := p.ranks[hostname]
   153  	if ok {
   154  		v.Set("metro_rank", strconv.Itoa(rank))
   155  	}
   156  
   157  	// Add result index.
   158  	v.Set("index", strconv.Itoa(index))
   159  
   160  	return v
   161  }
   162  
   163  // Nearest uses an implementation of the LocatorV2 interface to look up
   164  // nearest servers.
   165  func (c *Client) Nearest(rw http.ResponseWriter, req *http.Request) {
   166  	req.ParseForm()
   167  	result := v2.NearestResult{}
   168  	setHeaders(rw)
   169  
   170  	if c.limitRequest(time.Now().UTC(), req) {
   171  		result.Error = v2.NewError("client", tooManyRequests, http.StatusTooManyRequests)
   172  		writeResult(rw, result.Error.Status, &result)
   173  		metrics.RequestsTotal.WithLabelValues("nearest", "request limit", http.StatusText(result.Error.Status)).Inc()
   174  		return
   175  	}
   176  
   177  	// Check rate limit for IP and UA.
   178  	if c.ipLimiter != nil {
   179  		// Get the IP address from the request. X-Forwarded-For is guaranteed to
   180  		// be set by AppEngine.
   181  		ip := req.Header.Get("X-Forwarded-For")
   182  		ips := strings.Split(ip, ",")
   183  		if len(ips) > 0 {
   184  			ip = strings.TrimSpace(ips[0])
   185  		}
   186  		if ip != "" {
   187  			// An empty UA is technically possible.
   188  			ua := req.Header.Get("User-Agent")
   189  			status, err := c.ipLimiter.IsLimited(ip, ua)
   190  			if err != nil {
   191  				// Log error but don't block request (fail open).
   192  				// TODO: Add tests for this path.
   193  				log.Printf("Rate limiter error: %v", err)
   194  			} else if status.IsLimited {
   195  				// Log IP and UA and block the request.
   196  				result.Error = v2.NewError("client", tooManyRequests, http.StatusTooManyRequests)
   197  				metrics.RequestsTotal.WithLabelValues("nearest", "rate limit",
   198  					http.StatusText(result.Error.Status)).Inc()
   199  				// If the client provided a client_name, we want to know how many times
   200  				// that client_name was rate limited. This may be empty, which is fine.
   201  				clientName := req.Form.Get("client_name")
   202  				metrics.RateLimitedTotal.WithLabelValues(clientName, status.LimitType).Inc()
   203  
   204  				log.Printf("Rate limit (%s) exceeded for IP: %s, client: %s, UA: %s", ip,
   205  					status.LimitType, clientName, ua)
   206  				writeResult(rw, result.Error.Status, &result)
   207  				return
   208  			}
   209  		} else {
   210  			// This should never happen if Locate is deployed on AppEngine.
   211  			log.Println("Cannot find IP address for rate limiting.")
   212  		}
   213  	}
   214  
   215  	experiment, service := getExperimentAndService(req.URL.Path)
   216  
   217  	// Look up client location.
   218  	loc, err := c.checkClientLocation(rw, req)
   219  	if err != nil {
   220  		status := http.StatusServiceUnavailable
   221  		result.Error = v2.NewError("nearest", "Failed to lookup nearest machines", status)
   222  		writeResult(rw, result.Error.Status, &result)
   223  		metrics.RequestsTotal.WithLabelValues("nearest", "client location",
   224  			http.StatusText(result.Error.Status)).Inc()
   225  		return
   226  	}
   227  
   228  	// Parse client location.
   229  	lat, errLat := strconv.ParseFloat(loc.Latitude, 64)
   230  	lon, errLon := strconv.ParseFloat(loc.Longitude, 64)
   231  	if errLat != nil || errLon != nil {
   232  		result.Error = v2.NewError("client", errFailedToLookupClient.Error(), http.StatusInternalServerError)
   233  		writeResult(rw, result.Error.Status, &result)
   234  		metrics.RequestsTotal.WithLabelValues("nearest", "parse client location",
   235  			http.StatusText(result.Error.Status)).Inc()
   236  		return
   237  	}
   238  
   239  	// Find the nearest targets using the client parameters.
   240  	q := req.URL.Query()
   241  	t := q.Get("machine-type")
   242  	country := req.Header.Get("X-AppEngine-Country")
   243  	sites := q["site"]
   244  	org := q.Get("org")
   245  	strict := false
   246  	if qsStrict, err := strconv.ParseBool(q.Get("strict")); err == nil {
   247  		strict = qsStrict
   248  	}
   249  	// If strict, override the country from the AppEngine header with the one in
   250  	// the querystring.
   251  	if strict {
   252  		country = q.Get("country")
   253  	}
   254  	opts := &heartbeat.NearestOptions{Type: t, Country: country, Sites: sites, Org: org, Strict: strict}
   255  	targetInfo, err := c.LocatorV2.Nearest(service, lat, lon, opts)
   256  	if err != nil {
   257  		result.Error = v2.NewError("nearest", "Failed to lookup nearest machines", http.StatusInternalServerError)
   258  		writeResult(rw, result.Error.Status, &result)
   259  		metrics.RequestsTotal.WithLabelValues("nearest", "server location",
   260  			http.StatusText(result.Error.Status)).Inc()
   261  		return
   262  	}
   263  
   264  	pOpts := paramOpts{
   265  		raw:       req.Form,
   266  		version:   "v2",
   267  		ranks:     targetInfo.Ranks,
   268  		svcParams: static.ServiceParams,
   269  	}
   270  	// Populate target URLs and write out response.
   271  	c.populateURLs(targetInfo.Targets, targetInfo.URLs, experiment, pOpts)
   272  	result.Results = targetInfo.Targets
   273  	writeResult(rw, http.StatusOK, &result)
   274  	metrics.RequestsTotal.WithLabelValues("nearest", "success", http.StatusText(http.StatusOK)).Inc()
   275  }
   276  
   277  // Live is a minimal handler to indicate that the server is operating at all.
   278  func (c *Client) Live(rw http.ResponseWriter, req *http.Request) {
   279  	fmt.Fprintf(rw, "ok")
   280  }
   281  
   282  // Ready reports whether the server is working as expected and ready to serve requests.
   283  func (c *Client) Ready(rw http.ResponseWriter, req *http.Request) {
   284  	if c.LocatorV2.Ready() {
   285  		fmt.Fprintf(rw, "ok")
   286  	} else {
   287  		rw.WriteHeader(http.StatusInternalServerError)
   288  		fmt.Fprintf(rw, "not ready")
   289  	}
   290  }
   291  
   292  // Registrations returns information about registered machines. There are 3
   293  // supported query parameters:
   294  //
   295  // * format - defines the format of the returned JSON
   296  // * org - limits results to only records for the given organization
   297  // * exp - limits results to only records for the given experiment (e.g., ndt)
   298  func (c *Client) Registrations(rw http.ResponseWriter, req *http.Request) {
   299  	var err error
   300  	var result interface{}
   301  
   302  	q := req.URL.Query()
   303  	format := q.Get("format")
   304  
   305  	switch format {
   306  	default:
   307  		result, err = siteinfo.Machines(c.LocatorV2.Instances(), q)
   308  	}
   309  
   310  	if err != nil {
   311  		v2Error := v2.NewError("siteinfo", err.Error(), http.StatusInternalServerError)
   312  		writeResult(rw, http.StatusInternalServerError, v2Error)
   313  		return
   314  	}
   315  
   316  	writeResult(rw, http.StatusOK, result)
   317  }
   318  
   319  // checkClientLocation looks up the client location and copies the location
   320  // headers to the response writer.
   321  func (c *Client) checkClientLocation(rw http.ResponseWriter, req *http.Request) (*clientgeo.Location, error) {
   322  	// Lookup the client location using the client request.
   323  	loc, err := c.Locate(req)
   324  	if err != nil {
   325  		return nil, errFailedToLookupClient
   326  	}
   327  
   328  	// Copy location headers to response writer.
   329  	for key := range loc.Headers {
   330  		rw.Header().Set(key, loc.Headers.Get(key))
   331  	}
   332  
   333  	return loc, nil
   334  }
   335  
   336  // populateURLs populates each set of URLs using the target configuration.
   337  func (c *Client) populateURLs(targets []v2.Target, ports static.Ports, exp string, pOpts paramOpts) {
   338  	for i, target := range targets {
   339  		token := c.getAccessToken(target.Machine, exp)
   340  		params := c.extraParams(target.Machine, i, pOpts)
   341  		targets[i].URLs = c.getURLs(ports, target.Hostname, token, params)
   342  	}
   343  }
   344  
   345  // getAccessToken allocates a new access token using the given machine name as
   346  // the intended audience and the subject as the target service.
   347  func (c *Client) getAccessToken(machine, subject string) string {
   348  	// Create the token. The same access token is reused for every URL of a
   349  	// target port.
   350  	// A uuid is added to the claims so that each new token is unique.
   351  	cl := jwt.Claims{
   352  		Issuer:   static.IssuerLocate,
   353  		Subject:  subject,
   354  		Audience: jwt.Audience{machine},
   355  		Expiry:   jwt.NewNumericDate(time.Now().Add(time.Minute)),
   356  		ID:       uuid.NewString(),
   357  	}
   358  	token, err := c.Sign(cl)
   359  	// Sign errors can only happen due to a misconfiguration of the key.
   360  	// A good config will remain good.
   361  	rtx.PanicOnError(err, "signing claims has failed")
   362  	return token
   363  }
   364  
   365  // getURLs creates URLs for the named experiment, running on the named machine
   366  // for each given port. Every URL will include an `access_token=` parameter,
   367  // authorizing the measurement.
   368  func (c *Client) getURLs(ports static.Ports, hostname, token string, extra url.Values) map[string]string {
   369  	urls := map[string]string{}
   370  	// For each port config, prepare the target url with access_token and
   371  	// complete host field.
   372  	for _, target := range ports {
   373  		name := target.String()
   374  		params := url.Values{}
   375  		params.Set("access_token", token)
   376  		for key := range extra {
   377  			// note: we only use the first value.
   378  			params.Set(key, extra.Get(key))
   379  		}
   380  		target.RawQuery = params.Encode()
   381  
   382  		host := &bytes.Buffer{}
   383  		err := c.targetTmpl.Execute(host, map[string]string{
   384  			"Hostname": hostname,
   385  			"Ports":    target.Host, // from URL template, so typically just the ":port".
   386  		})
   387  		rtx.PanicOnError(err, "bad template evaluation")
   388  		target.Host = host.String()
   389  		urls[name] = target.String()
   390  	}
   391  	return urls
   392  }
   393  
   394  // limitRequest determines whether a client request should be rate-limited.
   395  func (c *Client) limitRequest(now time.Time, req *http.Request) bool {
   396  	agent := req.Header.Get("User-Agent")
   397  	l, ok := c.agentLimits[agent]
   398  	if !ok {
   399  		// No limit defined for user agent.
   400  		return false
   401  	}
   402  	return l.IsLimited(now)
   403  }
   404  
   405  // setHeaders sets the response headers for "nearest" requests.
   406  func setHeaders(rw http.ResponseWriter) {
   407  	// Set CORS policy to allow third-party websites to use returned resources.
   408  	rw.Header().Set("Content-Type", "application/json")
   409  	rw.Header().Set("Access-Control-Allow-Origin", "*")
   410  	// Prevent caching of result.
   411  	// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
   412  	rw.Header().Set("Cache-Control", "no-store")
   413  }
   414  
   415  // writeResult marshals the result and writes the result to the response writer.
   416  func writeResult(rw http.ResponseWriter, status int, result interface{}) {
   417  	b, err := json.MarshalIndent(result, "", "  ")
   418  	// Errors are only possible when marshalling incompatible types, like functions.
   419  	rtx.PanicOnError(err, "Failed to format result")
   420  	rw.WriteHeader(status)
   421  	rw.Write(b)
   422  }
   423  
   424  // getExperimentAndService takes an http request path and extracts the last two
   425  // fields. For correct requests (e.g. "/v2/nearest/ndt/ndt5"), this will be the
   426  // experiment name (e.g. "ndt") and the datatype (e.g. "ndt5").
   427  func getExperimentAndService(p string) (string, string) {
   428  	datatype := path.Base(p)
   429  	experiment := path.Base(path.Dir(p))
   430  	return experiment, experiment + "/" + datatype
   431  }