github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/checks/client.go (about) 1 package checks 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "net/url" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/hashicorp/go-cleanhttp" 16 "github.com/hashicorp/go-hclog" 17 "github.com/hashicorp/nomad/client/serviceregistration" 18 "github.com/hashicorp/nomad/nomad/structs" 19 "oss.indeed.com/go/libtime" 20 ) 21 22 const ( 23 // maxTimeoutHTTP is a fail-safe value for the HTTP client, ensuring a Nomad 24 // Client does not leak goroutines hanging on to unresponsive endpoints. 25 maxTimeoutHTTP = 10 * time.Minute 26 ) 27 28 // Checker executes a check given an allocation-specific context, and produces 29 // a resulting structs.CheckQueryResult 30 type Checker interface { 31 Do(context.Context, *QueryContext, *Query) *structs.CheckQueryResult 32 } 33 34 // New creates a new Checker capable of executing HTTP and TCP checks. 35 func New(log hclog.Logger) Checker { 36 httpClient := cleanhttp.DefaultPooledClient() 37 httpClient.Timeout = maxTimeoutHTTP 38 return &checker{ 39 log: log.Named("checks"), 40 httpClient: httpClient, 41 clock: libtime.SystemClock(), 42 } 43 } 44 45 type checker struct { 46 log hclog.Logger 47 clock libtime.Clock 48 httpClient *http.Client 49 } 50 51 func (c *checker) now() int64 { 52 return c.clock.Now().UTC().Unix() 53 } 54 55 // Do will execute the Query given the QueryContext and produce a structs.CheckQueryResult 56 func (c *checker) Do(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult { 57 var qr *structs.CheckQueryResult 58 59 timeout, cancel := context.WithTimeout(ctx, q.Timeout) 60 defer cancel() 61 62 switch q.Type { 63 case "http": 64 qr = c.checkHTTP(timeout, qc, q) 65 default: 66 qr = c.checkTCP(timeout, qc, q) 67 } 68 69 qr.ID = qc.ID 70 qr.Group = qc.Group 71 qr.Task = qc.Task 72 qr.Service = qc.Service 73 qr.Check = qc.Check 74 return qr 75 } 76 77 // resolve the address to use when executing Query given a QueryContext 78 func address(qc *QueryContext, q *Query) (string, error) { 79 mode := q.AddressMode 80 if mode == "" { // determine resolution for check address 81 if qc.CustomAddress != "" { 82 // if the service is using a custom address, enable the check to 83 // inherit that custom address 84 mode = structs.AddressModeAuto 85 } else { 86 // otherwise a check defaults to the host address 87 mode = structs.AddressModeHost 88 } 89 } 90 91 label := q.PortLabel 92 if label == "" { 93 label = qc.ServicePortLabel 94 } 95 96 status := qc.NetworkStatus.NetworkStatus() 97 addr, port, err := serviceregistration.GetAddress( 98 qc.CustomAddress, // custom address 99 mode, // check address mode 100 label, // port label 101 qc.Networks, // allocation networks 102 nil, // driver network (not supported) 103 qc.Ports, // ports 104 status, // allocation network status 105 ) 106 if err != nil { 107 return "", err 108 } 109 if port > 0 { 110 addr = net.JoinHostPort(addr, strconv.Itoa(port)) 111 } 112 return addr, nil 113 } 114 115 func (c *checker) checkTCP(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult { 116 qr := &structs.CheckQueryResult{ 117 Mode: q.Mode, 118 Timestamp: c.now(), 119 Status: structs.CheckPending, 120 } 121 122 addr, err := address(qc, q) 123 if err != nil { 124 qr.Output = err.Error() 125 qr.Status = structs.CheckFailure 126 return qr 127 } 128 129 if _, err = new(net.Dialer).DialContext(ctx, "tcp", addr); err != nil { 130 qr.Output = err.Error() 131 qr.Status = structs.CheckFailure 132 return qr 133 } 134 135 qr.Output = "nomad: tcp ok" 136 qr.Status = structs.CheckSuccess 137 return qr 138 } 139 140 func (c *checker) checkHTTP(ctx context.Context, qc *QueryContext, q *Query) *structs.CheckQueryResult { 141 qr := &structs.CheckQueryResult{ 142 Mode: q.Mode, 143 Timestamp: c.now(), 144 Status: structs.CheckPending, 145 } 146 147 addr, err := address(qc, q) 148 if err != nil { 149 qr.Output = err.Error() 150 qr.Status = structs.CheckFailure 151 return qr 152 } 153 154 u := (&url.URL{ 155 Scheme: q.Protocol, 156 Host: addr, 157 Path: q.Path, 158 }).String() 159 160 request, err := http.NewRequest(q.Method, u, nil) 161 if err != nil { 162 qr.Output = fmt.Sprintf("nomad: %s", err.Error()) 163 qr.Status = structs.CheckFailure 164 return qr 165 } 166 for header, values := range q.Headers { 167 for _, value := range values { 168 request.Header.Add(header, value) 169 } 170 } 171 172 request.Host = request.Header.Get("Host") 173 174 request.Body = io.NopCloser(strings.NewReader(q.Body)) 175 request = request.WithContext(ctx) 176 177 result, err := c.httpClient.Do(request) 178 if err != nil { 179 qr.Output = fmt.Sprintf("nomad: %s", err.Error()) 180 qr.Status = structs.CheckFailure 181 return qr 182 } 183 defer func() { 184 _ = result.Body.Close() 185 }() 186 187 // match the result status code to the http status code 188 qr.StatusCode = result.StatusCode 189 190 switch { 191 case result.StatusCode == 200: 192 qr.Status = structs.CheckSuccess 193 qr.Output = "nomad: http ok" 194 return qr 195 case result.StatusCode < 400: 196 qr.Status = structs.CheckSuccess 197 default: 198 qr.Status = structs.CheckFailure 199 } 200 201 // status code was not 200; read the response body and set that as the 202 // check result output content 203 qr.Output = limitRead(result.Body) 204 205 return qr 206 } 207 208 const ( 209 // outputSizeLimit is the maximum number of bytes to read and store of an http 210 // check output. Set to 3kb which fits in 1 page with room for other fields. 211 outputSizeLimit = 3 * 1024 212 ) 213 214 func limitRead(r io.Reader) string { 215 b := make([]byte, 0, outputSizeLimit) 216 output := bytes.NewBuffer(b) 217 limited := io.LimitReader(r, outputSizeLimit) 218 if _, err := io.Copy(output, limited); err != nil { 219 return fmt.Sprintf("nomad: %s", err.Error()) 220 } 221 return output.String() 222 }