github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/checks/client.go (about)

     1  package checks
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/hashicorp/go-cleanhttp"
    16  	"github.com/hashicorp/go-hclog"
    17  	"github.com/hashicorp/nomad/client/serviceregistration"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  	"oss.indeed.com/go/libtime"
    20  )
    21  
    22  const (
    23  	// maxTimeoutHTTP is a fail-safe value for the HTTP client, ensuring a Nomad
    24  	// Client does not leak goroutines hanging on to unresponsive endpoints.
    25  	maxTimeoutHTTP = 10 * time.Minute
    26  )
    27  
    28  // Checker executes a check given an allocation-specific context, and produces
    29  // a resulting structs.CheckQueryResult
    30  type Checker interface {
    31  	Do(context.Context, *QueryContext, *Query) *structs.CheckQueryResult
    32  }
    33  
    34  // New creates a new Checker capable of executing HTTP and TCP checks.
    35  func New(log hclog.Logger) Checker {
    36  	httpClient := cleanhttp.DefaultPooledClient()
    37  	httpClient.Timeout = maxTimeoutHTTP
    38  	return &checker{
    39  		log:        log.Named("checks"),
    40  		httpClient: httpClient,
    41  		clock:      libtime.SystemClock(),
    42  	}
    43  }
    44  
    45  type checker struct {
    46  	log        hclog.Logger
    47  	clock      libtime.Clock
    48  	httpClient *http.Client
    49  }
    50  
    51  func (c *checker) now() int64 {
    52  	return c.clock.Now().UTC().Unix()
    53  }
    54  
    55  // Do will execute the Query given the QueryContext and produce a structs.CheckQueryResult
    56  func (c *checker) Do(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult {
    57  	var qr *structs.CheckQueryResult
    58  
    59  	timeout, cancel := context.WithTimeout(ctx, q.Timeout)
    60  	defer cancel()
    61  
    62  	switch q.Type {
    63  	case "http":
    64  		qr = c.checkHTTP(timeout, qc, q)
    65  	default:
    66  		qr = c.checkTCP(timeout, qc, q)
    67  	}
    68  
    69  	qr.ID = qc.ID
    70  	qr.Group = qc.Group
    71  	qr.Task = qc.Task
    72  	qr.Service = qc.Service
    73  	qr.Check = qc.Check
    74  	return qr
    75  }
    76  
    77  // resolve the address to use when executing Query given a QueryContext
    78  func address(qc *QueryContext, q *Query) (string, error) {
    79  	mode := q.AddressMode
    80  	if mode == "" { // determine resolution for check address
    81  		if qc.CustomAddress != "" {
    82  			// if the service is using a custom address, enable the check to
    83  			// inherit that custom address
    84  			mode = structs.AddressModeAuto
    85  		} else {
    86  			// otherwise a check defaults to the host address
    87  			mode = structs.AddressModeHost
    88  		}
    89  	}
    90  
    91  	label := q.PortLabel
    92  	if label == "" {
    93  		label = qc.ServicePortLabel
    94  	}
    95  
    96  	status := qc.NetworkStatus.NetworkStatus()
    97  	addr, port, err := serviceregistration.GetAddress(
    98  		qc.CustomAddress, // custom address
    99  		mode,             // check address mode
   100  		label,            // port label
   101  		qc.Networks,      // allocation networks
   102  		nil,              // driver network (not supported)
   103  		qc.Ports,         // ports
   104  		status,           // allocation network status
   105  	)
   106  	if err != nil {
   107  		return "", err
   108  	}
   109  	if port > 0 {
   110  		addr = net.JoinHostPort(addr, strconv.Itoa(port))
   111  	}
   112  	return addr, nil
   113  }
   114  
   115  func (c *checker) checkTCP(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult {
   116  	qr := &structs.CheckQueryResult{
   117  		Mode:      q.Mode,
   118  		Timestamp: c.now(),
   119  		Status:    structs.CheckPending,
   120  	}
   121  
   122  	addr, err := address(qc, q)
   123  	if err != nil {
   124  		qr.Output = err.Error()
   125  		qr.Status = structs.CheckFailure
   126  		return qr
   127  	}
   128  
   129  	if _, err = new(net.Dialer).DialContext(ctx, "tcp", addr); err != nil {
   130  		qr.Output = err.Error()
   131  		qr.Status = structs.CheckFailure
   132  		return qr
   133  	}
   134  
   135  	qr.Output = "nomad: tcp ok"
   136  	qr.Status = structs.CheckSuccess
   137  	return qr
   138  }
   139  
   140  func (c *checker) checkHTTP(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult {
   141  	qr := &structs.CheckQueryResult{
   142  		Mode:      q.Mode,
   143  		Timestamp: c.now(),
   144  		Status:    structs.CheckPending,
   145  	}
   146  
   147  	addr, err := address(qc, q)
   148  	if err != nil {
   149  		qr.Output = err.Error()
   150  		qr.Status = structs.CheckFailure
   151  		return qr
   152  	}
   153  
   154  	u := (&url.URL{
   155  		Scheme: q.Protocol,
   156  		Host:   addr,
   157  		Path:   q.Path,
   158  	}).String()
   159  
   160  	request, err := http.NewRequest(q.Method, u, nil)
   161  	if err != nil {
   162  		qr.Output = fmt.Sprintf("nomad: %s", err.Error())
   163  		qr.Status = structs.CheckFailure
   164  		return qr
   165  	}
   166  	for header, values := range q.Headers {
   167  		for _, value := range values {
   168  			request.Header.Add(header, value)
   169  		}
   170  	}
   171  
   172  	request.Host = request.Header.Get("Host")
   173  
   174  	request.Body = io.NopCloser(strings.NewReader(q.Body))
   175  	request = request.WithContext(ctx)
   176  
   177  	result, err := c.httpClient.Do(request)
   178  	if err != nil {
   179  		qr.Output = fmt.Sprintf("nomad: %s", err.Error())
   180  		qr.Status = structs.CheckFailure
   181  		return qr
   182  	}
   183  	defer func() {
   184  		_ = result.Body.Close()
   185  	}()
   186  
   187  	// match the result status code to the http status code
   188  	qr.StatusCode = result.StatusCode
   189  
   190  	switch {
   191  	case result.StatusCode == 200:
   192  		qr.Status = structs.CheckSuccess
   193  		qr.Output = "nomad: http ok"
   194  		return qr
   195  	case result.StatusCode < 400:
   196  		qr.Status = structs.CheckSuccess
   197  	default:
   198  		qr.Status = structs.CheckFailure
   199  	}
   200  
   201  	// status code was not 200; read the response body and set that as the
   202  	// check result output content
   203  	qr.Output = limitRead(result.Body)
   204  
   205  	return qr
   206  }
   207  
   208  const (
   209  	// outputSizeLimit is the maximum number of bytes to read and store of an http
   210  	// check output. Set to 3kb which fits in 1 page with room for other fields.
   211  	outputSizeLimit = 3 * 1024
   212  )
   213  
   214  func limitRead(r io.Reader) string {
   215  	b := make([]byte, 0, outputSizeLimit)
   216  	output := bytes.NewBuffer(b)
   217  	limited := io.LimitReader(r, outputSizeLimit)
   218  	if _, err := io.Copy(output, limited); err != nil {
   219  		return fmt.Sprintf("nomad: %s", err.Error())
   220  	}
   221  	return output.String()
   222  }