github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/driver/driver_test.go (about)

     1  package driver
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/gogo/protobuf/types"
    17  	"github.com/prometheus/client_golang/prometheus"
    18  	prometheus_proto "github.com/prometheus/client_model/go"
    19  	"gopkg.in/go-playground/webhooks.v5/github"
    20  
    21  	"github.com/pachyderm/pachyderm/src/client"
    22  	"github.com/pachyderm/pachyderm/src/client/enterprise"
    23  	"github.com/pachyderm/pachyderm/src/client/pfs"
    24  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    25  	"github.com/pachyderm/pachyderm/src/client/pkg/require"
    26  	"github.com/pachyderm/pachyderm/src/client/pps"
    27  	"github.com/pachyderm/pachyderm/src/server/pkg/testpachd"
    28  	tu "github.com/pachyderm/pachyderm/src/server/pkg/testutil"
    29  	"github.com/pachyderm/pachyderm/src/server/worker/common"
    30  	"github.com/pachyderm/pachyderm/src/server/worker/logs"
    31  )
    32  
    33  var inputRepo = "inputRepo"
    34  var inputGitRepo = "https://github.com/pachyderm/test-artifacts.git"
    35  var inputGitRepoFake = "https://github.com/pachyderm/test-artifacts-fake.git"
    36  
    37  func testPipelineInfo() *pps.PipelineInfo {
    38  	return &pps.PipelineInfo{
    39  		Pipeline: client.NewPipeline("testPipeline"),
    40  		Transform: &pps.Transform{
    41  			Cmd: []string{"cp", filepath.Join("pfs", inputRepo, "file"), "pfs/out/file"},
    42  		},
    43  		ParallelismSpec: &pps.ParallelismSpec{
    44  			Constant: 1,
    45  		},
    46  		ResourceRequests: &pps.ResourceSpec{
    47  			Memory: "100M",
    48  			Cpu:    0.5,
    49  		},
    50  		Input: client.NewPFSInput(inputRepo, "/*"),
    51  	}
    52  }
    53  
    54  type testEnv struct {
    55  	testpachd.MockEnv
    56  	driver *driver
    57  }
    58  
    59  func withTestEnv(cb func(*testEnv)) error {
    60  	return testpachd.WithMockEnv(func(mockEnv *testpachd.MockEnv) (err error) {
    61  		env := &testEnv{MockEnv: *mockEnv}
    62  
    63  		// Mock out the enterprise.GetState call that happens during driver construction
    64  		env.MockPachd.Enterprise.GetState.Use(func(context.Context, *enterprise.GetStateRequest) (*enterprise.GetStateResponse, error) {
    65  			return &enterprise.GetStateResponse{State: enterprise.State_NONE}, nil
    66  		})
    67  
    68  		var d Driver
    69  		d, err = NewDriver(
    70  			testPipelineInfo(),
    71  			env.PachClient,
    72  			env.EtcdClient,
    73  			tu.UniqueString("driverTest"),
    74  			filepath.Clean(filepath.Join(env.Directory, "hashtrees")),
    75  			filepath.Clean(filepath.Join(env.Directory, "pfs")),
    76  			"namespace",
    77  		)
    78  		if err != nil {
    79  			return err
    80  		}
    81  		d = d.WithContext(env.Context)
    82  		env.driver = d.(*driver)
    83  		env.driver.pipelineInfo.Transform.WorkingDir = env.Directory
    84  
    85  		cb(env)
    86  
    87  		return nil
    88  	})
    89  }
    90  
    91  // collectLogs provides the given callback with a mock TaggedLogger object which
    92  // will be used to collect all the logs and return them. This is pretty naive
    93  // and just splits log statements based on newlines because when running user
    94  // code, it is just used as an io.Writer and doesn't know when one message ends
    95  // and the next begins.
    96  func collectLogs(cb func(logs.TaggedLogger)) []string {
    97  	logger := logs.NewMockLogger()
    98  	buffer := &bytes.Buffer{}
    99  	logger.Writer = buffer
   100  	logger.Job = "job-id"
   101  
   102  	cb(logger)
   103  
   104  	logStmts := strings.Split(buffer.String(), "\n")
   105  	if len(logStmts) > 0 && logStmts[len(logStmts)-1] == "" {
   106  		return logStmts[0 : len(logStmts)-1]
   107  	}
   108  	return logStmts
   109  }
   110  
   111  // requireLogs wraps collectLogs and ensures that certain log statements were
   112  // made. These are specified as regular expressions in the patterns parameter,
   113  // and each pattern must match at least one log line. The patterns are run
   114  // separately against each log line, not against the entire output. If the
   115  // patterns parameter is nil, we require that there are no log statements.
   116  func requireLogs(t *testing.T, patterns []string, cb func(logs.TaggedLogger)) {
   117  	logStmts := collectLogs(cb)
   118  
   119  	if patterns == nil {
   120  		require.Equal(t, 0, len(logStmts), "callback should not have logged anything")
   121  	} else {
   122  		for _, pattern := range patterns {
   123  			require.OneOfMatches(t, pattern, logStmts, "callback did not log the expected message")
   124  		}
   125  	}
   126  }
   127  
   128  func requireMetric(t *testing.T, metric prometheus.Collector, labels []string, cb func(prometheus_proto.Metric)) {
   129  	reg := prometheus.NewRegistry()
   130  	require.NoError(t, reg.Register(metric))
   131  
   132  	stats, err := reg.Gather()
   133  	require.NoError(t, err)
   134  
   135  	// Add a placeholder for the state label even if it isn't used
   136  	for len(labels) < 3 {
   137  		labels = append(labels, "")
   138  	}
   139  
   140  	// We only have one metric in the registry, so skip over the family level
   141  	for _, family := range stats {
   142  		for _, metric := range family.Metric {
   143  			var pipeline, job, state string
   144  			for _, pair := range metric.Label {
   145  				switch *pair.Name {
   146  				case "pipeline":
   147  					pipeline = *pair.Value
   148  				case "job":
   149  					job = *pair.Value
   150  				case "state":
   151  					state = *pair.Value
   152  				default:
   153  					require.True(t, false, fmt.Sprintf("unexpected metric label: %s", *pair.Name))
   154  				}
   155  			}
   156  
   157  			metricLabels := []string{pipeline, job, state}
   158  			if reflect.DeepEqual(labels, metricLabels) {
   159  				cb(*metric)
   160  				return
   161  			}
   162  		}
   163  	}
   164  
   165  	require.True(t, false, fmt.Sprintf("no matching metric found for labels: %v", labels))
   166  }
   167  
   168  func requireCounter(t *testing.T, counter *prometheus.CounterVec, labels []string, value float64) {
   169  	requireMetric(t, counter, labels, func(m prometheus_proto.Metric) {
   170  		require.NotNil(t, m.Counter)
   171  		require.Equal(t, value, *m.Counter.Value)
   172  	})
   173  }
   174  
   175  func requireHistogram(t *testing.T, histogram *prometheus.HistogramVec, labels []string, value uint64) {
   176  	requireMetric(t, histogram, labels, func(m prometheus_proto.Metric) {
   177  		require.NotNil(t, m.Histogram)
   178  		require.Equal(t, value, *m.Histogram.SampleCount)
   179  	})
   180  }
   181  
   182  func TestUpdateCounter(t *testing.T) {
   183  	t.Parallel()
   184  	err := withTestEnv(func(env *testEnv) {
   185  		env.driver.pipelineInfo.ID = "foo"
   186  
   187  		counterVec := prometheus.NewCounterVec(
   188  			prometheus.CounterOpts{Namespace: "test", Subsystem: "driver", Name: "counter"},
   189  			[]string{"pipeline", "job"},
   190  		)
   191  
   192  		counterVecWithState := prometheus.NewCounterVec(
   193  			prometheus.CounterOpts{Namespace: "test", Subsystem: "driver", Name: "counter_with_state"},
   194  			[]string{"pipeline", "job", "state"},
   195  		)
   196  
   197  		// Passing a state to the stateless counter should error
   198  		requireLogs(t, []string{"expected 2 label values but got 3"}, func(logger logs.TaggedLogger) {
   199  			env.driver.updateCounter(counterVec, logger, "bar", func(c prometheus.Counter) {
   200  				require.True(t, false, "should have errored")
   201  			})
   202  		})
   203  
   204  		// updateCounter should pass a valid counter with the selected tags
   205  		requireLogs(t, nil, func(logger logs.TaggedLogger) {
   206  			env.driver.updateCounter(counterVec, logger, "", func(c prometheus.Counter) {
   207  				c.Add(1)
   208  			})
   209  		})
   210  
   211  		// Check that the counter was incremented
   212  		requireCounter(t, counterVec, []string{"foo", "job-id"}, 1)
   213  
   214  		// Not passing a state to the stateful counter should error
   215  		requireLogs(t, []string{"expected 3 label values but got 2"}, func(logger logs.TaggedLogger) {
   216  			env.driver.updateCounter(counterVecWithState, logger, "", func(c prometheus.Counter) {
   217  				require.True(t, false, "should have errored")
   218  			})
   219  		})
   220  
   221  		// updateCounter should pass a valid counter with the selected tags
   222  		requireLogs(t, nil, func(logger logs.TaggedLogger) {
   223  			env.driver.updateCounter(counterVecWithState, logger, "bar", func(c prometheus.Counter) {
   224  				c.Add(1)
   225  			})
   226  		})
   227  
   228  		// Check that the counter was incremented
   229  		requireCounter(t, counterVecWithState, []string{"foo", "job-id", "bar"}, 1)
   230  	})
   231  	require.NoError(t, err)
   232  }
   233  
   234  func TestUpdateHistogram(t *testing.T) {
   235  	t.Parallel()
   236  	err := withTestEnv(func(env *testEnv) {
   237  		env.driver.pipelineInfo.ID = "foo"
   238  
   239  		histogramVec := prometheus.NewHistogramVec(
   240  			prometheus.HistogramOpts{
   241  				Namespace: "test", Subsystem: "driver", Name: "histogram",
   242  				Buckets: prometheus.ExponentialBuckets(1.0, 2.0, 20),
   243  			},
   244  			[]string{"pipeline", "job"},
   245  		)
   246  
   247  		histogramVecWithState := prometheus.NewHistogramVec(
   248  			prometheus.HistogramOpts{
   249  				Namespace: "test", Subsystem: "driver", Name: "histogram_with_state",
   250  				Buckets: prometheus.ExponentialBuckets(1.0, 2.0, 20),
   251  			},
   252  			[]string{"pipeline", "job", "state"},
   253  		)
   254  
   255  		// Passing a state to the stateless histogram should error
   256  		requireLogs(t, []string{"expected 2 label values but got 3"}, func(logger logs.TaggedLogger) {
   257  			env.driver.updateHistogram(histogramVec, logger, "bar", func(h prometheus.Observer) {
   258  				require.True(t, false, "should have errored")
   259  			})
   260  		})
   261  
   262  		requireLogs(t, nil, func(logger logs.TaggedLogger) {
   263  			env.driver.updateHistogram(histogramVec, logger, "", func(h prometheus.Observer) {
   264  				h.Observe(0)
   265  			})
   266  		})
   267  
   268  		// Check that the counter was incremented
   269  		requireHistogram(t, histogramVec, []string{"foo", "job-id"}, 1)
   270  
   271  		// Not passing a state to the stateful histogram should error
   272  		requireLogs(t, []string{"expected 3 label values but got 2"}, func(logger logs.TaggedLogger) {
   273  			env.driver.updateHistogram(histogramVecWithState, logger, "", func(h prometheus.Observer) {
   274  				require.True(t, false, "should have errored")
   275  			})
   276  		})
   277  
   278  		requireLogs(t, nil, func(logger logs.TaggedLogger) {
   279  			env.driver.updateHistogram(histogramVecWithState, logger, "bar", func(h prometheus.Observer) {
   280  				h.Observe(0)
   281  			})
   282  		})
   283  
   284  		// Check that the counter was incremented
   285  		requireHistogram(t, histogramVecWithState, []string{"foo", "job-id", "bar"}, 1)
   286  	})
   287  	require.NoError(t, err)
   288  }
   289  
   290  type inputData struct {
   291  	path     string
   292  	contents string
   293  	regex    string
   294  	found    bool
   295  }
   296  
   297  func newInputDataRegex(path string, regex string) *inputData {
   298  	return &inputData{path: filepath.Clean(path), regex: regex}
   299  }
   300  
   301  func newInputData(path string, contents string) *inputData {
   302  	return &inputData{path: filepath.Clean(path), contents: contents}
   303  }
   304  
   305  func requireEmptyScratch(t *testing.T, inputDir string) {
   306  	entries, err := ioutil.ReadDir(filepath.Join(inputDir, client.PPSScratchSpace))
   307  
   308  	if !errors.Is(err, os.ErrNotExist) {
   309  		require.ElementsEqual(t, []os.FileInfo{}, entries)
   310  	}
   311  }
   312  
   313  func requireContents(t *testing.T, dir string, data []*inputData) {
   314  	checkFile := func(fullPath string, relPath string) {
   315  		for _, checkData := range data {
   316  			if checkData.path == relPath {
   317  				contents, err := ioutil.ReadFile(fullPath)
   318  				require.NoError(t, err)
   319  				if checkData.regex != "" {
   320  					require.Matches(t, checkData.regex, string(contents), "Incorrect contents for input file: %s", relPath)
   321  				} else {
   322  					require.Equal(t, checkData.contents, string(contents), "Incorrect contents for input file: %s", relPath)
   323  				}
   324  				checkData.found = true
   325  				return
   326  			}
   327  		}
   328  		require.True(t, false, "Unexpected input file found: %s", relPath)
   329  	}
   330  
   331  	err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
   332  		require.NoError(t, err)
   333  		if info.Name() == ".git" || info.Name() == client.PPSScratchSpace {
   334  			return filepath.SkipDir
   335  		}
   336  		if !info.IsDir() {
   337  			path = filepath.Clean(path)
   338  			relPath := strings.TrimLeft(strings.TrimPrefix(path, dir), "/\\")
   339  			checkFile(path, relPath)
   340  		}
   341  		return nil
   342  	})
   343  	require.NoError(t, err)
   344  
   345  	for _, checkData := range data {
   346  		require.True(t, checkData.found, "Expected input file not found: %s", checkData.path)
   347  	}
   348  }
   349  
   350  func TestWithDataEmpty(t *testing.T) {
   351  	t.Parallel()
   352  	err := withTestEnv(func(env *testEnv) {
   353  		requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) {
   354  			_, err := env.driver.WithData(
   355  				[]*common.Input{},
   356  				nil,
   357  				logger,
   358  				func(dir string, stats *pps.ProcessStats) error {
   359  					requireContents(t, dir, []*inputData{})
   360  					return nil
   361  				},
   362  			)
   363  			require.NoError(t, err)
   364  			requireEmptyScratch(t, env.driver.InputDir())
   365  			requireContents(t, env.driver.InputDir(), []*inputData{})
   366  		})
   367  	})
   368  	require.NoError(t, err)
   369  }
   370  
   371  func TestWithDataSpout(t *testing.T) {
   372  	t.Parallel()
   373  	err := withTestEnv(func(env *testEnv) {
   374  		env.driver.pipelineInfo.Spout = &pps.Spout{}
   375  		requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) {
   376  			_, err := env.driver.WithData(
   377  				[]*common.Input{},
   378  				nil,
   379  				logger,
   380  				func(dir string, stats *pps.ProcessStats) error {
   381  					// A spout pipeline should have created a 'pfs/out` fifo for the user
   382  					// code to write to
   383  					requireContents(t, dir, []*inputData{newInputData("out", "")})
   384  					return nil
   385  				},
   386  			)
   387  			require.NoError(t, err)
   388  			requireEmptyScratch(t, env.driver.InputDir())
   389  			requireContents(t, env.driver.InputDir(), []*inputData{})
   390  		})
   391  	})
   392  	require.NoError(t, err)
   393  }
   394  
   395  // Shitty helper function to create possibly-not-malformed input structures
   396  func newInput(repo string, path string) *common.Input {
   397  	return &common.Input{
   398  		FileInfo: &pfs.FileInfo{
   399  			File: &pfs.File{
   400  				Commit: &pfs.Commit{
   401  					Repo: &pfs.Repo{
   402  						Name: repo,
   403  					},
   404  					ID: "commit-id-string",
   405  				},
   406  				Path: path,
   407  			},
   408  			FileType: pfs.FileType_FILE,
   409  		},
   410  		Name:   repo,
   411  		Branch: "master",
   412  	}
   413  }
   414  
   415  func TestWithDataCancel(t *testing.T) {
   416  	t.Parallel()
   417  	err := withTestEnv(func(env *testEnv) {
   418  		requireLogs(t, []string{"errored downloading data", "context canceled"}, func(logger logs.TaggedLogger) {
   419  			ctx, cancel := context.WithCancel(env.Context)
   420  			driver := env.driver.WithContext(ctx)
   421  
   422  			// Cancel the context during the download
   423  			env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error {
   424  				cancel()
   425  				<-serv.Context().Done()
   426  				return errors.Errorf("WalkFile canceled")
   427  			})
   428  
   429  			_, err := driver.WithData(
   430  				[]*common.Input{newInput("repo", "input.txt")},
   431  				nil,
   432  				logger,
   433  				func(dir string, stats *pps.ProcessStats) error {
   434  					require.True(t, false, "Should have been canceled before the callback")
   435  					cancel()
   436  					return nil
   437  				},
   438  			)
   439  			require.YesError(t, err, "WithData call should have been canceled")
   440  			requireEmptyScratch(t, env.driver.InputDir())
   441  			requireContents(t, env.driver.InputDir(), []*inputData{})
   442  		})
   443  	})
   444  	require.NoError(t, err)
   445  }
   446  
   447  // Check that the driver will download the requested inputs, put them in place
   448  // during WithData, and clean them up after running the inner function.
   449  func TestWithDataDownload(t *testing.T) {
   450  	t.Parallel()
   451  	err := withTestEnv(func(env *testEnv) {
   452  		requireLogs(t, []string{"finished downloading data", "inner function"}, func(logger logs.TaggedLogger) {
   453  			// Mock out the calls that will be used to download the data
   454  			env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error {
   455  				return serv.Send(&pfs.FileInfo{
   456  					File:     req.File,
   457  					FileType: pfs.FileType_FILE,
   458  				})
   459  			})
   460  
   461  			env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) error {
   462  				return serv.Send(&types.BytesValue{Value: []byte(fmt.Sprintf("%s-data", req.File.Commit.Repo.Name))})
   463  			})
   464  
   465  			_, err := env.driver.WithData(
   466  				[]*common.Input{newInput("repoA", "input.txt"), newInput("repoB", "input.md")},
   467  				nil,
   468  				logger,
   469  				func(dir string, stats *pps.ProcessStats) error {
   470  					requireContents(t, dir, []*inputData{
   471  						newInputData("repoA/input.txt", "repoA-data"),
   472  						newInputData("repoB/input.md", "repoB-data"),
   473  					})
   474  					logger.Logf("inner function")
   475  					return nil
   476  				},
   477  			)
   478  			require.NoError(t, err)
   479  			requireEmptyScratch(t, env.driver.InputDir())
   480  			requireContents(t, env.driver.InputDir(), []*inputData{})
   481  		})
   482  	})
   483  	require.NoError(t, err)
   484  }
   485  
   486  // Create several files and directories inside WithData and verify that they are
   487  // cleaned up after WithData returns.
   488  func TestWithActiveDataCleanup(t *testing.T) {
   489  	t.Parallel()
   490  	err := withTestEnv(func(env *testEnv) {
   491  		create := func(relPath string) {
   492  			fullPath := filepath.Join(env.driver.InputDir(), relPath)
   493  			require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0777))
   494  			file, err := os.Create(fullPath)
   495  			require.NoError(t, err)
   496  			require.NoError(t, file.Close())
   497  		}
   498  
   499  		requireLogs(t, []string{"finished downloading data", "inner function"}, func(logger logs.TaggedLogger) {
   500  			_, err := env.driver.WithData(
   501  				[]*common.Input{},
   502  				nil,
   503  				logger,
   504  				func(dir string, stats *pps.ProcessStats) error {
   505  					requireContents(t, dir, []*inputData{})
   506  					logger.Logf("inner function")
   507  
   508  					expectedContents := []*inputData{
   509  						newInputData("c", ""),
   510  						newInputData("out/1", ""),
   511  						newInputData("out/2/a", ""),
   512  						newInputData("out/2/b", ""),
   513  						newInputData("out/2/3/c", ""),
   514  						newInputData("foo/barbaz", ""),
   515  						newInputData("foo/bar/baz", ""),
   516  						newInputData("floop/blarp/blazj/etc", ""),
   517  					}
   518  
   519  					err := env.driver.WithActiveData([]*common.Input{}, dir, func() error {
   520  						for _, x := range expectedContents {
   521  							create(x.path)
   522  						}
   523  
   524  						requireContents(t, env.driver.InputDir(), expectedContents)
   525  						return nil
   526  					})
   527  					require.NoError(t, err)
   528  					requireContents(t, dir, expectedContents)
   529  					requireContents(t, env.driver.InputDir(), []*inputData{})
   530  					return nil
   531  				},
   532  			)
   533  			require.NoError(t, err)
   534  			requireEmptyScratch(t, env.driver.InputDir())
   535  			requireContents(t, env.driver.InputDir(), []*inputData{})
   536  		})
   537  	})
   538  	require.NoError(t, err)
   539  }
   540  
   541  func newGitInput(repo string, url string) *common.Input {
   542  	return &common.Input{
   543  		FileInfo: &pfs.FileInfo{
   544  			File: &pfs.File{
   545  				Commit: &pfs.Commit{
   546  					Repo: &pfs.Repo{
   547  						Name: repo,
   548  					},
   549  					ID: "commit-id-string",
   550  				},
   551  				Path: "commit.json",
   552  			},
   553  			FileType: pfs.FileType_FILE,
   554  		},
   555  		GitURL: url,
   556  		Name:   repo,
   557  	}
   558  }
   559  
   560  func mockGitGetFile(env *testEnv, repo string, ref string, sha string, cb func(*pfs.GetFileRequest)) {
   561  	env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) (retErr error) {
   562  		payload := &github.PushPayload{
   563  			Ref:   ref,
   564  			After: sha,
   565  		}
   566  		payload.Repository.CloneURL = repo
   567  		jsonBytes, err := json.Marshal(payload)
   568  		if err != nil {
   569  			return err
   570  		}
   571  
   572  		if cb != nil {
   573  			cb(req)
   574  		}
   575  
   576  		return serv.Send(&types.BytesValue{Value: jsonBytes})
   577  	})
   578  }
   579  
   580  func TestWithDataGit(t *testing.T) {
   581  	t.Parallel()
   582  	err := withTestEnv(func(env *testEnv) {
   583  		requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) {
   584  			var getFileReq *pfs.GetFileRequest
   585  			mockGitGetFile(env, inputGitRepo, "refs/heads/master", "9047fbfc251e7412ef3300868f743f2c24852539", func(req *pfs.GetFileRequest) {
   586  				getFileReq = req
   587  			})
   588  
   589  			_, err := env.driver.WithData(
   590  				[]*common.Input{newGitInput("artifacts", inputGitRepo)},
   591  				nil,
   592  				logger,
   593  				func(dir string, stats *pps.ProcessStats) error {
   594  					requireContents(t, dir, []*inputData{newInputDataRegex("artifacts/readme.md", "Test Artifacts")})
   595  					return nil
   596  				},
   597  			)
   598  			require.NoError(t, err)
   599  			require.NotNil(t, getFileReq)
   600  			require.Equal(t, getFileReq.File, client.NewFile("artifacts", "commit-id-string", "commit.json"))
   601  			requireEmptyScratch(t, env.driver.InputDir())
   602  			requireContents(t, env.driver.InputDir(), []*inputData{})
   603  		})
   604  	})
   605  	require.NoError(t, err)
   606  }
   607  
   608  func TestWithDataGitHookError(t *testing.T) {
   609  	t.Parallel()
   610  	err := withTestEnv(func(env *testEnv) {
   611  		requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) {
   612  			mockGitGetFile(env, "", "", "", nil)
   613  
   614  			_, err := env.driver.WithData(
   615  				[]*common.Input{newGitInput("artifacts", inputGitRepo)},
   616  				nil,
   617  				logger,
   618  				func(dir string, stats *pps.ProcessStats) error {
   619  					require.True(t, false, "Should have errored before calling WithData callback")
   620  					return nil
   621  				},
   622  			)
   623  			require.YesError(t, err)
   624  			require.Matches(t, "payload does not specify", err.Error())
   625  			requireEmptyScratch(t, env.driver.InputDir())
   626  			requireContents(t, env.driver.InputDir(), []*inputData{})
   627  		})
   628  	})
   629  	require.NoError(t, err)
   630  }
   631  
   632  func TestWithDataGitRepoMissing(t *testing.T) {
   633  	t.Parallel()
   634  	err := withTestEnv(func(env *testEnv) {
   635  		requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) {
   636  			mockGitGetFile(env, inputGitRepoFake, "refs/heads/master", "foobar", nil)
   637  
   638  			_, err := env.driver.WithData(
   639  				[]*common.Input{newGitInput("artifacts", inputGitRepo)},
   640  				nil,
   641  				logger,
   642  				func(dir string, stats *pps.ProcessStats) error {
   643  					require.True(t, false, "Should have errored before calling WithData callback")
   644  					return nil
   645  				},
   646  			)
   647  			require.YesError(t, err)
   648  			require.Matches(t, "authentication required", err.Error())
   649  			requireEmptyScratch(t, env.driver.InputDir())
   650  			requireContents(t, env.driver.InputDir(), []*inputData{})
   651  		})
   652  	})
   653  	require.NoError(t, err)
   654  }
   655  
   656  func TestWithDataGitInvalidSHA(t *testing.T) {
   657  	t.Parallel()
   658  	err := withTestEnv(func(env *testEnv) {
   659  		requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) {
   660  			mockGitGetFile(env, inputGitRepo, "refs/heads/master", "foobar", nil)
   661  
   662  			_, err := env.driver.WithData(
   663  				[]*common.Input{newGitInput("artifacts", inputGitRepo)},
   664  				nil,
   665  				logger,
   666  				func(dir string, stats *pps.ProcessStats) error {
   667  					require.True(t, false, "Should have errored before calling WithData callback")
   668  					return nil
   669  				},
   670  			)
   671  			require.YesError(t, err)
   672  			require.Matches(t, "could not find SHA foobar", err.Error())
   673  			requireEmptyScratch(t, env.driver.InputDir())
   674  			requireContents(t, env.driver.InputDir(), []*inputData{})
   675  		})
   676  	})
   677  	require.NoError(t, err)
   678  }
   679  
   680  // Test that user code will successfully run and the output will be forwarded to logs
   681  func TestRunUserCode(t *testing.T) {
   682  	t.Parallel()
   683  	logMessage := "this is a user code log message"
   684  	err := withTestEnv(func(env *testEnv) {
   685  		env.driver.pipelineInfo.Transform.Cmd = []string{"echo", logMessage}
   686  		requireLogs(t, []string{logMessage}, func(logger logs.TaggedLogger) {
   687  			err := env.driver.RunUserCode(logger, []string{}, nil, nil)
   688  			require.NoError(t, err)
   689  		})
   690  	})
   691  	require.NoError(t, err)
   692  }
   693  
   694  func TestRunUserCodeError(t *testing.T) {
   695  	t.Parallel()
   696  	err := withTestEnv(func(env *testEnv) {
   697  		env.driver.pipelineInfo.Transform.Cmd = []string{"false"}
   698  		requireLogs(t, []string{"exit status 1"}, func(logger logs.TaggedLogger) {
   699  			err := env.driver.RunUserCode(logger, []string{}, nil, nil)
   700  			require.YesError(t, err)
   701  		})
   702  	})
   703  	require.NoError(t, err)
   704  }
   705  
   706  func TestRunUserCodeNoCommand(t *testing.T) {
   707  	t.Parallel()
   708  	err := withTestEnv(func(env *testEnv) {
   709  		env.driver.pipelineInfo.Transform.Cmd = []string{}
   710  		requireLogs(t, []string{"no command specified"}, func(logger logs.TaggedLogger) {
   711  			err := env.driver.RunUserCode(logger, []string{}, nil, nil)
   712  			require.YesError(t, err)
   713  		})
   714  	})
   715  	require.NoError(t, err)
   716  }
   717  
   718  func TestRunUserCodeTimeout(t *testing.T) {
   719  	t.Parallel()
   720  	err := withTestEnv(func(env *testEnv) {
   721  		env.driver.pipelineInfo.Transform.Cmd = []string{"sleep", "10"}
   722  		timeout := types.DurationProto(10 * time.Millisecond)
   723  		requireLogs(t, []string{"context deadline exceeded"}, func(logger logs.TaggedLogger) {
   724  			err := env.driver.RunUserCode(logger, []string{}, nil, timeout)
   725  			require.YesError(t, err)
   726  			require.Matches(t, "context deadline exceeded", err.Error())
   727  		})
   728  	})
   729  	require.NoError(t, err)
   730  }
   731  
   732  func TestRunUserCodeEnv(t *testing.T) {
   733  	t.Parallel()
   734  	err := withTestEnv(func(env *testEnv) {
   735  		env.driver.pipelineInfo.Transform.Cmd = []string{"env"}
   736  		requireLogs(t, []string{"FOO=password", "BAR=hunter2"}, func(logger logs.TaggedLogger) {
   737  			err := env.driver.RunUserCode(logger, []string{"FOO=password", "BAR=hunter2"}, nil, nil)
   738  			require.NoError(t, err)
   739  		})
   740  	})
   741  	require.NoError(t, err)
   742  }
   743  
   744  func TestRunUserCodeWithData(t *testing.T) {
   745  	t.Parallel()
   746  	err := withTestEnv(func(env *testEnv) {
   747  		env.driver.pipelineInfo.Transform.Cmd = []string{"bash", "-c", "cat pfs/repoA/input.txt pfs/repoB/input.md > pfs/out/output.txt"}
   748  		requireLogs(t, []string{"finished running user code"}, func(logger logs.TaggedLogger) {
   749  			// Mock out the calls that will be used to download the data
   750  			env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error {
   751  				return serv.Send(&pfs.FileInfo{
   752  					File:     req.File,
   753  					FileType: pfs.FileType_FILE,
   754  				})
   755  			})
   756  
   757  			env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) error {
   758  				return serv.Send(&types.BytesValue{Value: []byte(fmt.Sprintf("%s-data", req.File.Commit.Repo.Name))})
   759  			})
   760  
   761  			inputs := []*common.Input{newInput("repoA", "input.txt"), newInput("repoB", "input.md")}
   762  			_, err := env.driver.WithData(
   763  				inputs,
   764  				nil,
   765  				logger,
   766  				func(dir string, stats *pps.ProcessStats) error {
   767  					requireContents(t, dir, []*inputData{
   768  						newInputData("repoA/input.txt", "repoA-data"),
   769  						newInputData("repoB/input.md", "repoB-data"),
   770  					})
   771  
   772  					err := env.driver.WithActiveData(inputs, dir, func() error {
   773  						requireContents(t, env.driver.InputDir(), []*inputData{
   774  							newInputData("repoA/input.txt", "repoA-data"),
   775  							newInputData("repoB/input.md", "repoB-data"),
   776  						})
   777  
   778  						err := env.driver.RunUserCode(logger, []string{}, nil, nil)
   779  						require.NoError(t, err)
   780  
   781  						requireContents(t, env.driver.InputDir(), []*inputData{
   782  							newInputData("repoA/input.txt", "repoA-data"),
   783  							newInputData("repoB/input.md", "repoB-data"),
   784  							newInputData("out/output.txt", "repoA-datarepoB-data"),
   785  						})
   786  						return nil
   787  					})
   788  					require.NoError(t, err)
   789  
   790  					requireContents(t, dir, []*inputData{
   791  						newInputData("repoA/input.txt", "repoA-data"),
   792  						newInputData("repoB/input.md", "repoB-data"),
   793  						newInputData("out/output.txt", "repoA-datarepoB-data"),
   794  					})
   795  					return nil
   796  				},
   797  			)
   798  			require.NoError(t, err)
   799  			requireEmptyScratch(t, env.driver.InputDir())
   800  			requireContents(t, env.driver.InputDir(), []*inputData{})
   801  		})
   802  	})
   803  	require.NoError(t, err)
   804  }