github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/cmd/runner/runner_test.go (about) 1 // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License. 2 3 package main 4 5 // This file contains the implementation of tests related to starting python based work and 6 // running it to completion within the server. Work run is prepackaged within the source 7 // code repository and orchestrated by the testing within this file. 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "crypto" 14 "crypto/rand" 15 "encoding/base64" 16 "encoding/json" 17 "fmt" 18 "html/template" 19 "io" 20 "io/ioutil" 21 "net/http" 22 "net/url" 23 "os" 24 "path" 25 "path/filepath" 26 "regexp" 27 "sort" 28 "strconv" 29 "strings" 30 "testing" 31 "time" 32 33 "github.com/leaf-ai/studio-go-runner/internal/runner" 34 35 runnerReports "github.com/leaf-ai/studio-go-runner/internal/gen/dev.cognizant_dev.ai/genproto/studio-go-runner/reports/v1" 36 37 "google.golang.org/protobuf/encoding/prototext" 38 39 "golang.org/x/crypto/ed25519" 40 "golang.org/x/crypto/ssh" 41 42 "github.com/davecgh/go-spew/spew" 43 "github.com/go-stack/stack" 44 "github.com/jjeffery/kv" // MIT License 45 46 minio "github.com/minio/minio-go" 47 48 "github.com/mholt/archiver" 49 model "github.com/prometheus/client_model/go" 50 "github.com/rs/xid" 51 52 "github.com/makasim/amqpextra" 53 "github.com/streadway/amqp" 54 ) 55 56 var ( 57 // Extracts three floating point values from a tensorflow output line typical for the experiments 58 // found within the tf packages. A typical log line will appear as follows 59 // '60000/60000 [==============================] - 1s 23us/step - loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355' 60 tfExtract = regexp.MustCompile(`(?mU).*loss:\s([-+]?[0-9]*\.[0-9]*)\s.*acc:\s([-+]?[0-9]*\.[0-9]*)\s.*val_loss:\s([-+]?[0-9]*\.[0-9]*)\s.*val_acc:\s([-+]?[0-9]*\.[0-9]*)$`) 61 ) 62 63 func TestATFExtractilargeon(t *testing.T) { 64 tfResultsExample := `60000/60000 [==============================] - 1s 23us/step - loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355` 65 66 expectedOutput := []string{ 67 tfResultsExample, 68 "0.2432", 69 "0.9313", 70 "0.2316", 71 "0.9355", 72 } 73 74 matches := tfExtract.FindAllStringSubmatch(tfResultsExample, -1) 75 for i, match := range expectedOutput { 76 if matches[0][i] != match { 77 t.Fatal(kv.NewError("a tensorflow result not extracted").With("expected", match).With("captured_match", matches[0][i]).With("stack", stack.Trace().TrimRuntime())) 78 } 79 } 80 } 81 82 type ExperData struct { 83 RabbitMQUser string 84 RabbitMQPassword string 85 Bucket string 86 MinioAddress string 87 MinioUser string 88 MinioPassword string 89 GPUs []runner.GPUTrack 90 GPUSlots int 91 } 92 93 // downloadFile will download a url to a local file using streaming. 94 // 95 func downloadFile(fn string, download string) (err kv.Error) { 96 97 // Create the file 98 out, errGo := os.Create(fn) 99 if errGo != nil { 100 return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime()) 101 } 102 defer out.Close() 103 104 // Get the data 105 resp, errGo := http.Get(download) 106 if errGo != nil { 107 return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime()) 108 } 109 defer resp.Body.Close() 110 111 // Write the body to file 112 _, errGo = io.Copy(out, resp.Body) 113 if errGo != nil { 114 return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime()) 115 } 116 117 return nil 118 } 119 120 func downloadRMQCli(fn string) (err kv.Error) { 121 if err = downloadFile(fn, os.ExpandEnv("http://${RABBITMQ_SERVICE_SERVICE_HOST}:${RABBITMQ_SERVICE_SERVICE_PORT_RMQ_ADMIN}/cli/rabbitmqadmin")); err != nil { 122 return err 123 } 124 // Having downloaded the administration CLI tool set it to be executable 125 if errGo := os.Chmod(fn, 0777); errGo != nil { 126 return kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime()) 127 } 128 return nil 129 } 130 131 // setupRMQ will download the rabbitMQ administration tool from the k8s deployed rabbitMQ 132 // server and place it into the project bin directory setting it to executable in order 133 // that diagnostic commands can be run using the shell 134 // 135 func setupRMQAdmin() (err kv.Error) { 136 rmqAdmin := path.Join("/project", "bin") 137 fi, errGo := os.Stat(rmqAdmin) 138 if errGo != nil { 139 return kv.Wrap(errGo).With("dir", rmqAdmin).With("stack", stack.Trace().TrimRuntime()) 140 } 141 if !fi.IsDir() { 142 return kv.NewError("specified directory is not actually a directory").With("dir", rmqAdmin).With("stack", stack.Trace().TrimRuntime()) 143 } 144 145 // Look for the rabbitMQ Server and download the command line tools for use 146 // in diagnosing issues, and do this before changing into the test directory 147 rmqAdmin = filepath.Join(rmqAdmin, "rabbitmqadmin") 148 return downloadRMQCli(rmqAdmin) 149 } 150 151 func collectUploadFiles(dir string) (files []string, err kv.Error) { 152 153 errGo := filepath.Walk(".", 154 func(path string, info os.FileInfo, err error) error { 155 files = append(files, path) 156 return nil 157 }) 158 159 if errGo != nil { 160 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 161 } 162 sort.Strings(files) 163 164 return files, nil 165 } 166 167 func uploadWorkspace(experiment *ExperData) (err kv.Error) { 168 169 wd, _ := os.Getwd() 170 logger.Debug("uploading", "dir", wd, "experiment", *experiment, "stack", stack.Trace().TrimRuntime()) 171 172 dir := "." 173 files, err := collectUploadFiles(dir) 174 if err != nil { 175 return err 176 } 177 if len(files) == 0 { 178 return kv.NewError("no files found").With("directory", dir).With("stack", stack.Trace().TrimRuntime()) 179 } 180 181 // Pack the files needed into an archive within a temporary directory 182 dir, errGo := ioutil.TempDir("", xid.New().String()) 183 if errGo != nil { 184 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 185 } 186 defer os.RemoveAll(dir) 187 188 archiveName := filepath.Join(dir, "workspace.tar") 189 190 if errGo = archiver.Tar.Make(archiveName, files); errGo != nil { 191 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 192 } 193 194 // Now we have the workspace for upload go ahead and contact the minio server 195 mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false) 196 if errGo != nil { 197 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 198 } 199 200 archive, errGo := os.Open(archiveName) 201 if errGo != nil { 202 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 203 } 204 defer archive.Close() 205 206 fileStat, errGo := archive.Stat() 207 if errGo != nil { 208 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 209 } 210 211 // Create the bucket that will be used by the experiment, and then place the workspace into it 212 if errGo = mc.MakeBucket(experiment.Bucket, ""); errGo != nil { 213 switch minio.ToErrorResponse(errGo).Code { 214 case "BucketAlreadyExists": 215 case "BucketAlreadyOwnedByYou": 216 default: 217 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 218 } 219 } 220 221 _, errGo = mc.PutObject(experiment.Bucket, "workspace.tar", archive, fileStat.Size(), 222 minio.PutObjectOptions{ 223 ContentType: "application/octet-stream", 224 }) 225 if errGo != nil { 226 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 227 } 228 return nil 229 } 230 231 func validateTFMinimal(ctx context.Context, experiment *ExperData) (err kv.Error) { 232 // Unpack the output archive within a temporary directory and use it for validation 233 dir, errGo := ioutil.TempDir("", xid.New().String()) 234 if errGo != nil { 235 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 236 } 237 defer os.RemoveAll(dir) 238 239 output := filepath.Join(dir, "output.tar") 240 if err = downloadOutput(ctx, experiment, output); err != nil { 241 return err 242 } 243 244 // Now examine the file for successfully running the python code 245 if errGo = archiver.Tar.Open(output, dir); errGo != nil { 246 return kv.Wrap(errGo).With("file", output).With("stack", stack.Trace().TrimRuntime()) 247 } 248 249 outFn := filepath.Join(dir, "output") 250 outFile, errGo := os.Open(outFn) 251 if errGo != nil { 252 return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime()) 253 } 254 255 supressDump := false 256 defer func() { 257 if !supressDump { 258 io.Copy(os.Stdout, outFile) 259 } 260 outFile.Close() 261 }() 262 263 // Typical values for these items inside the TF logging are as follows 264 // "loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355" 265 acceptableVals := []float64{ 266 0.35, 267 0.85, 268 0.35, 269 0.85, 270 } 271 272 matches := [][]string{} 273 scanner := bufio.NewScanner(outFile) 274 for scanner.Scan() { 275 matched := tfExtract.FindAllStringSubmatch(scanner.Text(), -1) 276 if len(matched) != 1 { 277 continue 278 } 279 if len(matched[0]) != 5 { 280 continue 281 } 282 matches = matched 283 } 284 if errGo = scanner.Err(); errGo != nil { 285 return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime()) 286 } 287 288 if len(matches) != 1 { 289 outFile.Seek(0, io.SeekStart) 290 io.Copy(os.Stdout, outFile) 291 return kv.NewError("unable to find any TF results in the log file").With("file", outFn).With("stack", stack.Trace().TrimRuntime()) 292 } 293 294 // Although the following values are not using epsilon style float adjustments because 295 // the test limits and values are abitrary anyway 296 297 // loss andf accuracy checks against the log data that was extracted using a regular expression 298 // and captures 299 loss, errGo := strconv.ParseFloat(matches[0][1], 64) 300 if errGo != nil { 301 return kv.Wrap(errGo).With("file", outFn).With("line", scanner.Text()).With("value", matches[0][1]).With("stack", stack.Trace().TrimRuntime()) 302 } 303 if loss > acceptableVals[1] { 304 return kv.NewError("loss is too large").With("file", outFn).With("line", scanner.Text()).With("value", loss).With("ceiling", acceptableVals[1]).With("stack", stack.Trace().TrimRuntime()) 305 } 306 loss, errGo = strconv.ParseFloat(matches[0][3], 64) 307 if errGo != nil { 308 return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][3]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime()) 309 } 310 if loss > acceptableVals[3] { 311 return kv.NewError("validation loss is too large").With("file", outFn).With("line", scanner.Text()).With("value", loss).With("ceiling", acceptableVals[3]).With("stack", stack.Trace().TrimRuntime()) 312 } 313 // accuracy checks 314 accu, errGo := strconv.ParseFloat(matches[0][2], 64) 315 if errGo != nil { 316 return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][2]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime()) 317 } 318 if accu < acceptableVals[2] { 319 return kv.NewError("accuracy is too small").With("file", outFn).With("line", scanner.Text()).With("value", accu).With("ceiling", acceptableVals[2]).With("stack", stack.Trace().TrimRuntime()) 320 } 321 accu, errGo = strconv.ParseFloat(matches[0][4], 64) 322 if errGo != nil { 323 return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][4]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime()) 324 } 325 if accu < acceptableVals[3] { 326 return kv.NewError("validation accuracy is too small").With("file", outFn).With("line", scanner.Text()).With("value", accu).With("ceiling", acceptableVals[3]).With("stack", stack.Trace().TrimRuntime()) 327 } 328 329 logger.Info(matches[0][0], "stack", stack.Trace().TrimRuntime()) 330 supressDump = true 331 332 return nil 333 } 334 335 func lsMetadata(ctx context.Context, experiment *ExperData) (names []string, err kv.Error) { 336 names = []string{} 337 338 // Now we have the workspace for upload go ahead and contact the minio server 339 mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false) 340 if errGo != nil { 341 return names, kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime()) 342 } 343 // Create a done channel to control 'ListObjects' go routine. 344 doneCh := make(chan struct{}) 345 346 // Indicate to our routine to exit cleanly upon return. 347 defer close(doneCh) 348 349 isRecursive := true 350 prefix := "metadata/" 351 objectCh := mc.ListObjects(experiment.Bucket, prefix, isRecursive, doneCh) 352 for object := range objectCh { 353 if object.Err != nil { 354 return names, kv.Wrap(object.Err).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime()) 355 } 356 names = append(names, fmt.Sprint(object.Key)) 357 } 358 return names, nil 359 } 360 361 func downloadMetadata(ctx context.Context, experiment *ExperData, outputDir string) (err kv.Error) { 362 // Now we have the workspace for upload go ahead and contact the minio server 363 mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false) 364 if errGo != nil { 365 return kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime()) 366 } 367 // Create a done channel to control 'ListObjects' go routine. 368 doneCh := make(chan struct{}) 369 370 // Indicate to our routine to exit cleanly upon return. 371 defer close(doneCh) 372 373 names := []string{} 374 375 isRecursive := true 376 prefix := "metadata/" 377 objectCh := mc.ListObjects(experiment.Bucket, prefix, isRecursive, doneCh) 378 for object := range objectCh { 379 if object.Err != nil { 380 return kv.Wrap(object.Err).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime()) 381 } 382 names = append(names, filepath.Base(object.Key)) 383 } 384 385 for _, name := range names { 386 key := prefix + name 387 object, errGo := mc.GetObject(experiment.Bucket, key, minio.GetObjectOptions{}) 388 if errGo != nil { 389 return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "name", name).With("stack", stack.Trace().TrimRuntime()) 390 } 391 localName := filepath.Join(outputDir, filepath.Base(name)) 392 localFile, errGo := os.Create(localName) 393 if errGo != nil { 394 return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "key", key, "filename", localName).With("stack", stack.Trace().TrimRuntime()) 395 } 396 if _, errGo = io.Copy(localFile, object); errGo != nil { 397 return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "key", key, "filename", localName).With("stack", stack.Trace().TrimRuntime()) 398 } 399 } 400 return nil 401 } 402 403 func downloadOutput(ctx context.Context, experiment *ExperData, output string) (err kv.Error) { 404 405 archive, errGo := os.Create(output) 406 if errGo != nil { 407 return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime()) 408 } 409 defer archive.Close() 410 411 // Now we have the workspace for upload go ahead and contact the minio server 412 mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false) 413 if errGo != nil { 414 return kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime()) 415 } 416 417 object, errGo := mc.GetObjectWithContext(ctx, experiment.Bucket, "output.tar", minio.GetObjectOptions{}) 418 if errGo != nil { 419 return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime()) 420 } 421 422 if _, errGo = io.Copy(archive, object); errGo != nil { 423 return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime()) 424 } 425 426 return nil 427 } 428 429 type relocateTemp func() (err kv.Error) 430 431 type relocate struct { 432 Original string 433 Pop []relocateTemp 434 } 435 436 func (r *relocate) Close() (err kv.Error) { 437 if r == nil { 438 return nil 439 } 440 // Iterate the list of call backs in reverse order when exiting 441 // the stack of things that were done as a LIFO 442 for i := len(r.Pop) - 1; i >= 0; i-- { 443 if err = r.Pop[i](); err != nil { 444 return err 445 } 446 } 447 return nil 448 } 449 450 func relocateToTemp(dir string) (callback relocate, err kv.Error) { 451 452 wd, errGo := os.Getwd() 453 if errGo != nil { 454 return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 455 } 456 dir, errGo = filepath.Abs(dir) 457 if errGo != nil { 458 return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 459 } 460 461 if rel, _ := filepath.Rel(wd, dir); rel == "." { 462 return callback, kv.NewError("the relocation directory is the same directory as the target").With("dir", dir).With("current_dir", wd).With("stack", stack.Trace().TrimRuntime()) 463 } 464 465 if errGo = os.Chdir(dir); errGo != nil { 466 return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 467 } 468 469 callback = relocate{ 470 Original: wd, 471 Pop: []relocateTemp{func() (err kv.Error) { 472 if errGo := os.Chdir(wd); errGo != nil { 473 return kv.Wrap(errGo).With("dir", wd).With("stack", stack.Trace().TrimRuntime()) 474 } 475 return nil 476 }}, 477 } 478 479 return callback, nil 480 } 481 482 func relocateToTransitory() (callback relocate, err kv.Error) { 483 484 dir, errGo := ioutil.TempDir("", xid.New().String()) 485 if errGo != nil { 486 return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 487 } 488 489 if callback, err = relocateToTemp(dir); err != nil { 490 return callback, err 491 } 492 493 callback.Pop = append(callback.Pop, func() (err kv.Error) { 494 // Move to an intermediate directory to allow the RemoveAll to occur 495 if errGo := os.Chdir(os.TempDir()); errGo != nil { 496 return kv.Wrap(errGo, "unable to retreat from the directory being deleted").With("dir", dir).With("stack", stack.Trace().TrimRuntime()) 497 } 498 if errGo := os.RemoveAll(dir); errGo != nil { 499 return kv.Wrap(errGo, "unable to retreat from the directory being deleted").With("dir", dir).With("stack", stack.Trace().TrimRuntime()) 500 } 501 return nil 502 }) 503 504 return callback, nil 505 } 506 507 func TestRelocation(t *testing.T) { 508 509 // Keep a record of the directory where we are currently located 510 wd, errGo := os.Getwd() 511 if errGo != nil { 512 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 513 } 514 // Create a test directory 515 dir, errGo := ioutil.TempDir("", xid.New().String()) 516 if errGo != nil { 517 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 518 } 519 defer os.RemoveAll(dir) 520 521 func() { 522 // Relocate to our new directory and then use the construct of a function 523 // to pop back out of the test directory to ensure we are in the right location 524 reloc, err := relocateToTemp(dir) 525 if err != nil { 526 t.Fatal(err) 527 } 528 defer reloc.Close() 529 }() 530 531 // find out where we are and make sure it is where we expect 532 newWD, errGo := os.Getwd() 533 if errGo != nil { 534 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 535 } 536 if wd != newWD { 537 t.Fatal(kv.NewError("relocation could not be reversed").With("origin", wd).With("recovered_to", newWD).With("temp_dir", dir).With("stack", stack.Trace().TrimRuntime())) 538 } 539 } 540 541 func TestNewRelocation(t *testing.T) { 542 543 // Keep a record of the directory where we are currently located 544 wd, errGo := os.Getwd() 545 if errGo != nil { 546 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 547 } 548 549 // Working directory location that is generated by the functions under test 550 tmpDir := "" 551 552 func() { 553 // Relocate to a new directory which has had a temporary name generated on 554 // out behalf as a working area 555 reloc, err := relocateToTransitory() 556 if err != nil { 557 t.Fatal(err) 558 } 559 // Make sure we are sitting in another directory at this point and place a test 560 // file in it so that later we can check that is got cleared 561 tmpDir, errGo = os.Getwd() 562 fn := filepath.Join(tmpDir, "EmptyFile") 563 fl, errGo := os.Create(fn) 564 if errGo != nil { 565 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 566 } 567 msg := "test file that should be gathered up and deleted at the end of the Transitory dir testing" 568 if _, errGo = fl.WriteString(msg); errGo != nil { 569 t.Fatal(kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime())) 570 } 571 fl.Close() 572 573 defer reloc.Close() 574 }() 575 576 // find out where we are and make sure it is where we expect 577 newWD, errGo := os.Getwd() 578 if errGo != nil { 579 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 580 } 581 // Make sure this was not a NOP 582 if wd != newWD { 583 t.Fatal(kv.NewError("relocation could not be reversed").With("origin", wd).With("recovered_to", newWD).With("temp_dir", tmpDir).With("stack", stack.Trace().TrimRuntime())) 584 } 585 586 // Make sure our working directory was cleaned up 587 if _, errGo := os.Stat(tmpDir); !os.IsNotExist(errGo) { 588 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 589 } 590 } 591 592 // prepareExperiment reads an experiment template from the current working directory and 593 // then uses it to prepare the json payload that will be sent as a runner request 594 // data structure to a go runner 595 // 596 func prepareExperiment(gpus int, ignoreK8s bool) (experiment *ExperData, r *runner.Request, err kv.Error) { 597 if !ignoreK8s { 598 if err = setupRMQAdmin(); err != nil { 599 return nil, nil, err 600 } 601 } 602 603 // Parse from the rabbitMQ Settings the username and password that will be available to the templated 604 // request 605 rmqURL, errGo := url.Parse(os.ExpandEnv(*amqpURL)) 606 if errGo != nil { 607 return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 608 } 609 610 slots := 0 611 gpusToUse := []runner.GPUTrack{} 612 if gpus != 0 { 613 // Templates will also have access to details about the GPU cards, upto a max of three 614 // so we find the gpu cards and if found load their capacity and allocation data into the 615 // template data source. These are used for live testing so use any live cards from the runner 616 // 617 invent, err := runner.GPUInventory() 618 if err != nil { 619 return nil, nil, err 620 } 621 if len(invent) < gpus { 622 return nil, nil, kv.NewError("not enough gpu cards for a test").With("needed", gpus).With("actual", len(invent)).With("stack", stack.Trace().TrimRuntime()) 623 } 624 625 // slots will be the total number of slots needed to grab the number of cards specified 626 // by the caller 627 if gpus > 1 { 628 sort.Slice(invent, func(i, j int) bool { return invent[i].FreeSlots < invent[j].FreeSlots }) 629 630 // Get the largest n (gpus) cards that have free slots 631 for i := 0; i != len(invent); i++ { 632 if len(gpusToUse) >= gpus { 633 break 634 } 635 if invent[i].FreeSlots <= 0 || invent[i].EccFailure != nil { 636 continue 637 } 638 639 slots += int(invent[i].FreeSlots) 640 gpusToUse = append(gpusToUse, invent[i]) 641 } 642 if len(gpusToUse) < gpus { 643 return nil, nil, kv.NewError("not enough available gpu cards for a test").With("needed", gpus).With("actual", len(gpusToUse)).With("stack", stack.Trace().TrimRuntime()) 644 } 645 } 646 } 647 // Find as many cards as defined by the caller and include the slots needed to claim them which means 648 // we need the two largest cards to force multiple claims if needed. If the number desired is 1 or 0 649 // then we dont do anything as the experiment template will control what we get 650 651 // Place test files into the serving location for our minio server 652 pass, _ := rmqURL.User.Password() 653 experiment = &ExperData{ 654 RabbitMQUser: rmqURL.User.Username(), 655 RabbitMQPassword: pass, 656 Bucket: xid.New().String(), 657 MinioAddress: runner.MinioTest.Address, 658 MinioUser: runner.MinioTest.AccessKeyId, 659 MinioPassword: runner.MinioTest.SecretAccessKeyId, 660 GPUs: gpusToUse, 661 GPUSlots: slots, 662 } 663 664 // Read a template for the payload that will be sent to run the experiment 665 payload, errGo := ioutil.ReadFile("experiment_template.json") 666 if errGo != nil { 667 return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 668 } 669 tmpl, errGo := template.New("TestBasicRun").Parse(string(payload[:])) 670 if errGo != nil { 671 return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 672 } 673 output := &bytes.Buffer{} 674 if errGo = tmpl.Execute(output, experiment); errGo != nil { 675 return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 676 } 677 678 // Take the string template for the experiment and unmarshall it so that it can be 679 // updated with live test data 680 if r, err = runner.UnmarshalRequest(output.Bytes()); err != nil { 681 return nil, nil, err 682 } 683 684 // If we are not using gpus then purge out the GPU sections of the request template 685 if gpus == 0 { 686 r.Experiment.Resource.Gpus = 0 687 r.Experiment.Resource.GpuMem = "" 688 } 689 690 // Construct a json payload that uses the current wall clock time and also 691 // refers to a locally embedded minio server 692 r.Experiment.TimeAdded = float64(time.Now().Unix()) 693 r.Experiment.TimeLastCheckpoint = nil 694 695 return experiment, r, nil 696 } 697 698 // projectStats will take a collection of metrics, typically retrieved from a local prometheus 699 // source and scan these for details relating to a specific project and experiment 700 // 701 func projectStats(metrics map[string]*model.MetricFamily, qName string, qType string, project string, experiment string) (running int, finished int, err kv.Error) { 702 for family, metric := range metrics { 703 switch metric.GetType() { 704 case model.MetricType_GAUGE: 705 case model.MetricType_COUNTER: 706 default: 707 continue 708 } 709 if strings.HasPrefix(family, "runner_project_") { 710 err = func() (err kv.Error) { 711 vecs := metric.GetMetric() 712 for _, vec := range vecs { 713 func() { 714 for _, label := range vec.GetLabel() { 715 switch label.GetName() { 716 case "experiment": 717 if label.GetValue() != experiment && len(experiment) != 0 { 718 logger.Trace("mismatched", "experiment", experiment, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime()) 719 return 720 } 721 case "host": 722 if label.GetValue() != host { 723 logger.Trace("mismatched", "host", host, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime()) 724 return 725 } 726 case "project": 727 if label.GetValue() != project { 728 logger.Trace("mismatched", "project", project, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime()) 729 return 730 } 731 case "queue_type": 732 if label.GetValue() != qType { 733 logger.Trace("mismatched", "qType", qType, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime()) 734 return 735 } 736 case "queue_name": 737 if !strings.HasSuffix(label.GetValue(), qName) { 738 logger.Trace("mismatched", "qName", qName, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime()) 739 logger.Trace(spew.Sdump(vecs)) 740 return 741 } 742 default: 743 return 744 } 745 } 746 747 logger.Trace("matched prometheus metric", "family", family, "vec", fmt.Sprint(*vec), "stack", stack.Trace().TrimRuntime()) 748 749 // Based on the name of the gauge we will add together quantities, this 750 // is done because the experiment might have been left out 751 // of the inputs and the caller wanted a total for a project 752 switch family { 753 case "runner_project_running": 754 running += int(vec.GetGauge().GetValue()) 755 case "runner_project_completed": 756 finished += int(vec.GetCounter().GetValue()) 757 default: 758 logger.Info("unexpected", "family", family) 759 } 760 }() 761 } 762 return nil 763 }() 764 if err != nil { 765 return 0, 0, err 766 } 767 } 768 } 769 770 return running, finished, nil 771 } 772 773 type waitFunc func(ctx context.Context, qName string, queueType string, r *runner.Request, prometheusPort int) (err kv.Error) 774 775 // waitForRun will check for an experiment to run using the prometheus metrics to 776 // track the progress of the experiment on a regular basis 777 // 778 func waitForRun(ctx context.Context, qName string, queueType string, r *runner.Request, prometheusPort int) (err kv.Error) { 779 // Wait for prometheus to show the task as having been ran and completed 780 pClient := NewPrometheusClient(fmt.Sprintf("http://localhost:%d/metrics", prometheusPort)) 781 782 interval := time.Duration(0) 783 784 // Run around checking the prometheus counters for our experiment seeing when the internal 785 // project tracking says everything has completed, only then go out and get the experiment 786 // results 787 // 788 for { 789 select { 790 case <-time.After(interval): 791 metrics, err := pClient.Fetch("runner_project_") 792 if err != nil { 793 return err 794 } 795 796 runningCnt, finishedCnt, err := projectStats(metrics, qName, queueType, r.Config.Database.ProjectId, r.Experiment.Key) 797 if err != nil { 798 return err 799 } 800 801 // Wait for prometheus to show the task stopped for our specific queue, host, project and experiment ID 802 if runningCnt == 0 && finishedCnt == 1 { 803 return nil 804 } 805 interval = time.Duration(15 * time.Second) 806 } 807 } 808 } 809 810 func createResponseRMQ(qName string, encrypt bool) (err kv.Error) { 811 812 rmq, err := newRMQ(encrypt) 813 if err != nil { 814 return err 815 } 816 817 if err = rmq.QueueDeclare(qName); err != nil { 818 return err 819 } 820 821 return nil 822 } 823 824 func deleteResponseRMQ(qName string, queueType string, routingKey string) (err kv.Error) { 825 rmq, err := newRMQ(false) 826 if err != nil { 827 return err 828 } 829 830 if err = rmq.QueueDestroy(qName); err != nil { 831 return err 832 } 833 834 return nil 835 } 836 837 func newRMQ(encrypted bool) (rmq *runner.RabbitMQ, err kv.Error) { 838 creds := "" 839 840 qURL, errGo := url.Parse(os.ExpandEnv(*amqpURL)) 841 if errGo != nil { 842 return nil, kv.Wrap(errGo).With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime()) 843 } 844 if qURL.User != nil { 845 creds = qURL.User.String() 846 } else { 847 return nil, kv.NewError("missing credentials in url").With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime()) 848 } 849 850 w, err := getWrapper() 851 if encrypted { 852 if err != nil { 853 return nil, err 854 } 855 } 856 857 qURL.User = nil 858 return runner.NewRabbitMQ(qURL.String(), creds, w) 859 } 860 861 func marshallToRMQ(rmq *runner.RabbitMQ, qName string, r *runner.Request) (b []byte, err kv.Error) { 862 if rmq == nil { 863 return nil, kv.NewError("rmq uninitialized").With("stack", stack.Trace().TrimRuntime()) 864 } 865 866 if !rmq.IsEncrypted() { 867 buf, errGo := json.MarshalIndent(r, "", " ") 868 if errGo != nil { 869 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 870 } 871 return buf, nil 872 } 873 // To sign a message use a generated signing public key 874 875 sigs := runner.GetSignatures() 876 sigDir := sigs.Dir() 877 878 if len(sigDir) == 0 { 879 return nil, kv.NewError("signatures directory not ready").With("stack", stack.Trace().TrimRuntime()) 880 } 881 882 pubKey, prvKey, errGo := ed25519.GenerateKey(rand.Reader) 883 if errGo != nil { 884 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 885 } 886 sshKey, errGo := ssh.NewPublicKey(pubKey) 887 if errGo != nil { 888 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 889 } 890 891 // Write the public key 892 keyFile := filepath.Join(sigDir, qName) 893 if errGo = ioutil.WriteFile(keyFile, ssh.MarshalAuthorizedKey(sshKey), 0600); errGo != nil { 894 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 895 } 896 897 // Now wait for the signature package to signal that the keys 898 // have been refreshed and our new file was there 899 <-runner.GetSignaturesRefresh().Done() 900 901 w, err := runner.KubernetesWrapper(*msgEncryptDirOpt) 902 if err != nil { 903 if runner.IsAliveK8s() != nil { 904 return nil, err 905 } 906 } 907 908 envelope, err := w.Envelope(r) 909 if err != nil { 910 return nil, err 911 } 912 913 envelope.Message.Fingerprint = ssh.FingerprintSHA256(sshKey) 914 915 sig, errGo := prvKey.Sign(rand.Reader, []byte(envelope.Message.Payload), crypto.Hash(0)) 916 if errGo != nil { 917 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 918 } 919 logger.Debug("signing produced", "sig", spew.Sdump(sig)) 920 // Encode the base signature into two fields with binary length fromatted 921 // using the SSH RFC method 922 envelope.Message.Signature = base64.StdEncoding.EncodeToString(sig) 923 924 if b, errGo = json.MarshalIndent(envelope, "", " "); errGo != nil { 925 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 926 } 927 return b, nil 928 } 929 930 // publishToRMQ will marshall a go structure containing experiment parameters and 931 // environment information and then send it to the rabbitMQ server this server is configured 932 // to listen to 933 // 934 func publishToRMQ(qName string, queueType string, routingKey string, r *runner.Request, encrypted bool) (err kv.Error) { 935 rmq, err := newRMQ(encrypted) 936 if err != nil { 937 return err 938 } 939 940 if err = rmq.QueueDeclare(qName); err != nil { 941 return err 942 } 943 944 b, err := marshallToRMQ(rmq, qName, r) 945 946 // Send the payload to rabbitMQ 947 return rmq.Publish(routingKey, "application/json", b) 948 } 949 950 func watchResponseQueue(ctx context.Context, qName string, encrypted bool) (msgQ chan *runnerReports.Report, err kv.Error) { 951 deliveryC := make(chan *runnerReports.Report) 952 953 rmq, err := newRMQ(encrypted) 954 if err != nil { 955 return nil, err 956 } 957 958 conn := amqpextra.Dial([]string{rmq.URL() + "%2f"}) 959 consumer := conn.Consumer( 960 qName, 961 amqpextra.WorkerFunc(func(ctx context.Context, msg amqp.Delivery) interface{} { 962 // process message 963 964 report := &runnerReports.Report{} 965 if err := prototext.Unmarshal([]byte(msg.ContentEncoding), report); err != nil { 966 return err 967 } 968 969 if report != nil { 970 logger.Info("report received", "report", spew.Sdump(*report)) 971 } 972 973 select { 974 case deliveryC <- report: 975 case <-time.After(5 * time.Second): 976 msg.Ack(false) 977 return nil 978 } 979 980 msg.Ack(true) 981 982 return nil 983 }), 984 ) 985 consumer.SetWorkerNum(1) 986 consumer.SetContext(ctx) 987 988 return deliveryC, nil 989 } 990 991 func pullReports(ctx context.Context, msgC <-chan *runnerReports.Report) { 992 for { 993 select { 994 case msg := <-msgC: 995 if msg == nil { 996 return 997 } 998 case <-ctx.Done(): 999 return 1000 } 1001 } 1002 } 1003 1004 type validationFunc func(ctx context.Context, experiment *ExperData) (err kv.Error) 1005 1006 // runStudioTest will run a python based experiment and will then present the result to 1007 // a caller supplied validation function 1008 // 1009 func runStudioTest(ctx context.Context, workDir string, gpus int, ignoreK8s bool, useEncryption bool, waiter waitFunc, validation validationFunc) (err kv.Error) { 1010 1011 if !ignoreK8s { 1012 if err = runner.IsAliveK8s(); err != nil { 1013 return err 1014 } 1015 } 1016 1017 timeoutAlive, aliveCancel := context.WithTimeout(ctx, time.Minute) 1018 defer aliveCancel() 1019 1020 // Check that the minio local server has initialized before continuing 1021 if alive, err := runner.MinioTest.IsAlive(timeoutAlive); !alive || err != nil { 1022 if err != nil { 1023 return err 1024 } 1025 return kv.NewError("The minio test server is not available to run this test").With("stack", stack.Trace().TrimRuntime()) 1026 } 1027 logger.Debug("alive checked", "addr", runner.MinioTest.Address) 1028 1029 returnToWD, err := relocateToTemp(workDir) 1030 if err != nil { 1031 return err 1032 } 1033 defer returnToWD.Close() 1034 1035 logger.Debug("test relocated", "workDir", workDir) 1036 1037 experiment, r, err := prepareExperiment(gpus, ignoreK8s) 1038 if err != nil { 1039 return err 1040 } 1041 1042 logger.Debug("experiment prepared") 1043 1044 // Having constructed the payload identify the files within the test template 1045 // directory and save them into a workspace tar archive then 1046 // generate a tar file of the entire workspace directory and upload 1047 // to the minio server that the runner will pull from 1048 if err = uploadWorkspace(experiment); err != nil { 1049 return err 1050 } 1051 1052 logger.Debug("experiment uploaded") 1053 1054 // Cleanup the bucket only after the validation function that was supplied has finished 1055 defer runner.MinioTest.RemoveBucketAll(experiment.Bucket) 1056 1057 // Generate queue names that will be used for this test case 1058 queueType := "rmq" 1059 qName := queueType + "_Multipart_" + xid.New().String() 1060 routingKey := "StudioML." + qName 1061 1062 // Create and listen to the response queue which will receive messages 1063 // from the worker 1064 if err = createResponseRMQ(qName+"_response", useEncryption); err != nil { 1065 return err 1066 } 1067 defer deleteResponseRMQ(qName+"_response", queueType, routingKey) 1068 1069 responseCtx, cancelResponse := context.WithCancel(context.Background()) 1070 defer cancelResponse() 1071 1072 msgC, err := watchResponseQueue(responseCtx, string(qName+"_response"), useEncryption) 1073 if err != nil { 1074 return err 1075 } 1076 1077 go pullReports(responseCtx, msgC) 1078 1079 logger.Debug("test initiated", "queue", qName, "stack", stack.Trace().TrimRuntime()) 1080 1081 // Now that the file needed is present on the minio server send the 1082 // experiment specification message to the worker using a new queue 1083 1084 if err = publishToRMQ(qName, queueType, routingKey, r, useEncryption); err != nil { 1085 return err 1086 } 1087 1088 logger.Debug("test waiting", "queue", qName, "stack", stack.Trace().TrimRuntime()) 1089 1090 if err = waiter(ctx, qName, queueType, r, prometheusPort); err != nil { 1091 return err 1092 } 1093 1094 // Query minio for the resulting output and compare it with the expected 1095 return validation(ctx, experiment) 1096 } 1097 1098 // TestÄE2EExperimentRun is a function used to exercise the core ability of the runner to successfully 1099 // complete a single experiment. The name of the test uses a Latin A with Diaresis to order this 1100 // test after others that are simpler in nature. 1101 // 1102 // This test take a minute or two but is left to run in the short version of testing because 1103 // it exercises the entire system under test end to end for experiments running in the python 1104 // environment 1105 // 1106 func TestÄE2ECPUExperimentRun(t *testing.T) { 1107 E2EExperimentRun(t, 0) 1108 } 1109 1110 func TestÄE2EGPUExperimentRun(t *testing.T) { 1111 if !*runner.UseGPU { 1112 logger.Warn("TestÄE2EExperimentRun not run") 1113 t.Skip("GPUs disabled for testing") 1114 } 1115 E2EExperimentRun(t, 1) 1116 1117 } 1118 1119 func E2EExperimentRun(t *testing.T, gpusNeeded int) { 1120 1121 if !*useK8s { 1122 t.Skip("kubernetes specific testing disabled") 1123 } 1124 1125 gpuCount := runner.GPUCount() 1126 if gpusNeeded > gpuCount { 1127 t.Skipf("insufficient GPUs %d, needed %d", gpuCount, gpusNeeded) 1128 } 1129 1130 cases := []struct { 1131 useEncrypt bool 1132 }{ 1133 {useEncrypt: true}, 1134 {useEncrypt: false}, 1135 } 1136 1137 for _, aCase := range cases { 1138 wd, errGo := os.Getwd() 1139 if errGo != nil { 1140 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 1141 } 1142 // Navigate to the assets directory being used for this experiment 1143 workDir, errGo := filepath.Abs(filepath.Join(wd, "..", "..", "assets", "tf_minimal")) 1144 if errGo != nil { 1145 t.Fatal(errGo) 1146 } 1147 1148 if err := runStudioTest(context.Background(), workDir, gpusNeeded, false, aCase.useEncrypt, waitForRun, validateTFMinimal); err != nil { 1149 t.Fatal(err) 1150 } 1151 1152 // Make sure we returned to the directory we expected 1153 newWD, errGo := os.Getwd() 1154 if errGo != nil { 1155 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 1156 } 1157 if newWD != wd { 1158 t.Fatal(kv.NewError("finished in an unexpected directory").With("expected_dir", wd).With("actual_dir", newWD).With("stack", stack.Trace().TrimRuntime())) 1159 } 1160 } 1161 } 1162 1163 func validatePytorchMultiGPU(ctx context.Context, experiment *ExperData) (err kv.Error) { 1164 // Unpack the output archive within a temporary directory and use it for validation 1165 dir, errGo := ioutil.TempDir("", xid.New().String()) 1166 if errGo != nil { 1167 return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()) 1168 } 1169 defer os.RemoveAll(dir) 1170 1171 output := filepath.Join(dir, "output.tar") 1172 if err = downloadOutput(ctx, experiment, output); err != nil { 1173 return err 1174 } 1175 1176 // Now examine the file for successfully running the python code 1177 if errGo = archiver.Tar.Open(output, dir); errGo != nil { 1178 return kv.Wrap(errGo).With("file", output).With("stack", stack.Trace().TrimRuntime()) 1179 } 1180 1181 outFn := filepath.Join(dir, "output") 1182 outFile, errGo := os.Open(outFn) 1183 if errGo != nil { 1184 return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime()) 1185 } 1186 1187 supressDump := false 1188 defer func() { 1189 if !supressDump { 1190 io.Copy(os.Stdout, outFile) 1191 } 1192 outFile.Close() 1193 }() 1194 1195 validateString := fmt.Sprintf("(\"Let's use\", %dL, 'GPUs!')", len(experiment.GPUs)) 1196 err = kv.NewError("multiple gpu logging not found").With("log", validateString).With("stack", stack.Trace().TrimRuntime()) 1197 1198 scanner := bufio.NewScanner(outFile) 1199 for scanner.Scan() { 1200 if strings.Contains(scanner.Text(), validateString) { 1201 supressDump = true 1202 err = nil 1203 break 1204 } 1205 } 1206 if errGo = scanner.Err(); errGo != nil { 1207 return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime()) 1208 } 1209 1210 return err 1211 } 1212 1213 // TestÄE2EPytorchMGPURun is a function used to exercise the multi GPU ability of the runner to successfully 1214 // complete a single pytorch multi GPU experiment. The name of the test uses a Latin A with Diaresis to order this 1215 // test after others that are simpler in nature. 1216 // 1217 // This test take a minute or two but is left to run in the short version of testing because 1218 // it exercises the entire system under test end to end for experiments running in the python 1219 // environment 1220 // 1221 func TestÄE2EPytorchMGPURun(t *testing.T) { 1222 1223 if !*useK8s { 1224 t.Skip("kubernetes specific testing disabled") 1225 } 1226 1227 if !*runner.UseGPU { 1228 logger.Warn("TestÄE2EPytorchMGPURun not run") 1229 t.Skip("GPUs disabled for testing") 1230 } 1231 1232 wd, errGo := os.Getwd() 1233 if errGo != nil { 1234 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 1235 } 1236 1237 gpusNeeded := 2 1238 gpuCount := runner.GPUCount() 1239 if gpusNeeded > gpuCount { 1240 t.Skipf("insufficient GPUs %d, needed %d", gpuCount, gpusNeeded) 1241 } 1242 1243 // Navigate to the assets directory being used for this experiment 1244 workDir, errGo := filepath.Abs(filepath.Join(wd, "..", "..", "assets", "pytorch_mgpu")) 1245 if errGo != nil { 1246 t.Fatal(errGo) 1247 } 1248 1249 if err := runStudioTest(context.Background(), workDir, 2, false, false, waitForRun, validatePytorchMultiGPU); err != nil { 1250 t.Fatal(err) 1251 } 1252 1253 // Make sure we returned to the directory we expected 1254 newWD, errGo := os.Getwd() 1255 if errGo != nil { 1256 t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())) 1257 } 1258 if newWD != wd { 1259 t.Fatal(kv.NewError("finished in an unexpected directory").With("expected_dir", wd).With("actual_dir", newWD).With("stack", stack.Trace().TrimRuntime())) 1260 } 1261 }