github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/worker/worker.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package worker provides an implementation of the bazel worker protocol.
    16  //
    17  // Tools may be written as a normal command line utility, except the passed
    18  // run function may be invoked multiple times.
    19  package worker
    20  
    21  import (
    22  	"bufio"
    23  	"bytes"
    24  	"flag"
    25  	"fmt"
    26  	"io"
    27  	"io/ioutil"
    28  	"log"
    29  	"net"
    30  	"net/http"
    31  	"os"
    32  	"path/filepath"
    33  	"sort"
    34  	"strings"
    35  	"time"
    36  
    37  	_ "net/http/pprof" // For profiling.
    38  
    39  	"golang.org/x/sys/unix"
    40  	"google.golang.org/protobuf/encoding/protowire"
    41  	"google.golang.org/protobuf/proto"
    42  	wpb "github.com/SagerNet/bazel/worker_protocol_go_proto"
    43  )
    44  
    45  var (
    46  	persistentWorker  = flag.Bool("persistent_worker", false, "enable persistent worker.")
    47  	workerDebug       = flag.Bool("worker_debug", false, "debug persistent workers.")
    48  	maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.")
    49  )
    50  
    51  var (
    52  	// inputFiles is the last set of input files.
    53  	//
    54  	// This is used for cache invalidation. The key is the *absolute* path
    55  	// name, and the value is the digest in the current run.
    56  	inputFiles = make(map[string]string)
    57  
    58  	// activeCaches is the set of active caches.
    59  	activeCaches = make(map[*Cache]struct{})
    60  
    61  	// totalCacheUsage is the total usage of all caches.
    62  	totalCacheUsage int64
    63  )
    64  
    65  // mustAbs returns the absolute path of a filename or dies.
    66  func mustAbs(filename string) string {
    67  	abs, err := filepath.Abs(filename)
    68  	if err != nil {
    69  		log.Fatalf("error getting absolute path: %v", err)
    70  	}
    71  	return abs
    72  }
    73  
    74  // updateInputFiles creates an entry in inputFiles.
    75  func updateInputFile(filename, digest string) {
    76  	inputFiles[mustAbs(filename)] = digest
    77  }
    78  
    79  // Sizer returns a size.
    80  type Sizer interface {
    81  	Size() int64
    82  }
    83  
    84  // CacheBytes is an example of a Sizer.
    85  type CacheBytes []byte
    86  
    87  // Size implements Sizer.Size.
    88  func (cb CacheBytes) Size() int64 {
    89  	return int64(len(cb))
    90  }
    91  
    92  // Cache is a worker cache.
    93  //
    94  // They can be created via NewCache.
    95  type Cache struct {
    96  	name    string
    97  	entries map[string]Sizer
    98  	size    int64
    99  	hits    int64
   100  	misses  int64
   101  }
   102  
   103  // NewCache returns a new cache.
   104  func NewCache(name string) *Cache {
   105  	return &Cache{
   106  		name: name,
   107  	}
   108  }
   109  
   110  // Lookup looks up an entry in the cache.
   111  //
   112  // It is a function of the given files.
   113  func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer {
   114  	digests := make([]string, 0, len(filenames))
   115  	for _, filename := range filenames {
   116  		digest, ok := inputFiles[mustAbs(filename)]
   117  		if !ok {
   118  			// This is not a valid input. We may not be running as
   119  			// persistent worker in this cache. If that's the case,
   120  			// then the file's contents will not change across the
   121  			// run, and we just use the filename itself.
   122  			digest = filename
   123  		}
   124  		digests = append(digests, digest)
   125  	}
   126  
   127  	// Attempt the lookup.
   128  	sort.Slice(digests, func(i, j int) bool {
   129  		return digests[i] < digests[j]
   130  	})
   131  	cacheKey := strings.Join(digests, "+")
   132  	if c.entries == nil {
   133  		c.entries = make(map[string]Sizer)
   134  		activeCaches[c] = struct{}{}
   135  	}
   136  	entry, ok := c.entries[cacheKey]
   137  	if ok {
   138  		c.hits++
   139  		return entry
   140  	}
   141  
   142  	// Generate a new entry.
   143  	entry = generate()
   144  	c.misses++
   145  	c.entries[cacheKey] = entry
   146  	if entry != nil {
   147  		sz := entry.Size()
   148  		c.size += sz
   149  		totalCacheUsage += sz
   150  	}
   151  
   152  	// Check the capacity of all caches. If it greater than the maximum, we
   153  	// flush everything but still return this entry.
   154  	if totalCacheUsage > *maximumCacheUsage {
   155  		for entry, _ := range activeCaches {
   156  			// Drop all entries.
   157  			entry.size = 0
   158  			entry.entries = nil
   159  		}
   160  		totalCacheUsage = 0 // Reset.
   161  	}
   162  
   163  	return entry
   164  }
   165  
   166  // allCacheStats returns stats for all caches.
   167  func allCacheStats() string {
   168  	var sb strings.Builder
   169  	for entry, _ := range activeCaches {
   170  		ratio := float64(entry.hits) / float64(entry.hits+entry.misses)
   171  		fmt.Fprintf(&sb,
   172  			"% 10s: count: % 5d  size: % 10d  hits: % 7d  misses: % 7d  ratio: %2.2f\n",
   173  			entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio)
   174  	}
   175  	if len(activeCaches) > 0 {
   176  		fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage)
   177  	}
   178  	return sb.String()
   179  }
   180  
   181  // LookupDigest returns a digest for the given file.
   182  func LookupDigest(filename string) (string, bool) {
   183  	digest, ok := inputFiles[filename]
   184  	return digest, ok
   185  }
   186  
   187  // Work invokes the main function.
   188  func Work(run func([]string) int) {
   189  	flag.CommandLine.Parse(os.Args[1:])
   190  	if !*persistentWorker {
   191  		// Handle the argument file.
   192  		args := flag.CommandLine.Args()
   193  		if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' {
   194  			content, err := ioutil.ReadFile(args[0][1:])
   195  			if err != nil {
   196  				log.Fatalf("unable to parse args file: %v", err)
   197  			}
   198  			// Pull arguments from the file.
   199  			args = strings.Split(string(content), "\n")
   200  			flag.CommandLine.Parse(args)
   201  			args = flag.CommandLine.Args()
   202  		}
   203  		os.Exit(run(args))
   204  	}
   205  
   206  	var listenHeader string // Emitted always.
   207  	if *workerDebug {
   208  		// Bind a server for profiling.
   209  		listener, err := net.Listen("tcp", "localhost:0")
   210  		if err != nil {
   211  			log.Fatalf("unable to bind a server: %v", err)
   212  		}
   213  		// Construct the header for stats output, below.
   214  		listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port)
   215  		go http.Serve(listener, nil)
   216  	}
   217  
   218  	// Move stdout. This is done to prevent anything else from accidentally
   219  	// printing to stdout, which must contain only the valid WorkerResponse
   220  	// serialized protos.
   221  	newOutput, err := unix.Dup(1)
   222  	if err != nil {
   223  		log.Fatalf("unable to move stdout: %v", err)
   224  	}
   225  	// Stderr may be closed or may be a copy of stdout. We make sure that
   226  	// we have an output that is in a completely separate range.
   227  	for newOutput <= 2 {
   228  		newOutput, err = unix.Dup(newOutput)
   229  		if err != nil {
   230  			log.Fatalf("unable to move stdout: %v", err)
   231  		}
   232  	}
   233  
   234  	// Best-effort: collect logs.
   235  	rPipe, wPipe, err := os.Pipe()
   236  	if err != nil {
   237  		log.Fatalf("unable to create pipe: %v", err)
   238  	}
   239  	if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil {
   240  		log.Fatalf("error duping over stdout: %v", err)
   241  	}
   242  	if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil {
   243  		log.Fatalf("error duping over stderr: %v", err)
   244  	}
   245  	wPipe.Close()
   246  	defer rPipe.Close()
   247  
   248  	// Read requests from stdin.
   249  	input := bufio.NewReader(os.NewFile(0, "input"))
   250  	output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output"))
   251  	for {
   252  		szBuf, err := input.Peek(4)
   253  		if err != nil {
   254  			log.Fatalf("unabel to read header: %v", err)
   255  		}
   256  
   257  		// Parse the size, and discard bits.
   258  		sz, szBytes := protowire.ConsumeVarint(szBuf)
   259  		if szBytes < 0 {
   260  			szBytes = 0
   261  		}
   262  		if _, err := input.Discard(szBytes); err != nil {
   263  			log.Fatalf("error discarding size: %v", err)
   264  		}
   265  
   266  		// Read a full message.
   267  		msg := make([]byte, int(sz))
   268  		if _, err := io.ReadFull(input, msg); err != nil {
   269  			log.Fatalf("error reading worker request: %v", err)
   270  		}
   271  		var wreq wpb.WorkRequest
   272  		if err := proto.Unmarshal(msg, &wreq); err != nil {
   273  			log.Fatalf("error unmarshaling worker request: %v", err)
   274  		}
   275  
   276  		// Flush relevant caches.
   277  		inputFiles = make(map[string]string)
   278  		for _, input := range wreq.GetInputs() {
   279  			updateInputFile(input.GetPath(), string(input.GetDigest()))
   280  		}
   281  
   282  		// Prepare logging.
   283  		outputBuffer := bytes.NewBuffer(nil)
   284  		outputBuffer.WriteString(listenHeader)
   285  		log.SetOutput(outputBuffer)
   286  
   287  		// Parse all arguments.
   288  		flag.CommandLine.Parse(wreq.GetArguments())
   289  		var exitCode int
   290  		exitChan := make(chan int)
   291  		go func() { exitChan <- run(flag.CommandLine.Args()) }()
   292  		for running := true; running; {
   293  			select {
   294  			case exitCode = <-exitChan:
   295  				running = false
   296  			default:
   297  			}
   298  			// N.B. rPipe is given a read deadline of 1ms. We expect
   299  			// this to turn a copy error after 1ms, and we just keep
   300  			// flushing this buffer while the task is running.
   301  			rPipe.SetReadDeadline(time.Now().Add(time.Millisecond))
   302  			outputBuffer.ReadFrom(rPipe)
   303  		}
   304  
   305  		if *workerDebug {
   306  			// Attach all cache stats.
   307  			outputBuffer.WriteString(allCacheStats())
   308  		}
   309  
   310  		// Send the response.
   311  		var wresp wpb.WorkResponse
   312  		wresp.ExitCode = int32(exitCode)
   313  		wresp.Output = string(outputBuffer.Bytes())
   314  		rmsg, err := proto.Marshal(&wresp)
   315  		if err != nil {
   316  			log.Fatalf("error marshaling response: %v", err)
   317  		}
   318  		if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil {
   319  			log.Fatalf("error sending worker response: %v", err)
   320  		}
   321  		if err := output.Flush(); err != nil {
   322  			log.Fatalf("error flushing output: %v", err)
   323  		}
   324  	}
   325  }