github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/fakeexec/loopback.go (about)

     1  // Copyright 2021 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package fakeexec
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net"
    12  	"os"
    13  
    14  	"google.golang.org/grpc"
    15  
    16  	"go.chromium.org/tast/core/internal/protocol"
    17  	"go.chromium.org/tast/core/shutil"
    18  )
    19  
    20  // ProcFunc is a callback passed to CreateLoopback to fully control the behavior of
    21  // a loopback process.
    22  type ProcFunc func(args []string, stdin io.Reader, stdout, stderr io.WriteCloser) int
    23  
    24  // Exec implements LoopbackExecService.
    25  func (p ProcFunc) Exec(srv protocol.LoopbackExecService_ExecServer) error {
    26  	// Receive ExecInit.
    27  	req, err := srv.Recv()
    28  	if err != nil {
    29  		return err
    30  	}
    31  	init := req.GetType().(*protocol.ExecRequest_Init).Init
    32  
    33  	stdin := &execIn{srv: srv}
    34  	stdout := &execOut{
    35  		srv: srv,
    36  		ctor: func(ev *protocol.PipeEvent) *protocol.ExecResponse {
    37  			return &protocol.ExecResponse{Type: &protocol.ExecResponse_Stdout{Stdout: ev}}
    38  		}}
    39  	stderr := &execOut{
    40  		srv: srv,
    41  		ctor: func(ev *protocol.PipeEvent) *protocol.ExecResponse {
    42  			return &protocol.ExecResponse{Type: &protocol.ExecResponse_Stderr{Stderr: ev}}
    43  		}}
    44  
    45  	// Run the callback.
    46  	code := p(init.GetArgs(), stdin, stdout, stderr)
    47  
    48  	// Send ExitEvent.
    49  	return srv.Send(&protocol.ExecResponse{Type: &protocol.ExecResponse_Exit{Exit: &protocol.ExitEvent{Code: int32(code)}}})
    50  }
    51  
    52  // Loopback represents a loopback executable file.
    53  type Loopback struct {
    54  	srv  *grpc.Server
    55  	path string
    56  }
    57  
    58  // CreateLoopback creates a new file called a loopback executable.
    59  //
    60  // When a loopback executable file is executed, its process connects to the
    61  // current unit test process by gRPC to call proc remotely. The process behaves
    62  // exactly as specified by proc. Since proc is called within the current unit
    63  // test process, unit tests and subprocesses can interact easily with Go
    64  // constructs, e.g. shared memory or channels.
    65  //
    66  // A drawback is that proc can only emulate args, stdio and exit code. If you
    67  // need to do anything else, e.g. catching signals, use NewAuxMain instead.
    68  //
    69  // Once you're done with a loopback executable, call Loopback.Close to release
    70  // associated resources.
    71  func CreateLoopback(path string, proc ProcFunc) (lo *Loopback, retErr error) {
    72  	// Listen on a local port.
    73  	lis, err := net.Listen("tcp", "localhost:0")
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	port := lis.Addr().(*net.TCPAddr).Port
    78  	defer func() {
    79  		if retErr != nil {
    80  			lis.Close()
    81  		}
    82  	}()
    83  
    84  	// Create a loopback executable file.
    85  	script, err := buildScript(port)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	if err := ioutil.WriteFile(path, script, 0755); err != nil {
    90  		return nil, err
    91  	}
    92  	defer func() {
    93  		if retErr != nil {
    94  			os.Remove(path)
    95  		}
    96  	}()
    97  
    98  	// Make sure the file has executable bit.
    99  	if err := os.Chmod(path, 0755); err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	// Finally start a gRPC server.
   104  	srv := grpc.NewServer()
   105  	protocol.RegisterLoopbackExecServiceServer(srv, proc)
   106  	go srv.Serve(lis)
   107  	return &Loopback{srv: srv, path: path}, nil
   108  }
   109  
   110  func buildScript(port int) ([]byte, error) {
   111  	exe, err := os.Executable()
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	script := fmt.Sprintf(`#!/bin/sh
   117  %s=%d exec %s "$@"
   118  `, portEnvName, port, shutil.Escape(exe))
   119  	return []byte(script), nil
   120  }
   121  
   122  // Close removes the loopback executable file and releases its associated
   123  // resources.
   124  func (s *Loopback) Close() error {
   125  	s.srv.GracefulStop()
   126  	return os.Remove(s.path)
   127  }
   128  
   129  // execIn implements io.Reader which reads from the loopback process stdin.
   130  type execIn struct {
   131  	srv    protocol.LoopbackExecService_ExecServer
   132  	buf    []byte
   133  	closed bool
   134  }
   135  
   136  func (s *execIn) Read(p []byte) (n int, err error) {
   137  	for {
   138  		// Return buffered data.
   139  		if len(s.buf) > 0 {
   140  			n = copy(p, s.buf)
   141  			s.buf = s.buf[n:]
   142  			return n, nil
   143  		}
   144  		if s.closed {
   145  			return 0, io.EOF
   146  		}
   147  
   148  		// Buffer is empty, wait for new data.
   149  		req, err := s.srv.Recv()
   150  		if err != nil {
   151  			return 0, err
   152  		}
   153  
   154  		// Fill the buffer and continue.
   155  		ev := req.GetType().(*protocol.ExecRequest_Stdin).Stdin
   156  		s.buf = ev.GetData()
   157  		if ev.GetClose() {
   158  			s.closed = true
   159  		}
   160  	}
   161  }
   162  
   163  // execOut implements io.WriteCloser which writes to the loopback process stdout
   164  // or stderr.
   165  type execOut struct {
   166  	srv  protocol.LoopbackExecService_ExecServer
   167  	ctor func(*protocol.PipeEvent) *protocol.ExecResponse
   168  }
   169  
   170  func (s *execOut) Write(p []byte) (n int, err error) {
   171  	if err := s.srv.Send(s.ctor(&protocol.PipeEvent{Data: p})); err != nil {
   172  		return 0, err
   173  	}
   174  	return len(p), nil
   175  }
   176  
   177  func (s *execOut) Close() error {
   178  	return s.srv.Send(s.ctor(&protocol.PipeEvent{Close: true}))
   179  }