go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/tq/tqtesting/loopback.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tqtesting
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"io"
    21  	"net"
    22  	"net/http"
    23  	"net/url"
    24  	"strconv"
    25  
    26  	taskspb "cloud.google.com/go/cloudtasks/apiv2/cloudtaskspb"
    27  
    28  	"go.chromium.org/luci/common/logging"
    29  )
    30  
    31  // LoopbackHTTPExecutor is an Executor that executes tasks by sending HTTP
    32  // requests to the server with TQ module serving at the given (usually loopback)
    33  // address.
    34  //
    35  // Used exclusively when running TQ locally.
    36  type LoopbackHTTPExecutor struct {
    37  	// ServerAddr is where the server is listening for requests.
    38  	ServerAddr net.Addr
    39  }
    40  
    41  // Execute dispatches the task to the HTTP handler in a dedicated goroutine.
    42  //
    43  // Marks the task as failed if the response status code is outside of
    44  // range [200-299].
    45  func (e *LoopbackHTTPExecutor) Execute(ctx context.Context, t *Task, done func(retry bool)) {
    46  	if t.Message != nil {
    47  		done(false)
    48  		panic("Executing PubSub tasks is not supported yet") // break tests loudly
    49  	}
    50  
    51  	success := false
    52  	defer func() {
    53  		done(!success)
    54  	}()
    55  
    56  	if e.ServerAddr == nil {
    57  		logging.Errorf(ctx, "LoopbackHTTPExecutor is not configured. Is the server exposing main HTTP port?")
    58  		return
    59  	}
    60  
    61  	var method taskspb.HttpMethod
    62  	var requestURL string
    63  	var headers map[string]string
    64  	var body []byte
    65  
    66  	switch mt := t.Task.MessageType.(type) {
    67  	case *taskspb.Task_HttpRequest:
    68  		method = mt.HttpRequest.HttpMethod
    69  		requestURL = mt.HttpRequest.Url
    70  		headers = mt.HttpRequest.Headers
    71  		body = mt.HttpRequest.Body
    72  	case *taskspb.Task_AppEngineHttpRequest:
    73  		method = mt.AppEngineHttpRequest.HttpMethod
    74  		requestURL = mt.AppEngineHttpRequest.RelativeUri
    75  		headers = mt.AppEngineHttpRequest.Headers
    76  		body = mt.AppEngineHttpRequest.Body
    77  	default:
    78  		logging.Errorf(ctx, "Bad task, no payload: %q", t.Task)
    79  		return
    80  	}
    81  
    82  	parsedURL, err := url.Parse(requestURL)
    83  	if err != nil {
    84  		logging.Errorf(ctx, "Bad task URL %q", requestURL)
    85  		return
    86  	}
    87  	host := parsedURL.Host
    88  
    89  	// Make the URL relative to the localhost server at the requested port.
    90  	parsedURL.Scheme = "http"
    91  	parsedURL.Host = e.ServerAddr.String() // this is "<host>:<port>"
    92  	requestURL = parsedURL.String()
    93  
    94  	req, err := http.NewRequestWithContext(ctx, method.String(), requestURL, bytes.NewReader(body))
    95  	if err != nil {
    96  		logging.Errorf(ctx, "Could not construct HTTP request: %s", err)
    97  		return
    98  	}
    99  	req.Host = host // sets "Host" request header
   100  	for k, v := range headers {
   101  		req.Header.Set(k, v)
   102  	}
   103  
   104  	// See https://cloud.google.com/tasks/docs/creating-http-target-tasks#handler
   105  	// We emulate only headers we actually use.
   106  	req.Header.Set("X-CloudTasks-TaskExecutionCount", strconv.Itoa(t.Attempts-1))
   107  	if t.Attempts > 1 {
   108  		req.Header.Set("X-CloudTasks-TaskRetryReason", "task handler failed")
   109  	}
   110  
   111  	resp, err := http.DefaultClient.Do(req)
   112  	if err != nil {
   113  		logging.Errorf(ctx, "Failed to send HTTP request: %s", err)
   114  		return
   115  	}
   116  	defer resp.Body.Close()
   117  	// Read the body fully to be able to reuse the connection.
   118  	if _, err = io.Copy(io.Discard, resp.Body); err != nil {
   119  		logging.Errorf(ctx, "Failed to read the response: %s", err)
   120  		return
   121  	}
   122  
   123  	success = resp.StatusCode >= 200 && resp.StatusCode <= 299
   124  }