sigs.k8s.io/gateway-api@v1.0.0/conformance/utils/roundtripper/roundtripper.go (about) 1 /* 2 Copyright 2022 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package roundtripper 18 19 import ( 20 "context" 21 "crypto/tls" 22 "crypto/x509" 23 "encoding/json" 24 "errors" 25 "fmt" 26 "io" 27 "net" 28 "net/http" 29 "net/http/httputil" 30 "net/url" 31 "regexp" 32 33 "golang.org/x/net/http2" 34 35 "sigs.k8s.io/gateway-api/conformance/utils/config" 36 ) 37 38 const ( 39 H2CPriorKnowledgeProtocol = "H2C_PRIOR_KNOWLEDGE" 40 ) 41 42 // RoundTripper is an interface used to make requests within conformance tests. 43 // This can be overridden with custom implementations whenever necessary. 44 type RoundTripper interface { 45 CaptureRoundTrip(Request) (*CapturedRequest, *CapturedResponse, error) 46 } 47 48 // Request is the primary input for making a request. 49 type Request struct { 50 URL url.URL 51 Host string 52 Protocol string 53 Method string 54 Headers map[string][]string 55 UnfollowRedirect bool 56 CertPem []byte 57 KeyPem []byte 58 Server string 59 } 60 61 // String returns a printable version of Request for logging. Note that the 62 // CertPem and KeyPem are truncated. 63 func (r Request) String() string { 64 return fmt.Sprintf("{URL: %+v, Host: %v, Protocol: %v, Method: %v, Headers: %v, UnfollowRedirect: %v, Server: %v, CertPem: <truncated>, KeyPem: <truncated>}", 65 r.URL, 66 r.Host, 67 r.Protocol, 68 r.Method, 69 r.Headers, 70 r.UnfollowRedirect, 71 r.Server, 72 ) 73 } 74 75 // CapturedRequest contains request metadata captured from an echoserver 76 // response. 77 type CapturedRequest struct { 78 Path string `json:"path"` 79 Host string `json:"host"` 80 Method string `json:"method"` 81 Protocol string `json:"proto"` 82 Headers map[string][]string `json:"headers"` 83 84 Namespace string `json:"namespace"` 85 Pod string `json:"pod"` 86 } 87 88 // RedirectRequest contains a follow up request metadata captured from a redirect 89 // response. 90 type RedirectRequest struct { 91 Scheme string 92 Host string 93 Port string 94 Path string 95 } 96 97 // CapturedResponse contains response metadata. 98 type CapturedResponse struct { 99 StatusCode int 100 ContentLength int64 101 Protocol string 102 Headers map[string][]string 103 RedirectRequest *RedirectRequest 104 } 105 106 // DefaultRoundTripper is the default implementation of a RoundTripper. It will 107 // be used if a custom implementation is not specified. 108 type DefaultRoundTripper struct { 109 Debug bool 110 TimeoutConfig config.TimeoutConfig 111 CustomDialContext func(context.Context, string, string) (net.Conn, error) 112 } 113 114 func (d *DefaultRoundTripper) httpTransport(request Request) (http.RoundTripper, error) { 115 transport := &http.Transport{ 116 DialContext: d.CustomDialContext, 117 // We disable keep-alives so that we don't leak established TCP connections. 118 // Leaking TCP connections is bad because we could eventually hit the 119 // threshold of maximum number of open TCP connections to a specific 120 // destination. Keep-alives are not presently utilized so disabling this has 121 // no adverse affect. 122 // 123 // Ref. https://github.com/kubernetes-sigs/gateway-api/issues/2357 124 DisableKeepAlives: true, 125 } 126 if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 { 127 tlsConfig, err := tlsClientConfig(request.Server, request.CertPem, request.KeyPem) 128 if err != nil { 129 return nil, err 130 } 131 transport.TLSClientConfig = tlsConfig 132 } 133 134 return transport, nil 135 } 136 137 func (d *DefaultRoundTripper) h2cPriorKnowledgeTransport(request Request) (http.RoundTripper, error) { 138 if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 { 139 return nil, errors.New("request has configured cert and key but h2 prior knowledge is not encrypted") 140 } 141 142 transport := &http2.Transport{ 143 AllowHTTP: true, 144 DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { 145 var d net.Dialer 146 return d.DialContext(ctx, network, addr) 147 }, 148 } 149 150 return transport, nil 151 } 152 153 // CaptureRoundTrip makes a request with the provided parameters and returns the 154 // captured request and response from echoserver. An error will be returned if 155 // there is an error running the function but not if an HTTP error status code 156 // is received. 157 func (d *DefaultRoundTripper) CaptureRoundTrip(request Request) (*CapturedRequest, *CapturedResponse, error) { 158 var transport http.RoundTripper 159 var err error 160 161 switch request.Protocol { 162 case H2CPriorKnowledgeProtocol: 163 transport, err = d.h2cPriorKnowledgeTransport(request) 164 default: 165 transport, err = d.httpTransport(request) 166 } 167 168 if err != nil { 169 return nil, nil, err 170 } 171 172 return d.defaultRoundTrip(request, transport) 173 } 174 175 func (d *DefaultRoundTripper) defaultRoundTrip(request Request, transport http.RoundTripper) (*CapturedRequest, *CapturedResponse, error) { 176 client := &http.Client{} 177 178 if request.UnfollowRedirect { 179 client.CheckRedirect = func(req *http.Request, via []*http.Request) error { 180 return http.ErrUseLastResponse 181 } 182 } 183 184 client.Transport = transport 185 186 method := "GET" 187 if request.Method != "" { 188 method = request.Method 189 } 190 ctx, cancel := context.WithTimeout(context.Background(), d.TimeoutConfig.RequestTimeout) 191 defer cancel() 192 req, err := http.NewRequestWithContext(ctx, method, request.URL.String(), nil) 193 if err != nil { 194 return nil, nil, err 195 } 196 197 if request.Host != "" { 198 req.Host = request.Host 199 } 200 201 if request.Headers != nil { 202 for name, value := range request.Headers { 203 req.Header.Set(name, value[0]) 204 } 205 } 206 207 if d.Debug { 208 var dump []byte 209 dump, err = httputil.DumpRequestOut(req, true) 210 if err != nil { 211 return nil, nil, err 212 } 213 214 fmt.Printf("Sending Request:\n%s\n\n", formatDump(dump, "< ")) 215 } 216 217 resp, err := client.Do(req) 218 if err != nil { 219 return nil, nil, err 220 } 221 defer resp.Body.Close() 222 223 if d.Debug { 224 var dump []byte 225 dump, err = httputil.DumpResponse(resp, true) 226 if err != nil { 227 return nil, nil, err 228 } 229 230 fmt.Printf("Received Response:\n%s\n\n", formatDump(dump, "< ")) 231 } 232 233 cReq := &CapturedRequest{} 234 235 body, err := io.ReadAll(resp.Body) 236 if err != nil { 237 return nil, nil, err 238 } 239 240 // we cannot assume the response is JSON 241 if resp.Header.Get("Content-type") == "application/json" { 242 err = json.Unmarshal(body, cReq) 243 if err != nil { 244 return nil, nil, fmt.Errorf("unexpected error reading response: %w", err) 245 } 246 } else { 247 cReq.Method = method // assume it made the right request if the service being called isn't echoing 248 } 249 250 cRes := &CapturedResponse{ 251 StatusCode: resp.StatusCode, 252 ContentLength: resp.ContentLength, 253 Protocol: resp.Proto, 254 Headers: resp.Header, 255 } 256 257 if IsRedirect(resp.StatusCode) { 258 redirectURL, err := resp.Location() 259 if err != nil { 260 return nil, nil, err 261 } 262 cRes.RedirectRequest = &RedirectRequest{ 263 Scheme: redirectURL.Scheme, 264 Host: redirectURL.Hostname(), 265 Port: redirectURL.Port(), 266 Path: redirectURL.Path, 267 } 268 } 269 270 return cReq, cRes, nil 271 } 272 273 func tlsClientConfig(server string, certPem []byte, keyPem []byte) (*tls.Config, error) { 274 // Create a certificate from the provided cert and key 275 cert, err := tls.X509KeyPair(certPem, keyPem) 276 if err != nil { 277 return nil, fmt.Errorf("unexpected error creating cert: %w", err) 278 } 279 280 // Add the provided cert as a trusted CA 281 certPool := x509.NewCertPool() 282 if !certPool.AppendCertsFromPEM(certPem) { 283 return nil, fmt.Errorf("unexpected error adding trusted CA: %w", err) 284 } 285 286 if server == "" { 287 return nil, fmt.Errorf("unexpected error, server name required for TLS") 288 } 289 290 // Create the tls Config for this provided host, cert, and trusted CA 291 // Disable G402: TLS MinVersion too low. (gosec) 292 // #nosec G402 293 return &tls.Config{ 294 Certificates: []tls.Certificate{cert}, 295 ServerName: server, 296 RootCAs: certPool, 297 }, nil 298 } 299 300 // IsRedirect returns true if a given status code is a redirect code. 301 func IsRedirect(statusCode int) bool { 302 switch statusCode { 303 case http.StatusMultipleChoices, 304 http.StatusMovedPermanently, 305 http.StatusFound, 306 http.StatusSeeOther, 307 http.StatusNotModified, 308 http.StatusUseProxy, 309 http.StatusTemporaryRedirect, 310 http.StatusPermanentRedirect: 311 return true 312 } 313 return false 314 } 315 316 // IsTimeoutError returns true if a given status code is a timeout error code. 317 func IsTimeoutError(statusCode int) bool { 318 switch statusCode { 319 case http.StatusRequestTimeout, 320 http.StatusGatewayTimeout: 321 return true 322 } 323 return false 324 } 325 326 var startLineRegex = regexp.MustCompile(`(?m)^`) 327 328 func formatDump(data []byte, prefix string) string { 329 data = startLineRegex.ReplaceAllLiteral(data, []byte(prefix)) 330 return string(data) 331 }