github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/worker/worker.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package worker provides an implementation of the bazel worker protocol. 16 // 17 // Tools may be written as a normal command line utility, except the passed 18 // run function may be invoked multiple times. 19 package worker 20 21 import ( 22 "bufio" 23 "bytes" 24 "flag" 25 "fmt" 26 "io" 27 "io/ioutil" 28 "log" 29 "net" 30 "net/http" 31 "os" 32 "path/filepath" 33 "sort" 34 "strings" 35 "time" 36 37 _ "net/http/pprof" // For profiling. 38 39 "golang.org/x/sys/unix" 40 "google.golang.org/protobuf/encoding/protowire" 41 "google.golang.org/protobuf/proto" 42 wpb "github.com/SagerNet/bazel/worker_protocol_go_proto" 43 ) 44 45 var ( 46 persistentWorker = flag.Bool("persistent_worker", false, "enable persistent worker.") 47 workerDebug = flag.Bool("worker_debug", false, "debug persistent workers.") 48 maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.") 49 ) 50 51 var ( 52 // inputFiles is the last set of input files. 53 // 54 // This is used for cache invalidation. The key is the *absolute* path 55 // name, and the value is the digest in the current run. 56 inputFiles = make(map[string]string) 57 58 // activeCaches is the set of active caches. 59 activeCaches = make(map[*Cache]struct{}) 60 61 // totalCacheUsage is the total usage of all caches. 62 totalCacheUsage int64 63 ) 64 65 // mustAbs returns the absolute path of a filename or dies. 66 func mustAbs(filename string) string { 67 abs, err := filepath.Abs(filename) 68 if err != nil { 69 log.Fatalf("error getting absolute path: %v", err) 70 } 71 return abs 72 } 73 74 // updateInputFiles creates an entry in inputFiles. 75 func updateInputFile(filename, digest string) { 76 inputFiles[mustAbs(filename)] = digest 77 } 78 79 // Sizer returns a size. 80 type Sizer interface { 81 Size() int64 82 } 83 84 // CacheBytes is an example of a Sizer. 85 type CacheBytes []byte 86 87 // Size implements Sizer.Size. 88 func (cb CacheBytes) Size() int64 { 89 return int64(len(cb)) 90 } 91 92 // Cache is a worker cache. 93 // 94 // They can be created via NewCache. 95 type Cache struct { 96 name string 97 entries map[string]Sizer 98 size int64 99 hits int64 100 misses int64 101 } 102 103 // NewCache returns a new cache. 104 func NewCache(name string) *Cache { 105 return &Cache{ 106 name: name, 107 } 108 } 109 110 // Lookup looks up an entry in the cache. 111 // 112 // It is a function of the given files. 113 func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer { 114 digests := make([]string, 0, len(filenames)) 115 for _, filename := range filenames { 116 digest, ok := inputFiles[mustAbs(filename)] 117 if !ok { 118 // This is not a valid input. We may not be running as 119 // persistent worker in this cache. If that's the case, 120 // then the file's contents will not change across the 121 // run, and we just use the filename itself. 122 digest = filename 123 } 124 digests = append(digests, digest) 125 } 126 127 // Attempt the lookup. 128 sort.Slice(digests, func(i, j int) bool { 129 return digests[i] < digests[j] 130 }) 131 cacheKey := strings.Join(digests, "+") 132 if c.entries == nil { 133 c.entries = make(map[string]Sizer) 134 activeCaches[c] = struct{}{} 135 } 136 entry, ok := c.entries[cacheKey] 137 if ok { 138 c.hits++ 139 return entry 140 } 141 142 // Generate a new entry. 143 entry = generate() 144 c.misses++ 145 c.entries[cacheKey] = entry 146 if entry != nil { 147 sz := entry.Size() 148 c.size += sz 149 totalCacheUsage += sz 150 } 151 152 // Check the capacity of all caches. If it greater than the maximum, we 153 // flush everything but still return this entry. 154 if totalCacheUsage > *maximumCacheUsage { 155 for entry, _ := range activeCaches { 156 // Drop all entries. 157 entry.size = 0 158 entry.entries = nil 159 } 160 totalCacheUsage = 0 // Reset. 161 } 162 163 return entry 164 } 165 166 // allCacheStats returns stats for all caches. 167 func allCacheStats() string { 168 var sb strings.Builder 169 for entry, _ := range activeCaches { 170 ratio := float64(entry.hits) / float64(entry.hits+entry.misses) 171 fmt.Fprintf(&sb, 172 "% 10s: count: % 5d size: % 10d hits: % 7d misses: % 7d ratio: %2.2f\n", 173 entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio) 174 } 175 if len(activeCaches) > 0 { 176 fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage) 177 } 178 return sb.String() 179 } 180 181 // LookupDigest returns a digest for the given file. 182 func LookupDigest(filename string) (string, bool) { 183 digest, ok := inputFiles[filename] 184 return digest, ok 185 } 186 187 // Work invokes the main function. 188 func Work(run func([]string) int) { 189 flag.CommandLine.Parse(os.Args[1:]) 190 if !*persistentWorker { 191 // Handle the argument file. 192 args := flag.CommandLine.Args() 193 if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' { 194 content, err := ioutil.ReadFile(args[0][1:]) 195 if err != nil { 196 log.Fatalf("unable to parse args file: %v", err) 197 } 198 // Pull arguments from the file. 199 args = strings.Split(string(content), "\n") 200 flag.CommandLine.Parse(args) 201 args = flag.CommandLine.Args() 202 } 203 os.Exit(run(args)) 204 } 205 206 var listenHeader string // Emitted always. 207 if *workerDebug { 208 // Bind a server for profiling. 209 listener, err := net.Listen("tcp", "localhost:0") 210 if err != nil { 211 log.Fatalf("unable to bind a server: %v", err) 212 } 213 // Construct the header for stats output, below. 214 listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port) 215 go http.Serve(listener, nil) 216 } 217 218 // Move stdout. This is done to prevent anything else from accidentally 219 // printing to stdout, which must contain only the valid WorkerResponse 220 // serialized protos. 221 newOutput, err := unix.Dup(1) 222 if err != nil { 223 log.Fatalf("unable to move stdout: %v", err) 224 } 225 // Stderr may be closed or may be a copy of stdout. We make sure that 226 // we have an output that is in a completely separate range. 227 for newOutput <= 2 { 228 newOutput, err = unix.Dup(newOutput) 229 if err != nil { 230 log.Fatalf("unable to move stdout: %v", err) 231 } 232 } 233 234 // Best-effort: collect logs. 235 rPipe, wPipe, err := os.Pipe() 236 if err != nil { 237 log.Fatalf("unable to create pipe: %v", err) 238 } 239 if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil { 240 log.Fatalf("error duping over stdout: %v", err) 241 } 242 if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil { 243 log.Fatalf("error duping over stderr: %v", err) 244 } 245 wPipe.Close() 246 defer rPipe.Close() 247 248 // Read requests from stdin. 249 input := bufio.NewReader(os.NewFile(0, "input")) 250 output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output")) 251 for { 252 szBuf, err := input.Peek(4) 253 if err != nil { 254 log.Fatalf("unabel to read header: %v", err) 255 } 256 257 // Parse the size, and discard bits. 258 sz, szBytes := protowire.ConsumeVarint(szBuf) 259 if szBytes < 0 { 260 szBytes = 0 261 } 262 if _, err := input.Discard(szBytes); err != nil { 263 log.Fatalf("error discarding size: %v", err) 264 } 265 266 // Read a full message. 267 msg := make([]byte, int(sz)) 268 if _, err := io.ReadFull(input, msg); err != nil { 269 log.Fatalf("error reading worker request: %v", err) 270 } 271 var wreq wpb.WorkRequest 272 if err := proto.Unmarshal(msg, &wreq); err != nil { 273 log.Fatalf("error unmarshaling worker request: %v", err) 274 } 275 276 // Flush relevant caches. 277 inputFiles = make(map[string]string) 278 for _, input := range wreq.GetInputs() { 279 updateInputFile(input.GetPath(), string(input.GetDigest())) 280 } 281 282 // Prepare logging. 283 outputBuffer := bytes.NewBuffer(nil) 284 outputBuffer.WriteString(listenHeader) 285 log.SetOutput(outputBuffer) 286 287 // Parse all arguments. 288 flag.CommandLine.Parse(wreq.GetArguments()) 289 var exitCode int 290 exitChan := make(chan int) 291 go func() { exitChan <- run(flag.CommandLine.Args()) }() 292 for running := true; running; { 293 select { 294 case exitCode = <-exitChan: 295 running = false 296 default: 297 } 298 // N.B. rPipe is given a read deadline of 1ms. We expect 299 // this to turn a copy error after 1ms, and we just keep 300 // flushing this buffer while the task is running. 301 rPipe.SetReadDeadline(time.Now().Add(time.Millisecond)) 302 outputBuffer.ReadFrom(rPipe) 303 } 304 305 if *workerDebug { 306 // Attach all cache stats. 307 outputBuffer.WriteString(allCacheStats()) 308 } 309 310 // Send the response. 311 var wresp wpb.WorkResponse 312 wresp.ExitCode = int32(exitCode) 313 wresp.Output = string(outputBuffer.Bytes()) 314 rmsg, err := proto.Marshal(&wresp) 315 if err != nil { 316 log.Fatalf("error marshaling response: %v", err) 317 } 318 if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil { 319 log.Fatalf("error sending worker response: %v", err) 320 } 321 if err := output.Flush(); err != nil { 322 log.Fatalf("error flushing output: %v", err) 323 } 324 } 325 }