go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/exec/internal/execmockserver/client.go (about)

     1  // Copyright 2023 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package execmockserver
    16  
    17  import (
    18  	"flag"
    19  	"fmt"
    20  	"net/rpc"
    21  	"os"
    22  	"reflect"
    23  	"runtime/debug"
    24  	"strconv"
    25  	"strings"
    26  
    27  	"go.chromium.org/luci/common/errors"
    28  )
    29  
    30  type client struct {
    31  	client       *rpc.Client
    32  	invocationID uint64
    33  }
    34  
    35  func (c *client) getIvocationInput() *InvocationInput {
    36  	var ret InvocationInput
    37  	if err := c.client.Call("ExecMockServer.GetInvocationInput", c.invocationID, &ret); err != nil {
    38  		panic(err)
    39  	}
    40  	return &ret
    41  }
    42  
    43  func (c *client) setInvocationOutput(rslt, err reflect.Value, panicStack string) {
    44  	out := &InvocationOutput{
    45  		InvocationID: c.invocationID,
    46  		RunnerOutput: rslt.Interface(),
    47  		RunnerPanic:  panicStack,
    48  	}
    49  
    50  	if !err.IsNil() {
    51  		out.RunnerError = err.Interface().(error).Error()
    52  	}
    53  
    54  	var ignore int
    55  	if err := c.client.Call("ExecMockServer.SetInvocationOutput", out, &ignore); err != nil {
    56  		panic(err)
    57  	}
    58  }
    59  
    60  // ClientIntercept will look for the LUCI_EXECMOCK_CTX environment variable, and, if
    61  // found, invoke `cb` with an initialized Client as well as the input
    62  // corresponding to this invocation.
    63  //
    64  // The callback will execute the mock function with the decoded input, and then
    65  // possibly call Client.SetInvocationOutput.
    66  //
    67  // See go.chromium.org/luci/common/execmock
    68  func ClientIntercept(runnerRegistry map[uint64]reflect.Value) (exitcode int, intercepted bool) {
    69  	endpoint := os.Getenv(execServeEnvvar)
    70  	if endpoint == "" {
    71  		return 0, false
    72  	}
    73  	intercepted = true
    74  
    75  	// We reset flag.CommandLine to make it easier to use `flag` from a runner
    76  	// function without accidentally picking up the default flags registered by
    77  	// `go test`.
    78  	oldCLI := flag.CommandLine
    79  	defer func() { flag.CommandLine = oldCLI }()
    80  	flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
    81  
    82  	// This is the mock subprocess invocation.
    83  
    84  	// hide the envvar from the RunnerFunction for cleanliness purposes.
    85  	if err := os.Unsetenv(execServeEnvvar); err != nil {
    86  		panic(err)
    87  	}
    88  
    89  	tokens := strings.Split(endpoint, "|")
    90  	if len(tokens) != 2 {
    91  		panic(errors.Reason("%s: expected two tokens, got %q", execServeEnvvar, endpoint).Err())
    92  	}
    93  	hostname, invocationIDstr := tokens[0], tokens[1]
    94  	invocationID, err := strconv.ParseUint(invocationIDstr, 10, 64)
    95  	if err != nil {
    96  		panic(err)
    97  	}
    98  	rpcClient, err := rpc.DialHTTP("tcp", hostname)
    99  	if err != nil {
   100  		panic(err)
   101  	}
   102  
   103  	emClient := &client{rpcClient, invocationID}
   104  	input := emClient.getIvocationInput()
   105  	runnerFn := runnerRegistry[input.RunnerID]
   106  	if !runnerFn.IsValid() {
   107  		panic(fmt.Sprintf("unknown runner id %d", input.RunnerID))
   108  	}
   109  
   110  	inputType := runnerFn.Type().In(0)
   111  
   112  	// GoB will always decode this as flat, so if the runner is expecting *T or
   113  	// v will just be type T.
   114  	v := reflect.ValueOf(input.RunnerInput)
   115  	if v.Type() != inputType && inputType.Kind() == reflect.Pointer {
   116  		newV := reflect.New(v.Type())
   117  		newV.Elem().Set(v)
   118  		v = newV
   119  	}
   120  
   121  	defer func() {
   122  		if thing := recover(); thing != nil {
   123  			stack := string(debug.Stack())
   124  			exitcode = 1
   125  			if err, ok := thing.(error); ok {
   126  				emClient.setInvocationOutput(
   127  					reflect.New(runnerFn.Type().Out(0)).Elem(), reflect.ValueOf(err), stack)
   128  			} else {
   129  				emClient.setInvocationOutput(
   130  					reflect.New(runnerFn.Type().Out(0)).Elem(), reflect.ValueOf(errors.Reason("%s", thing).Err()), stack)
   131  			}
   132  		}
   133  	}()
   134  	results := runnerFn.Call([]reflect.Value{v})
   135  	emClient.setInvocationOutput(results[0], results[2], "")
   136  
   137  	return int(results[1].Int()), true
   138  }