github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/cmd/runner/runner_test.go (about)

     1  // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     2  
     3  package main
     4  
     5  // This file contains the implementation of tests related to starting python based work and
     6  // running it to completion within the server.  Work run is prepackaged within the source
     7  // code repository and orchestrated by the testing within this file.
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"crypto"
    14  	"crypto/rand"
    15  	"encoding/base64"
    16  	"encoding/json"
    17  	"fmt"
    18  	"html/template"
    19  	"io"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"net/url"
    23  	"os"
    24  	"path"
    25  	"path/filepath"
    26  	"regexp"
    27  	"sort"
    28  	"strconv"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/leaf-ai/studio-go-runner/internal/runner"
    34  
    35  	runnerReports "github.com/leaf-ai/studio-go-runner/internal/gen/dev.cognizant_dev.ai/genproto/studio-go-runner/reports/v1"
    36  
    37  	"google.golang.org/protobuf/encoding/prototext"
    38  
    39  	"golang.org/x/crypto/ed25519"
    40  	"golang.org/x/crypto/ssh"
    41  
    42  	"github.com/davecgh/go-spew/spew"
    43  	"github.com/go-stack/stack"
    44  	"github.com/jjeffery/kv" // MIT License
    45  
    46  	minio "github.com/minio/minio-go"
    47  
    48  	"github.com/mholt/archiver"
    49  	model "github.com/prometheus/client_model/go"
    50  	"github.com/rs/xid"
    51  
    52  	"github.com/makasim/amqpextra"
    53  	"github.com/streadway/amqp"
    54  )
    55  
    56  var (
    57  	// Extracts three floating point values from a tensorflow output line typical for the experiments
    58  	// found within the tf packages.  A typical log line will appear as follows
    59  	// '60000/60000 [==============================] - 1s 23us/step - loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355'
    60  	tfExtract = regexp.MustCompile(`(?mU).*loss:\s([-+]?[0-9]*\.[0-9]*)\s.*acc:\s([-+]?[0-9]*\.[0-9]*)\s.*val_loss:\s([-+]?[0-9]*\.[0-9]*)\s.*val_acc:\s([-+]?[0-9]*\.[0-9]*)$`)
    61  )
    62  
    63  func TestATFExtractilargeon(t *testing.T) {
    64  	tfResultsExample := `60000/60000 [==============================] - 1s 23us/step - loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355`
    65  
    66  	expectedOutput := []string{
    67  		tfResultsExample,
    68  		"0.2432",
    69  		"0.9313",
    70  		"0.2316",
    71  		"0.9355",
    72  	}
    73  
    74  	matches := tfExtract.FindAllStringSubmatch(tfResultsExample, -1)
    75  	for i, match := range expectedOutput {
    76  		if matches[0][i] != match {
    77  			t.Fatal(kv.NewError("a tensorflow result not extracted").With("expected", match).With("captured_match", matches[0][i]).With("stack", stack.Trace().TrimRuntime()))
    78  		}
    79  	}
    80  }
    81  
    82  type ExperData struct {
    83  	RabbitMQUser     string
    84  	RabbitMQPassword string
    85  	Bucket           string
    86  	MinioAddress     string
    87  	MinioUser        string
    88  	MinioPassword    string
    89  	GPUs             []runner.GPUTrack
    90  	GPUSlots         int
    91  }
    92  
    93  // downloadFile will download a url to a local file using streaming.
    94  //
    95  func downloadFile(fn string, download string) (err kv.Error) {
    96  
    97  	// Create the file
    98  	out, errGo := os.Create(fn)
    99  	if errGo != nil {
   100  		return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
   101  	}
   102  	defer out.Close()
   103  
   104  	// Get the data
   105  	resp, errGo := http.Get(download)
   106  	if errGo != nil {
   107  		return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
   108  	}
   109  	defer resp.Body.Close()
   110  
   111  	// Write the body to file
   112  	_, errGo = io.Copy(out, resp.Body)
   113  	if errGo != nil {
   114  		return kv.Wrap(errGo).With("url", download).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func downloadRMQCli(fn string) (err kv.Error) {
   121  	if err = downloadFile(fn, os.ExpandEnv("http://${RABBITMQ_SERVICE_SERVICE_HOST}:${RABBITMQ_SERVICE_SERVICE_PORT_RMQ_ADMIN}/cli/rabbitmqadmin")); err != nil {
   122  		return err
   123  	}
   124  	// Having downloaded the administration CLI tool set it to be executable
   125  	if errGo := os.Chmod(fn, 0777); errGo != nil {
   126  		return kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime())
   127  	}
   128  	return nil
   129  }
   130  
   131  // setupRMQ will download the rabbitMQ administration tool from the k8s deployed rabbitMQ
   132  // server and place it into the project bin directory setting it to executable in order
   133  // that diagnostic commands can be run using the shell
   134  //
   135  func setupRMQAdmin() (err kv.Error) {
   136  	rmqAdmin := path.Join("/project", "bin")
   137  	fi, errGo := os.Stat(rmqAdmin)
   138  	if errGo != nil {
   139  		return kv.Wrap(errGo).With("dir", rmqAdmin).With("stack", stack.Trace().TrimRuntime())
   140  	}
   141  	if !fi.IsDir() {
   142  		return kv.NewError("specified directory is not actually a directory").With("dir", rmqAdmin).With("stack", stack.Trace().TrimRuntime())
   143  	}
   144  
   145  	// Look for the rabbitMQ Server and download the command line tools for use
   146  	// in diagnosing issues, and do this before changing into the test directory
   147  	rmqAdmin = filepath.Join(rmqAdmin, "rabbitmqadmin")
   148  	return downloadRMQCli(rmqAdmin)
   149  }
   150  
   151  func collectUploadFiles(dir string) (files []string, err kv.Error) {
   152  
   153  	errGo := filepath.Walk(".",
   154  		func(path string, info os.FileInfo, err error) error {
   155  			files = append(files, path)
   156  			return nil
   157  		})
   158  
   159  	if errGo != nil {
   160  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   161  	}
   162  	sort.Strings(files)
   163  
   164  	return files, nil
   165  }
   166  
   167  func uploadWorkspace(experiment *ExperData) (err kv.Error) {
   168  
   169  	wd, _ := os.Getwd()
   170  	logger.Debug("uploading", "dir", wd, "experiment", *experiment, "stack", stack.Trace().TrimRuntime())
   171  
   172  	dir := "."
   173  	files, err := collectUploadFiles(dir)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if len(files) == 0 {
   178  		return kv.NewError("no files found").With("directory", dir).With("stack", stack.Trace().TrimRuntime())
   179  	}
   180  
   181  	// Pack the files needed into an archive within a temporary directory
   182  	dir, errGo := ioutil.TempDir("", xid.New().String())
   183  	if errGo != nil {
   184  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   185  	}
   186  	defer os.RemoveAll(dir)
   187  
   188  	archiveName := filepath.Join(dir, "workspace.tar")
   189  
   190  	if errGo = archiver.Tar.Make(archiveName, files); errGo != nil {
   191  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   192  	}
   193  
   194  	// Now we have the workspace for upload go ahead and contact the minio server
   195  	mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false)
   196  	if errGo != nil {
   197  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   198  	}
   199  
   200  	archive, errGo := os.Open(archiveName)
   201  	if errGo != nil {
   202  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   203  	}
   204  	defer archive.Close()
   205  
   206  	fileStat, errGo := archive.Stat()
   207  	if errGo != nil {
   208  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   209  	}
   210  
   211  	// Create the bucket that will be used by the experiment, and then place the workspace into it
   212  	if errGo = mc.MakeBucket(experiment.Bucket, ""); errGo != nil {
   213  		switch minio.ToErrorResponse(errGo).Code {
   214  		case "BucketAlreadyExists":
   215  		case "BucketAlreadyOwnedByYou":
   216  		default:
   217  			return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   218  		}
   219  	}
   220  
   221  	_, errGo = mc.PutObject(experiment.Bucket, "workspace.tar", archive, fileStat.Size(),
   222  		minio.PutObjectOptions{
   223  			ContentType: "application/octet-stream",
   224  		})
   225  	if errGo != nil {
   226  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   227  	}
   228  	return nil
   229  }
   230  
   231  func validateTFMinimal(ctx context.Context, experiment *ExperData) (err kv.Error) {
   232  	// Unpack the output archive within a temporary directory and use it for validation
   233  	dir, errGo := ioutil.TempDir("", xid.New().String())
   234  	if errGo != nil {
   235  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   236  	}
   237  	defer os.RemoveAll(dir)
   238  
   239  	output := filepath.Join(dir, "output.tar")
   240  	if err = downloadOutput(ctx, experiment, output); err != nil {
   241  		return err
   242  	}
   243  
   244  	// Now examine the file for successfully running the python code
   245  	if errGo = archiver.Tar.Open(output, dir); errGo != nil {
   246  		return kv.Wrap(errGo).With("file", output).With("stack", stack.Trace().TrimRuntime())
   247  	}
   248  
   249  	outFn := filepath.Join(dir, "output")
   250  	outFile, errGo := os.Open(outFn)
   251  	if errGo != nil {
   252  		return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime())
   253  	}
   254  
   255  	supressDump := false
   256  	defer func() {
   257  		if !supressDump {
   258  			io.Copy(os.Stdout, outFile)
   259  		}
   260  		outFile.Close()
   261  	}()
   262  
   263  	// Typical values for these items inside the TF logging are as follows
   264  	// "loss: 0.2432 - acc: 0.9313 - val_loss: 0.2316 - val_acc: 0.9355"
   265  	acceptableVals := []float64{
   266  		0.35,
   267  		0.85,
   268  		0.35,
   269  		0.85,
   270  	}
   271  
   272  	matches := [][]string{}
   273  	scanner := bufio.NewScanner(outFile)
   274  	for scanner.Scan() {
   275  		matched := tfExtract.FindAllStringSubmatch(scanner.Text(), -1)
   276  		if len(matched) != 1 {
   277  			continue
   278  		}
   279  		if len(matched[0]) != 5 {
   280  			continue
   281  		}
   282  		matches = matched
   283  	}
   284  	if errGo = scanner.Err(); errGo != nil {
   285  		return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime())
   286  	}
   287  
   288  	if len(matches) != 1 {
   289  		outFile.Seek(0, io.SeekStart)
   290  		io.Copy(os.Stdout, outFile)
   291  		return kv.NewError("unable to find any TF results in the log file").With("file", outFn).With("stack", stack.Trace().TrimRuntime())
   292  	}
   293  
   294  	// Although the following values are not using epsilon style float adjustments because
   295  	// the test limits and values are abitrary anyway
   296  
   297  	// loss andf accuracy checks against the log data that was extracted using a regular expression
   298  	// and captures
   299  	loss, errGo := strconv.ParseFloat(matches[0][1], 64)
   300  	if errGo != nil {
   301  		return kv.Wrap(errGo).With("file", outFn).With("line", scanner.Text()).With("value", matches[0][1]).With("stack", stack.Trace().TrimRuntime())
   302  	}
   303  	if loss > acceptableVals[1] {
   304  		return kv.NewError("loss is too large").With("file", outFn).With("line", scanner.Text()).With("value", loss).With("ceiling", acceptableVals[1]).With("stack", stack.Trace().TrimRuntime())
   305  	}
   306  	loss, errGo = strconv.ParseFloat(matches[0][3], 64)
   307  	if errGo != nil {
   308  		return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][3]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime())
   309  	}
   310  	if loss > acceptableVals[3] {
   311  		return kv.NewError("validation loss is too large").With("file", outFn).With("line", scanner.Text()).With("value", loss).With("ceiling", acceptableVals[3]).With("stack", stack.Trace().TrimRuntime())
   312  	}
   313  	// accuracy checks
   314  	accu, errGo := strconv.ParseFloat(matches[0][2], 64)
   315  	if errGo != nil {
   316  		return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][2]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime())
   317  	}
   318  	if accu < acceptableVals[2] {
   319  		return kv.NewError("accuracy is too small").With("file", outFn).With("line", scanner.Text()).With("value", accu).With("ceiling", acceptableVals[2]).With("stack", stack.Trace().TrimRuntime())
   320  	}
   321  	accu, errGo = strconv.ParseFloat(matches[0][4], 64)
   322  	if errGo != nil {
   323  		return kv.Wrap(errGo).With("file", outFn).With("value", matches[0][4]).With("line", scanner.Text()).With("stack", stack.Trace().TrimRuntime())
   324  	}
   325  	if accu < acceptableVals[3] {
   326  		return kv.NewError("validation accuracy is too small").With("file", outFn).With("line", scanner.Text()).With("value", accu).With("ceiling", acceptableVals[3]).With("stack", stack.Trace().TrimRuntime())
   327  	}
   328  
   329  	logger.Info(matches[0][0], "stack", stack.Trace().TrimRuntime())
   330  	supressDump = true
   331  
   332  	return nil
   333  }
   334  
   335  func lsMetadata(ctx context.Context, experiment *ExperData) (names []string, err kv.Error) {
   336  	names = []string{}
   337  
   338  	// Now we have the workspace for upload go ahead and contact the minio server
   339  	mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false)
   340  	if errGo != nil {
   341  		return names, kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime())
   342  	}
   343  	// Create a done channel to control 'ListObjects' go routine.
   344  	doneCh := make(chan struct{})
   345  
   346  	// Indicate to our routine to exit cleanly upon return.
   347  	defer close(doneCh)
   348  
   349  	isRecursive := true
   350  	prefix := "metadata/"
   351  	objectCh := mc.ListObjects(experiment.Bucket, prefix, isRecursive, doneCh)
   352  	for object := range objectCh {
   353  		if object.Err != nil {
   354  			return names, kv.Wrap(object.Err).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime())
   355  		}
   356  		names = append(names, fmt.Sprint(object.Key))
   357  	}
   358  	return names, nil
   359  }
   360  
   361  func downloadMetadata(ctx context.Context, experiment *ExperData, outputDir string) (err kv.Error) {
   362  	// Now we have the workspace for upload go ahead and contact the minio server
   363  	mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false)
   364  	if errGo != nil {
   365  		return kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime())
   366  	}
   367  	// Create a done channel to control 'ListObjects' go routine.
   368  	doneCh := make(chan struct{})
   369  
   370  	// Indicate to our routine to exit cleanly upon return.
   371  	defer close(doneCh)
   372  
   373  	names := []string{}
   374  
   375  	isRecursive := true
   376  	prefix := "metadata/"
   377  	objectCh := mc.ListObjects(experiment.Bucket, prefix, isRecursive, doneCh)
   378  	for object := range objectCh {
   379  		if object.Err != nil {
   380  			return kv.Wrap(object.Err).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime())
   381  		}
   382  		names = append(names, filepath.Base(object.Key))
   383  	}
   384  
   385  	for _, name := range names {
   386  		key := prefix + name
   387  		object, errGo := mc.GetObject(experiment.Bucket, key, minio.GetObjectOptions{})
   388  		if errGo != nil {
   389  			return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "name", name).With("stack", stack.Trace().TrimRuntime())
   390  		}
   391  		localName := filepath.Join(outputDir, filepath.Base(name))
   392  		localFile, errGo := os.Create(localName)
   393  		if errGo != nil {
   394  			return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "key", key, "filename", localName).With("stack", stack.Trace().TrimRuntime())
   395  		}
   396  		if _, errGo = io.Copy(localFile, object); errGo != nil {
   397  			return kv.Wrap(errGo).With("address", experiment.MinioAddress, "bucket", experiment.Bucket, "key", key, "filename", localName).With("stack", stack.Trace().TrimRuntime())
   398  		}
   399  	}
   400  	return nil
   401  }
   402  
   403  func downloadOutput(ctx context.Context, experiment *ExperData, output string) (err kv.Error) {
   404  
   405  	archive, errGo := os.Create(output)
   406  	if errGo != nil {
   407  		return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime())
   408  	}
   409  	defer archive.Close()
   410  
   411  	// Now we have the workspace for upload go ahead and contact the minio server
   412  	mc, errGo := minio.New(experiment.MinioAddress, experiment.MinioUser, experiment.MinioPassword, false)
   413  	if errGo != nil {
   414  		return kv.Wrap(errGo).With("address", experiment.MinioAddress).With("stack", stack.Trace().TrimRuntime())
   415  	}
   416  
   417  	object, errGo := mc.GetObjectWithContext(ctx, experiment.Bucket, "output.tar", minio.GetObjectOptions{})
   418  	if errGo != nil {
   419  		return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime())
   420  	}
   421  
   422  	if _, errGo = io.Copy(archive, object); errGo != nil {
   423  		return kv.Wrap(errGo).With("output", output).With("stack", stack.Trace().TrimRuntime())
   424  	}
   425  
   426  	return nil
   427  }
   428  
   429  type relocateTemp func() (err kv.Error)
   430  
   431  type relocate struct {
   432  	Original string
   433  	Pop      []relocateTemp
   434  }
   435  
   436  func (r *relocate) Close() (err kv.Error) {
   437  	if r == nil {
   438  		return nil
   439  	}
   440  	// Iterate the list of call backs in reverse order when exiting
   441  	// the stack of things that were done as a LIFO
   442  	for i := len(r.Pop) - 1; i >= 0; i-- {
   443  		if err = r.Pop[i](); err != nil {
   444  			return err
   445  		}
   446  	}
   447  	return nil
   448  }
   449  
   450  func relocateToTemp(dir string) (callback relocate, err kv.Error) {
   451  
   452  	wd, errGo := os.Getwd()
   453  	if errGo != nil {
   454  		return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   455  	}
   456  	dir, errGo = filepath.Abs(dir)
   457  	if errGo != nil {
   458  		return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   459  	}
   460  
   461  	if rel, _ := filepath.Rel(wd, dir); rel == "." {
   462  		return callback, kv.NewError("the relocation directory is the same directory as the target").With("dir", dir).With("current_dir", wd).With("stack", stack.Trace().TrimRuntime())
   463  	}
   464  
   465  	if errGo = os.Chdir(dir); errGo != nil {
   466  		return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   467  	}
   468  
   469  	callback = relocate{
   470  		Original: wd,
   471  		Pop: []relocateTemp{func() (err kv.Error) {
   472  			if errGo := os.Chdir(wd); errGo != nil {
   473  				return kv.Wrap(errGo).With("dir", wd).With("stack", stack.Trace().TrimRuntime())
   474  			}
   475  			return nil
   476  		}},
   477  	}
   478  
   479  	return callback, nil
   480  }
   481  
   482  func relocateToTransitory() (callback relocate, err kv.Error) {
   483  
   484  	dir, errGo := ioutil.TempDir("", xid.New().String())
   485  	if errGo != nil {
   486  		return callback, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   487  	}
   488  
   489  	if callback, err = relocateToTemp(dir); err != nil {
   490  		return callback, err
   491  	}
   492  
   493  	callback.Pop = append(callback.Pop, func() (err kv.Error) {
   494  		// Move to an intermediate directory to allow the RemoveAll to occur
   495  		if errGo := os.Chdir(os.TempDir()); errGo != nil {
   496  			return kv.Wrap(errGo, "unable to retreat from the directory being deleted").With("dir", dir).With("stack", stack.Trace().TrimRuntime())
   497  		}
   498  		if errGo := os.RemoveAll(dir); errGo != nil {
   499  			return kv.Wrap(errGo, "unable to retreat from the directory being deleted").With("dir", dir).With("stack", stack.Trace().TrimRuntime())
   500  		}
   501  		return nil
   502  	})
   503  
   504  	return callback, nil
   505  }
   506  
   507  func TestRelocation(t *testing.T) {
   508  
   509  	// Keep a record of the directory where we are currently located
   510  	wd, errGo := os.Getwd()
   511  	if errGo != nil {
   512  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   513  	}
   514  	// Create a test directory
   515  	dir, errGo := ioutil.TempDir("", xid.New().String())
   516  	if errGo != nil {
   517  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   518  	}
   519  	defer os.RemoveAll(dir)
   520  
   521  	func() {
   522  		// Relocate to our new directory and then use the construct of a function
   523  		// to pop back out of the test directory to ensure we are in the right location
   524  		reloc, err := relocateToTemp(dir)
   525  		if err != nil {
   526  			t.Fatal(err)
   527  		}
   528  		defer reloc.Close()
   529  	}()
   530  
   531  	// find out where we are and make sure it is where we expect
   532  	newWD, errGo := os.Getwd()
   533  	if errGo != nil {
   534  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   535  	}
   536  	if wd != newWD {
   537  		t.Fatal(kv.NewError("relocation could not be reversed").With("origin", wd).With("recovered_to", newWD).With("temp_dir", dir).With("stack", stack.Trace().TrimRuntime()))
   538  	}
   539  }
   540  
   541  func TestNewRelocation(t *testing.T) {
   542  
   543  	// Keep a record of the directory where we are currently located
   544  	wd, errGo := os.Getwd()
   545  	if errGo != nil {
   546  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   547  	}
   548  
   549  	// Working directory location that is generated by the functions under test
   550  	tmpDir := ""
   551  
   552  	func() {
   553  		// Relocate to a new directory which has had a temporary name generated on
   554  		// out behalf as a working area
   555  		reloc, err := relocateToTransitory()
   556  		if err != nil {
   557  			t.Fatal(err)
   558  		}
   559  		// Make sure we are sitting in another directory at this point and place a test
   560  		// file in it so that later we can check that is got cleared
   561  		tmpDir, errGo = os.Getwd()
   562  		fn := filepath.Join(tmpDir, "EmptyFile")
   563  		fl, errGo := os.Create(fn)
   564  		if errGo != nil {
   565  			t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   566  		}
   567  		msg := "test file that should be gathered up and deleted at the end of the Transitory dir testing"
   568  		if _, errGo = fl.WriteString(msg); errGo != nil {
   569  			t.Fatal(kv.Wrap(errGo).With("filename", fn).With("stack", stack.Trace().TrimRuntime()))
   570  		}
   571  		fl.Close()
   572  
   573  		defer reloc.Close()
   574  	}()
   575  
   576  	// find out where we are and make sure it is where we expect
   577  	newWD, errGo := os.Getwd()
   578  	if errGo != nil {
   579  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   580  	}
   581  	// Make sure this was not a NOP
   582  	if wd != newWD {
   583  		t.Fatal(kv.NewError("relocation could not be reversed").With("origin", wd).With("recovered_to", newWD).With("temp_dir", tmpDir).With("stack", stack.Trace().TrimRuntime()))
   584  	}
   585  
   586  	// Make sure our working directory was cleaned up
   587  	if _, errGo := os.Stat(tmpDir); !os.IsNotExist(errGo) {
   588  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
   589  	}
   590  }
   591  
   592  // prepareExperiment reads an experiment template from the current working directory and
   593  // then uses it to prepare the json payload that will be sent as a runner request
   594  // data structure to a go runner
   595  //
   596  func prepareExperiment(gpus int, ignoreK8s bool) (experiment *ExperData, r *runner.Request, err kv.Error) {
   597  	if !ignoreK8s {
   598  		if err = setupRMQAdmin(); err != nil {
   599  			return nil, nil, err
   600  		}
   601  	}
   602  
   603  	// Parse from the rabbitMQ Settings the username and password that will be available to the templated
   604  	// request
   605  	rmqURL, errGo := url.Parse(os.ExpandEnv(*amqpURL))
   606  	if errGo != nil {
   607  		return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   608  	}
   609  
   610  	slots := 0
   611  	gpusToUse := []runner.GPUTrack{}
   612  	if gpus != 0 {
   613  		// Templates will also have access to details about the GPU cards, upto a max of three
   614  		// so we find the gpu cards and if found load their capacity and allocation data into the
   615  		// template data source.  These are used for live testing so use any live cards from the runner
   616  		//
   617  		invent, err := runner.GPUInventory()
   618  		if err != nil {
   619  			return nil, nil, err
   620  		}
   621  		if len(invent) < gpus {
   622  			return nil, nil, kv.NewError("not enough gpu cards for a test").With("needed", gpus).With("actual", len(invent)).With("stack", stack.Trace().TrimRuntime())
   623  		}
   624  
   625  		// slots will be the total number of slots needed to grab the number of cards specified
   626  		// by the caller
   627  		if gpus > 1 {
   628  			sort.Slice(invent, func(i, j int) bool { return invent[i].FreeSlots < invent[j].FreeSlots })
   629  
   630  			// Get the largest n (gpus) cards that have free slots
   631  			for i := 0; i != len(invent); i++ {
   632  				if len(gpusToUse) >= gpus {
   633  					break
   634  				}
   635  				if invent[i].FreeSlots <= 0 || invent[i].EccFailure != nil {
   636  					continue
   637  				}
   638  
   639  				slots += int(invent[i].FreeSlots)
   640  				gpusToUse = append(gpusToUse, invent[i])
   641  			}
   642  			if len(gpusToUse) < gpus {
   643  				return nil, nil, kv.NewError("not enough available gpu cards for a test").With("needed", gpus).With("actual", len(gpusToUse)).With("stack", stack.Trace().TrimRuntime())
   644  			}
   645  		}
   646  	}
   647  	// Find as many cards as defined by the caller and include the slots needed to claim them which means
   648  	// we need the two largest cards to force multiple claims if needed.  If the  number desired is 1 or 0
   649  	// then we dont do anything as the experiment template will control what we get
   650  
   651  	// Place test files into the serving location for our minio server
   652  	pass, _ := rmqURL.User.Password()
   653  	experiment = &ExperData{
   654  		RabbitMQUser:     rmqURL.User.Username(),
   655  		RabbitMQPassword: pass,
   656  		Bucket:           xid.New().String(),
   657  		MinioAddress:     runner.MinioTest.Address,
   658  		MinioUser:        runner.MinioTest.AccessKeyId,
   659  		MinioPassword:    runner.MinioTest.SecretAccessKeyId,
   660  		GPUs:             gpusToUse,
   661  		GPUSlots:         slots,
   662  	}
   663  
   664  	// Read a template for the payload that will be sent to run the experiment
   665  	payload, errGo := ioutil.ReadFile("experiment_template.json")
   666  	if errGo != nil {
   667  		return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   668  	}
   669  	tmpl, errGo := template.New("TestBasicRun").Parse(string(payload[:]))
   670  	if errGo != nil {
   671  		return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   672  	}
   673  	output := &bytes.Buffer{}
   674  	if errGo = tmpl.Execute(output, experiment); errGo != nil {
   675  		return nil, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   676  	}
   677  
   678  	// Take the string template for the experiment and unmarshall it so that it can be
   679  	// updated with live test data
   680  	if r, err = runner.UnmarshalRequest(output.Bytes()); err != nil {
   681  		return nil, nil, err
   682  	}
   683  
   684  	// If we are not using gpus then purge out the GPU sections of the request template
   685  	if gpus == 0 {
   686  		r.Experiment.Resource.Gpus = 0
   687  		r.Experiment.Resource.GpuMem = ""
   688  	}
   689  
   690  	// Construct a json payload that uses the current wall clock time and also
   691  	// refers to a locally embedded minio server
   692  	r.Experiment.TimeAdded = float64(time.Now().Unix())
   693  	r.Experiment.TimeLastCheckpoint = nil
   694  
   695  	return experiment, r, nil
   696  }
   697  
   698  // projectStats will take a collection of metrics, typically retrieved from a local prometheus
   699  // source and scan these for details relating to a specific project and experiment
   700  //
   701  func projectStats(metrics map[string]*model.MetricFamily, qName string, qType string, project string, experiment string) (running int, finished int, err kv.Error) {
   702  	for family, metric := range metrics {
   703  		switch metric.GetType() {
   704  		case model.MetricType_GAUGE:
   705  		case model.MetricType_COUNTER:
   706  		default:
   707  			continue
   708  		}
   709  		if strings.HasPrefix(family, "runner_project_") {
   710  			err = func() (err kv.Error) {
   711  				vecs := metric.GetMetric()
   712  				for _, vec := range vecs {
   713  					func() {
   714  						for _, label := range vec.GetLabel() {
   715  							switch label.GetName() {
   716  							case "experiment":
   717  								if label.GetValue() != experiment && len(experiment) != 0 {
   718  									logger.Trace("mismatched", "experiment", experiment, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime())
   719  									return
   720  								}
   721  							case "host":
   722  								if label.GetValue() != host {
   723  									logger.Trace("mismatched", "host", host, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime())
   724  									return
   725  								}
   726  							case "project":
   727  								if label.GetValue() != project {
   728  									logger.Trace("mismatched", "project", project, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime())
   729  									return
   730  								}
   731  							case "queue_type":
   732  								if label.GetValue() != qType {
   733  									logger.Trace("mismatched", "qType", qType, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime())
   734  									return
   735  								}
   736  							case "queue_name":
   737  								if !strings.HasSuffix(label.GetValue(), qName) {
   738  									logger.Trace("mismatched", "qName", qName, "value", label.GetValue(), "stack", stack.Trace().TrimRuntime())
   739  									logger.Trace(spew.Sdump(vecs))
   740  									return
   741  								}
   742  							default:
   743  								return
   744  							}
   745  						}
   746  
   747  						logger.Trace("matched prometheus metric", "family", family, "vec", fmt.Sprint(*vec), "stack", stack.Trace().TrimRuntime())
   748  
   749  						// Based on the name of the gauge we will add together quantities, this
   750  						// is done because the experiment might have been left out
   751  						// of the inputs and the caller wanted a total for a project
   752  						switch family {
   753  						case "runner_project_running":
   754  							running += int(vec.GetGauge().GetValue())
   755  						case "runner_project_completed":
   756  							finished += int(vec.GetCounter().GetValue())
   757  						default:
   758  							logger.Info("unexpected", "family", family)
   759  						}
   760  					}()
   761  				}
   762  				return nil
   763  			}()
   764  			if err != nil {
   765  				return 0, 0, err
   766  			}
   767  		}
   768  	}
   769  
   770  	return running, finished, nil
   771  }
   772  
   773  type waitFunc func(ctx context.Context, qName string, queueType string, r *runner.Request, prometheusPort int) (err kv.Error)
   774  
   775  // waitForRun will check for an experiment to run using the prometheus metrics to
   776  // track the progress of the experiment on a regular basis
   777  //
   778  func waitForRun(ctx context.Context, qName string, queueType string, r *runner.Request, prometheusPort int) (err kv.Error) {
   779  	// Wait for prometheus to show the task as having been ran and completed
   780  	pClient := NewPrometheusClient(fmt.Sprintf("http://localhost:%d/metrics", prometheusPort))
   781  
   782  	interval := time.Duration(0)
   783  
   784  	// Run around checking the prometheus counters for our experiment seeing when the internal
   785  	// project tracking says everything has completed, only then go out and get the experiment
   786  	// results
   787  	//
   788  	for {
   789  		select {
   790  		case <-time.After(interval):
   791  			metrics, err := pClient.Fetch("runner_project_")
   792  			if err != nil {
   793  				return err
   794  			}
   795  
   796  			runningCnt, finishedCnt, err := projectStats(metrics, qName, queueType, r.Config.Database.ProjectId, r.Experiment.Key)
   797  			if err != nil {
   798  				return err
   799  			}
   800  
   801  			// Wait for prometheus to show the task stopped for our specific queue, host, project and experiment ID
   802  			if runningCnt == 0 && finishedCnt == 1 {
   803  				return nil
   804  			}
   805  			interval = time.Duration(15 * time.Second)
   806  		}
   807  	}
   808  }
   809  
   810  func createResponseRMQ(qName string, encrypt bool) (err kv.Error) {
   811  
   812  	rmq, err := newRMQ(encrypt)
   813  	if err != nil {
   814  		return err
   815  	}
   816  
   817  	if err = rmq.QueueDeclare(qName); err != nil {
   818  		return err
   819  	}
   820  
   821  	return nil
   822  }
   823  
   824  func deleteResponseRMQ(qName string, queueType string, routingKey string) (err kv.Error) {
   825  	rmq, err := newRMQ(false)
   826  	if err != nil {
   827  		return err
   828  	}
   829  
   830  	if err = rmq.QueueDestroy(qName); err != nil {
   831  		return err
   832  	}
   833  
   834  	return nil
   835  }
   836  
   837  func newRMQ(encrypted bool) (rmq *runner.RabbitMQ, err kv.Error) {
   838  	creds := ""
   839  
   840  	qURL, errGo := url.Parse(os.ExpandEnv(*amqpURL))
   841  	if errGo != nil {
   842  		return nil, kv.Wrap(errGo).With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime())
   843  	}
   844  	if qURL.User != nil {
   845  		creds = qURL.User.String()
   846  	} else {
   847  		return nil, kv.NewError("missing credentials in url").With("url", *amqpURL).With("stack", stack.Trace().TrimRuntime())
   848  	}
   849  
   850  	w, err := getWrapper()
   851  	if encrypted {
   852  		if err != nil {
   853  			return nil, err
   854  		}
   855  	}
   856  
   857  	qURL.User = nil
   858  	return runner.NewRabbitMQ(qURL.String(), creds, w)
   859  }
   860  
   861  func marshallToRMQ(rmq *runner.RabbitMQ, qName string, r *runner.Request) (b []byte, err kv.Error) {
   862  	if rmq == nil {
   863  		return nil, kv.NewError("rmq uninitialized").With("stack", stack.Trace().TrimRuntime())
   864  	}
   865  
   866  	if !rmq.IsEncrypted() {
   867  		buf, errGo := json.MarshalIndent(r, "", "  ")
   868  		if errGo != nil {
   869  			return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   870  		}
   871  		return buf, nil
   872  	}
   873  	// To sign a message use a generated signing public key
   874  
   875  	sigs := runner.GetSignatures()
   876  	sigDir := sigs.Dir()
   877  
   878  	if len(sigDir) == 0 {
   879  		return nil, kv.NewError("signatures directory not ready").With("stack", stack.Trace().TrimRuntime())
   880  	}
   881  
   882  	pubKey, prvKey, errGo := ed25519.GenerateKey(rand.Reader)
   883  	if errGo != nil {
   884  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   885  	}
   886  	sshKey, errGo := ssh.NewPublicKey(pubKey)
   887  	if errGo != nil {
   888  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   889  	}
   890  
   891  	// Write the public key
   892  	keyFile := filepath.Join(sigDir, qName)
   893  	if errGo = ioutil.WriteFile(keyFile, ssh.MarshalAuthorizedKey(sshKey), 0600); errGo != nil {
   894  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   895  	}
   896  
   897  	// Now wait for the signature package to signal that the keys
   898  	// have been refreshed and our new file was there
   899  	<-runner.GetSignaturesRefresh().Done()
   900  
   901  	w, err := runner.KubernetesWrapper(*msgEncryptDirOpt)
   902  	if err != nil {
   903  		if runner.IsAliveK8s() != nil {
   904  			return nil, err
   905  		}
   906  	}
   907  
   908  	envelope, err := w.Envelope(r)
   909  	if err != nil {
   910  		return nil, err
   911  	}
   912  
   913  	envelope.Message.Fingerprint = ssh.FingerprintSHA256(sshKey)
   914  
   915  	sig, errGo := prvKey.Sign(rand.Reader, []byte(envelope.Message.Payload), crypto.Hash(0))
   916  	if errGo != nil {
   917  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   918  	}
   919  	logger.Debug("signing produced", "sig", spew.Sdump(sig))
   920  	// Encode the base signature into two fields with binary length fromatted
   921  	// using the SSH RFC method
   922  	envelope.Message.Signature = base64.StdEncoding.EncodeToString(sig)
   923  
   924  	if b, errGo = json.MarshalIndent(envelope, "", "  "); errGo != nil {
   925  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
   926  	}
   927  	return b, nil
   928  }
   929  
   930  // publishToRMQ will marshall a go structure containing experiment parameters and
   931  // environment information and then send it to the rabbitMQ server this server is configured
   932  // to listen to
   933  //
   934  func publishToRMQ(qName string, queueType string, routingKey string, r *runner.Request, encrypted bool) (err kv.Error) {
   935  	rmq, err := newRMQ(encrypted)
   936  	if err != nil {
   937  		return err
   938  	}
   939  
   940  	if err = rmq.QueueDeclare(qName); err != nil {
   941  		return err
   942  	}
   943  
   944  	b, err := marshallToRMQ(rmq, qName, r)
   945  
   946  	// Send the payload to rabbitMQ
   947  	return rmq.Publish(routingKey, "application/json", b)
   948  }
   949  
   950  func watchResponseQueue(ctx context.Context, qName string, encrypted bool) (msgQ chan *runnerReports.Report, err kv.Error) {
   951  	deliveryC := make(chan *runnerReports.Report)
   952  
   953  	rmq, err := newRMQ(encrypted)
   954  	if err != nil {
   955  		return nil, err
   956  	}
   957  
   958  	conn := amqpextra.Dial([]string{rmq.URL() + "%2f"})
   959  	consumer := conn.Consumer(
   960  		qName,
   961  		amqpextra.WorkerFunc(func(ctx context.Context, msg amqp.Delivery) interface{} {
   962  			// process message
   963  
   964  			report := &runnerReports.Report{}
   965  			if err := prototext.Unmarshal([]byte(msg.ContentEncoding), report); err != nil {
   966  				return err
   967  			}
   968  
   969  			if report != nil {
   970  				logger.Info("report received", "report", spew.Sdump(*report))
   971  			}
   972  
   973  			select {
   974  			case deliveryC <- report:
   975  			case <-time.After(5 * time.Second):
   976  				msg.Ack(false)
   977  				return nil
   978  			}
   979  
   980  			msg.Ack(true)
   981  
   982  			return nil
   983  		}),
   984  	)
   985  	consumer.SetWorkerNum(1)
   986  	consumer.SetContext(ctx)
   987  
   988  	return deliveryC, nil
   989  }
   990  
   991  func pullReports(ctx context.Context, msgC <-chan *runnerReports.Report) {
   992  	for {
   993  		select {
   994  		case msg := <-msgC:
   995  			if msg == nil {
   996  				return
   997  			}
   998  		case <-ctx.Done():
   999  			return
  1000  		}
  1001  	}
  1002  }
  1003  
  1004  type validationFunc func(ctx context.Context, experiment *ExperData) (err kv.Error)
  1005  
  1006  // runStudioTest will run a python based experiment and will then present the result to
  1007  // a caller supplied validation function
  1008  //
  1009  func runStudioTest(ctx context.Context, workDir string, gpus int, ignoreK8s bool, useEncryption bool, waiter waitFunc, validation validationFunc) (err kv.Error) {
  1010  
  1011  	if !ignoreK8s {
  1012  		if err = runner.IsAliveK8s(); err != nil {
  1013  			return err
  1014  		}
  1015  	}
  1016  
  1017  	timeoutAlive, aliveCancel := context.WithTimeout(ctx, time.Minute)
  1018  	defer aliveCancel()
  1019  
  1020  	// Check that the minio local server has initialized before continuing
  1021  	if alive, err := runner.MinioTest.IsAlive(timeoutAlive); !alive || err != nil {
  1022  		if err != nil {
  1023  			return err
  1024  		}
  1025  		return kv.NewError("The minio test server is not available to run this test").With("stack", stack.Trace().TrimRuntime())
  1026  	}
  1027  	logger.Debug("alive checked", "addr", runner.MinioTest.Address)
  1028  
  1029  	returnToWD, err := relocateToTemp(workDir)
  1030  	if err != nil {
  1031  		return err
  1032  	}
  1033  	defer returnToWD.Close()
  1034  
  1035  	logger.Debug("test relocated", "workDir", workDir)
  1036  
  1037  	experiment, r, err := prepareExperiment(gpus, ignoreK8s)
  1038  	if err != nil {
  1039  		return err
  1040  	}
  1041  
  1042  	logger.Debug("experiment prepared")
  1043  
  1044  	// Having constructed the payload identify the files within the test template
  1045  	// directory and save them into a workspace tar archive then
  1046  	// generate a tar file of the entire workspace directory and upload
  1047  	// to the minio server that the runner will pull from
  1048  	if err = uploadWorkspace(experiment); err != nil {
  1049  		return err
  1050  	}
  1051  
  1052  	logger.Debug("experiment uploaded")
  1053  
  1054  	// Cleanup the bucket only after the validation function that was supplied has finished
  1055  	defer runner.MinioTest.RemoveBucketAll(experiment.Bucket)
  1056  
  1057  	// Generate queue names that will be used for this test case
  1058  	queueType := "rmq"
  1059  	qName := queueType + "_Multipart_" + xid.New().String()
  1060  	routingKey := "StudioML." + qName
  1061  
  1062  	// Create and listen to the response queue which will receive messages
  1063  	// from the worker
  1064  	if err = createResponseRMQ(qName+"_response", useEncryption); err != nil {
  1065  		return err
  1066  	}
  1067  	defer deleteResponseRMQ(qName+"_response", queueType, routingKey)
  1068  
  1069  	responseCtx, cancelResponse := context.WithCancel(context.Background())
  1070  	defer cancelResponse()
  1071  
  1072  	msgC, err := watchResponseQueue(responseCtx, string(qName+"_response"), useEncryption)
  1073  	if err != nil {
  1074  		return err
  1075  	}
  1076  
  1077  	go pullReports(responseCtx, msgC)
  1078  
  1079  	logger.Debug("test initiated", "queue", qName, "stack", stack.Trace().TrimRuntime())
  1080  
  1081  	// Now that the file needed is present on the minio server send the
  1082  	// experiment specification message to the worker using a new queue
  1083  
  1084  	if err = publishToRMQ(qName, queueType, routingKey, r, useEncryption); err != nil {
  1085  		return err
  1086  	}
  1087  
  1088  	logger.Debug("test waiting", "queue", qName, "stack", stack.Trace().TrimRuntime())
  1089  
  1090  	if err = waiter(ctx, qName, queueType, r, prometheusPort); err != nil {
  1091  		return err
  1092  	}
  1093  
  1094  	// Query minio for the resulting output and compare it with the expected
  1095  	return validation(ctx, experiment)
  1096  }
  1097  
  1098  // TestÄE2EExperimentRun is a function used to exercise the core ability of the runner to successfully
  1099  // complete a single experiment.  The name of the test uses a Latin A with Diaresis to order this
  1100  // test after others that are simpler in nature.
  1101  //
  1102  // This test take a minute or two but is left to run in the short version of testing because
  1103  // it exercises the entire system under test end to end for experiments running in the python
  1104  // environment
  1105  //
  1106  func TestÄE2ECPUExperimentRun(t *testing.T) {
  1107  	E2EExperimentRun(t, 0)
  1108  }
  1109  
  1110  func TestÄE2EGPUExperimentRun(t *testing.T) {
  1111  	if !*runner.UseGPU {
  1112  		logger.Warn("TestÄE2EExperimentRun not run")
  1113  		t.Skip("GPUs disabled for testing")
  1114  	}
  1115  	E2EExperimentRun(t, 1)
  1116  
  1117  }
  1118  
  1119  func E2EExperimentRun(t *testing.T, gpusNeeded int) {
  1120  
  1121  	if !*useK8s {
  1122  		t.Skip("kubernetes specific testing disabled")
  1123  	}
  1124  
  1125  	gpuCount := runner.GPUCount()
  1126  	if gpusNeeded > gpuCount {
  1127  		t.Skipf("insufficient GPUs %d, needed %d", gpuCount, gpusNeeded)
  1128  	}
  1129  
  1130  	cases := []struct {
  1131  		useEncrypt bool
  1132  	}{
  1133  		{useEncrypt: true},
  1134  		{useEncrypt: false},
  1135  	}
  1136  
  1137  	for _, aCase := range cases {
  1138  		wd, errGo := os.Getwd()
  1139  		if errGo != nil {
  1140  			t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
  1141  		}
  1142  		// Navigate to the assets directory being used for this experiment
  1143  		workDir, errGo := filepath.Abs(filepath.Join(wd, "..", "..", "assets", "tf_minimal"))
  1144  		if errGo != nil {
  1145  			t.Fatal(errGo)
  1146  		}
  1147  
  1148  		if err := runStudioTest(context.Background(), workDir, gpusNeeded, false, aCase.useEncrypt, waitForRun, validateTFMinimal); err != nil {
  1149  			t.Fatal(err)
  1150  		}
  1151  
  1152  		// Make sure we returned to the directory we expected
  1153  		newWD, errGo := os.Getwd()
  1154  		if errGo != nil {
  1155  			t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
  1156  		}
  1157  		if newWD != wd {
  1158  			t.Fatal(kv.NewError("finished in an unexpected directory").With("expected_dir", wd).With("actual_dir", newWD).With("stack", stack.Trace().TrimRuntime()))
  1159  		}
  1160  	}
  1161  }
  1162  
  1163  func validatePytorchMultiGPU(ctx context.Context, experiment *ExperData) (err kv.Error) {
  1164  	// Unpack the output archive within a temporary directory and use it for validation
  1165  	dir, errGo := ioutil.TempDir("", xid.New().String())
  1166  	if errGo != nil {
  1167  		return kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime())
  1168  	}
  1169  	defer os.RemoveAll(dir)
  1170  
  1171  	output := filepath.Join(dir, "output.tar")
  1172  	if err = downloadOutput(ctx, experiment, output); err != nil {
  1173  		return err
  1174  	}
  1175  
  1176  	// Now examine the file for successfully running the python code
  1177  	if errGo = archiver.Tar.Open(output, dir); errGo != nil {
  1178  		return kv.Wrap(errGo).With("file", output).With("stack", stack.Trace().TrimRuntime())
  1179  	}
  1180  
  1181  	outFn := filepath.Join(dir, "output")
  1182  	outFile, errGo := os.Open(outFn)
  1183  	if errGo != nil {
  1184  		return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime())
  1185  	}
  1186  
  1187  	supressDump := false
  1188  	defer func() {
  1189  		if !supressDump {
  1190  			io.Copy(os.Stdout, outFile)
  1191  		}
  1192  		outFile.Close()
  1193  	}()
  1194  
  1195  	validateString := fmt.Sprintf("(\"Let's use\", %dL, 'GPUs!')", len(experiment.GPUs))
  1196  	err = kv.NewError("multiple gpu logging not found").With("log", validateString).With("stack", stack.Trace().TrimRuntime())
  1197  
  1198  	scanner := bufio.NewScanner(outFile)
  1199  	for scanner.Scan() {
  1200  		if strings.Contains(scanner.Text(), validateString) {
  1201  			supressDump = true
  1202  			err = nil
  1203  			break
  1204  		}
  1205  	}
  1206  	if errGo = scanner.Err(); errGo != nil {
  1207  		return kv.Wrap(errGo).With("file", outFn).With("stack", stack.Trace().TrimRuntime())
  1208  	}
  1209  
  1210  	return err
  1211  }
  1212  
  1213  // TestÄE2EPytorchMGPURun is a function used to exercise the multi GPU ability of the runner to successfully
  1214  // complete a single pytorch multi GPU experiment.  The name of the test uses a Latin A with Diaresis to order this
  1215  // test after others that are simpler in nature.
  1216  //
  1217  // This test take a minute or two but is left to run in the short version of testing because
  1218  // it exercises the entire system under test end to end for experiments running in the python
  1219  // environment
  1220  //
  1221  func TestÄE2EPytorchMGPURun(t *testing.T) {
  1222  
  1223  	if !*useK8s {
  1224  		t.Skip("kubernetes specific testing disabled")
  1225  	}
  1226  
  1227  	if !*runner.UseGPU {
  1228  		logger.Warn("TestÄE2EPytorchMGPURun not run")
  1229  		t.Skip("GPUs disabled for testing")
  1230  	}
  1231  
  1232  	wd, errGo := os.Getwd()
  1233  	if errGo != nil {
  1234  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
  1235  	}
  1236  
  1237  	gpusNeeded := 2
  1238  	gpuCount := runner.GPUCount()
  1239  	if gpusNeeded > gpuCount {
  1240  		t.Skipf("insufficient GPUs %d, needed %d", gpuCount, gpusNeeded)
  1241  	}
  1242  
  1243  	// Navigate to the assets directory being used for this experiment
  1244  	workDir, errGo := filepath.Abs(filepath.Join(wd, "..", "..", "assets", "pytorch_mgpu"))
  1245  	if errGo != nil {
  1246  		t.Fatal(errGo)
  1247  	}
  1248  
  1249  	if err := runStudioTest(context.Background(), workDir, 2, false, false, waitForRun, validatePytorchMultiGPU); err != nil {
  1250  		t.Fatal(err)
  1251  	}
  1252  
  1253  	// Make sure we returned to the directory we expected
  1254  	newWD, errGo := os.Getwd()
  1255  	if errGo != nil {
  1256  		t.Fatal(kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()))
  1257  	}
  1258  	if newWD != wd {
  1259  		t.Fatal(kv.NewError("finished in an unexpected directory").With("expected_dir", wd).With("actual_dir", newWD).With("stack", stack.Trace().TrimRuntime()))
  1260  	}
  1261  }