github.com/pachyderm/pachyderm@v1.13.4/src/server/worker/driver/driver_test.go (about) 1 package driver 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "reflect" 12 "strings" 13 "testing" 14 "time" 15 16 "github.com/gogo/protobuf/types" 17 "github.com/prometheus/client_golang/prometheus" 18 prometheus_proto "github.com/prometheus/client_model/go" 19 "gopkg.in/go-playground/webhooks.v5/github" 20 21 "github.com/pachyderm/pachyderm/src/client" 22 "github.com/pachyderm/pachyderm/src/client/enterprise" 23 "github.com/pachyderm/pachyderm/src/client/pfs" 24 "github.com/pachyderm/pachyderm/src/client/pkg/errors" 25 "github.com/pachyderm/pachyderm/src/client/pkg/require" 26 "github.com/pachyderm/pachyderm/src/client/pps" 27 "github.com/pachyderm/pachyderm/src/server/pkg/testpachd" 28 tu "github.com/pachyderm/pachyderm/src/server/pkg/testutil" 29 "github.com/pachyderm/pachyderm/src/server/worker/common" 30 "github.com/pachyderm/pachyderm/src/server/worker/logs" 31 ) 32 33 var inputRepo = "inputRepo" 34 var inputGitRepo = "https://github.com/pachyderm/test-artifacts.git" 35 var inputGitRepoFake = "https://github.com/pachyderm/test-artifacts-fake.git" 36 37 func testPipelineInfo() *pps.PipelineInfo { 38 return &pps.PipelineInfo{ 39 Pipeline: client.NewPipeline("testPipeline"), 40 Transform: &pps.Transform{ 41 Cmd: []string{"cp", filepath.Join("pfs", inputRepo, "file"), "pfs/out/file"}, 42 }, 43 ParallelismSpec: &pps.ParallelismSpec{ 44 Constant: 1, 45 }, 46 ResourceRequests: &pps.ResourceSpec{ 47 Memory: "100M", 48 Cpu: 0.5, 49 }, 50 Input: client.NewPFSInput(inputRepo, "/*"), 51 } 52 } 53 54 type testEnv struct { 55 testpachd.MockEnv 56 driver *driver 57 } 58 59 func withTestEnv(cb func(*testEnv)) error { 60 return testpachd.WithMockEnv(func(mockEnv *testpachd.MockEnv) (err error) { 61 env := &testEnv{MockEnv: *mockEnv} 62 63 // Mock out the enterprise.GetState call that happens during driver construction 64 env.MockPachd.Enterprise.GetState.Use(func(context.Context, *enterprise.GetStateRequest) (*enterprise.GetStateResponse, error) { 65 return &enterprise.GetStateResponse{State: enterprise.State_NONE}, nil 66 }) 67 68 var d Driver 69 d, err = NewDriver( 70 testPipelineInfo(), 71 env.PachClient, 72 env.EtcdClient, 73 tu.UniqueString("driverTest"), 74 filepath.Clean(filepath.Join(env.Directory, "hashtrees")), 75 filepath.Clean(filepath.Join(env.Directory, "pfs")), 76 "namespace", 77 ) 78 if err != nil { 79 return err 80 } 81 d = d.WithContext(env.Context) 82 env.driver = d.(*driver) 83 env.driver.pipelineInfo.Transform.WorkingDir = env.Directory 84 85 cb(env) 86 87 return nil 88 }) 89 } 90 91 // collectLogs provides the given callback with a mock TaggedLogger object which 92 // will be used to collect all the logs and return them. This is pretty naive 93 // and just splits log statements based on newlines because when running user 94 // code, it is just used as an io.Writer and doesn't know when one message ends 95 // and the next begins. 96 func collectLogs(cb func(logs.TaggedLogger)) []string { 97 logger := logs.NewMockLogger() 98 buffer := &bytes.Buffer{} 99 logger.Writer = buffer 100 logger.Job = "job-id" 101 102 cb(logger) 103 104 logStmts := strings.Split(buffer.String(), "\n") 105 if len(logStmts) > 0 && logStmts[len(logStmts)-1] == "" { 106 return logStmts[0 : len(logStmts)-1] 107 } 108 return logStmts 109 } 110 111 // requireLogs wraps collectLogs and ensures that certain log statements were 112 // made. These are specified as regular expressions in the patterns parameter, 113 // and each pattern must match at least one log line. The patterns are run 114 // separately against each log line, not against the entire output. If the 115 // patterns parameter is nil, we require that there are no log statements. 116 func requireLogs(t *testing.T, patterns []string, cb func(logs.TaggedLogger)) { 117 logStmts := collectLogs(cb) 118 119 if patterns == nil { 120 require.Equal(t, 0, len(logStmts), "callback should not have logged anything") 121 } else { 122 for _, pattern := range patterns { 123 require.OneOfMatches(t, pattern, logStmts, "callback did not log the expected message") 124 } 125 } 126 } 127 128 func requireMetric(t *testing.T, metric prometheus.Collector, labels []string, cb func(prometheus_proto.Metric)) { 129 reg := prometheus.NewRegistry() 130 require.NoError(t, reg.Register(metric)) 131 132 stats, err := reg.Gather() 133 require.NoError(t, err) 134 135 // Add a placeholder for the state label even if it isn't used 136 for len(labels) < 3 { 137 labels = append(labels, "") 138 } 139 140 // We only have one metric in the registry, so skip over the family level 141 for _, family := range stats { 142 for _, metric := range family.Metric { 143 var pipeline, job, state string 144 for _, pair := range metric.Label { 145 switch *pair.Name { 146 case "pipeline": 147 pipeline = *pair.Value 148 case "job": 149 job = *pair.Value 150 case "state": 151 state = *pair.Value 152 default: 153 require.True(t, false, fmt.Sprintf("unexpected metric label: %s", *pair.Name)) 154 } 155 } 156 157 metricLabels := []string{pipeline, job, state} 158 if reflect.DeepEqual(labels, metricLabels) { 159 cb(*metric) 160 return 161 } 162 } 163 } 164 165 require.True(t, false, fmt.Sprintf("no matching metric found for labels: %v", labels)) 166 } 167 168 func requireCounter(t *testing.T, counter *prometheus.CounterVec, labels []string, value float64) { 169 requireMetric(t, counter, labels, func(m prometheus_proto.Metric) { 170 require.NotNil(t, m.Counter) 171 require.Equal(t, value, *m.Counter.Value) 172 }) 173 } 174 175 func requireHistogram(t *testing.T, histogram *prometheus.HistogramVec, labels []string, value uint64) { 176 requireMetric(t, histogram, labels, func(m prometheus_proto.Metric) { 177 require.NotNil(t, m.Histogram) 178 require.Equal(t, value, *m.Histogram.SampleCount) 179 }) 180 } 181 182 func TestUpdateCounter(t *testing.T) { 183 t.Parallel() 184 err := withTestEnv(func(env *testEnv) { 185 env.driver.pipelineInfo.ID = "foo" 186 187 counterVec := prometheus.NewCounterVec( 188 prometheus.CounterOpts{Namespace: "test", Subsystem: "driver", Name: "counter"}, 189 []string{"pipeline", "job"}, 190 ) 191 192 counterVecWithState := prometheus.NewCounterVec( 193 prometheus.CounterOpts{Namespace: "test", Subsystem: "driver", Name: "counter_with_state"}, 194 []string{"pipeline", "job", "state"}, 195 ) 196 197 // Passing a state to the stateless counter should error 198 requireLogs(t, []string{"expected 2 label values but got 3"}, func(logger logs.TaggedLogger) { 199 env.driver.updateCounter(counterVec, logger, "bar", func(c prometheus.Counter) { 200 require.True(t, false, "should have errored") 201 }) 202 }) 203 204 // updateCounter should pass a valid counter with the selected tags 205 requireLogs(t, nil, func(logger logs.TaggedLogger) { 206 env.driver.updateCounter(counterVec, logger, "", func(c prometheus.Counter) { 207 c.Add(1) 208 }) 209 }) 210 211 // Check that the counter was incremented 212 requireCounter(t, counterVec, []string{"foo", "job-id"}, 1) 213 214 // Not passing a state to the stateful counter should error 215 requireLogs(t, []string{"expected 3 label values but got 2"}, func(logger logs.TaggedLogger) { 216 env.driver.updateCounter(counterVecWithState, logger, "", func(c prometheus.Counter) { 217 require.True(t, false, "should have errored") 218 }) 219 }) 220 221 // updateCounter should pass a valid counter with the selected tags 222 requireLogs(t, nil, func(logger logs.TaggedLogger) { 223 env.driver.updateCounter(counterVecWithState, logger, "bar", func(c prometheus.Counter) { 224 c.Add(1) 225 }) 226 }) 227 228 // Check that the counter was incremented 229 requireCounter(t, counterVecWithState, []string{"foo", "job-id", "bar"}, 1) 230 }) 231 require.NoError(t, err) 232 } 233 234 func TestUpdateHistogram(t *testing.T) { 235 t.Parallel() 236 err := withTestEnv(func(env *testEnv) { 237 env.driver.pipelineInfo.ID = "foo" 238 239 histogramVec := prometheus.NewHistogramVec( 240 prometheus.HistogramOpts{ 241 Namespace: "test", Subsystem: "driver", Name: "histogram", 242 Buckets: prometheus.ExponentialBuckets(1.0, 2.0, 20), 243 }, 244 []string{"pipeline", "job"}, 245 ) 246 247 histogramVecWithState := prometheus.NewHistogramVec( 248 prometheus.HistogramOpts{ 249 Namespace: "test", Subsystem: "driver", Name: "histogram_with_state", 250 Buckets: prometheus.ExponentialBuckets(1.0, 2.0, 20), 251 }, 252 []string{"pipeline", "job", "state"}, 253 ) 254 255 // Passing a state to the stateless histogram should error 256 requireLogs(t, []string{"expected 2 label values but got 3"}, func(logger logs.TaggedLogger) { 257 env.driver.updateHistogram(histogramVec, logger, "bar", func(h prometheus.Observer) { 258 require.True(t, false, "should have errored") 259 }) 260 }) 261 262 requireLogs(t, nil, func(logger logs.TaggedLogger) { 263 env.driver.updateHistogram(histogramVec, logger, "", func(h prometheus.Observer) { 264 h.Observe(0) 265 }) 266 }) 267 268 // Check that the counter was incremented 269 requireHistogram(t, histogramVec, []string{"foo", "job-id"}, 1) 270 271 // Not passing a state to the stateful histogram should error 272 requireLogs(t, []string{"expected 3 label values but got 2"}, func(logger logs.TaggedLogger) { 273 env.driver.updateHistogram(histogramVecWithState, logger, "", func(h prometheus.Observer) { 274 require.True(t, false, "should have errored") 275 }) 276 }) 277 278 requireLogs(t, nil, func(logger logs.TaggedLogger) { 279 env.driver.updateHistogram(histogramVecWithState, logger, "bar", func(h prometheus.Observer) { 280 h.Observe(0) 281 }) 282 }) 283 284 // Check that the counter was incremented 285 requireHistogram(t, histogramVecWithState, []string{"foo", "job-id", "bar"}, 1) 286 }) 287 require.NoError(t, err) 288 } 289 290 type inputData struct { 291 path string 292 contents string 293 regex string 294 found bool 295 } 296 297 func newInputDataRegex(path string, regex string) *inputData { 298 return &inputData{path: filepath.Clean(path), regex: regex} 299 } 300 301 func newInputData(path string, contents string) *inputData { 302 return &inputData{path: filepath.Clean(path), contents: contents} 303 } 304 305 func requireEmptyScratch(t *testing.T, inputDir string) { 306 entries, err := ioutil.ReadDir(filepath.Join(inputDir, client.PPSScratchSpace)) 307 308 if !errors.Is(err, os.ErrNotExist) { 309 require.ElementsEqual(t, []os.FileInfo{}, entries) 310 } 311 } 312 313 func requireContents(t *testing.T, dir string, data []*inputData) { 314 checkFile := func(fullPath string, relPath string) { 315 for _, checkData := range data { 316 if checkData.path == relPath { 317 contents, err := ioutil.ReadFile(fullPath) 318 require.NoError(t, err) 319 if checkData.regex != "" { 320 require.Matches(t, checkData.regex, string(contents), "Incorrect contents for input file: %s", relPath) 321 } else { 322 require.Equal(t, checkData.contents, string(contents), "Incorrect contents for input file: %s", relPath) 323 } 324 checkData.found = true 325 return 326 } 327 } 328 require.True(t, false, "Unexpected input file found: %s", relPath) 329 } 330 331 err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { 332 require.NoError(t, err) 333 if info.Name() == ".git" || info.Name() == client.PPSScratchSpace { 334 return filepath.SkipDir 335 } 336 if !info.IsDir() { 337 path = filepath.Clean(path) 338 relPath := strings.TrimLeft(strings.TrimPrefix(path, dir), "/\\") 339 checkFile(path, relPath) 340 } 341 return nil 342 }) 343 require.NoError(t, err) 344 345 for _, checkData := range data { 346 require.True(t, checkData.found, "Expected input file not found: %s", checkData.path) 347 } 348 } 349 350 func TestWithDataEmpty(t *testing.T) { 351 t.Parallel() 352 err := withTestEnv(func(env *testEnv) { 353 requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) { 354 _, err := env.driver.WithData( 355 []*common.Input{}, 356 nil, 357 logger, 358 func(dir string, stats *pps.ProcessStats) error { 359 requireContents(t, dir, []*inputData{}) 360 return nil 361 }, 362 ) 363 require.NoError(t, err) 364 requireEmptyScratch(t, env.driver.InputDir()) 365 requireContents(t, env.driver.InputDir(), []*inputData{}) 366 }) 367 }) 368 require.NoError(t, err) 369 } 370 371 func TestWithDataSpout(t *testing.T) { 372 t.Parallel() 373 err := withTestEnv(func(env *testEnv) { 374 env.driver.pipelineInfo.Spout = &pps.Spout{} 375 requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) { 376 _, err := env.driver.WithData( 377 []*common.Input{}, 378 nil, 379 logger, 380 func(dir string, stats *pps.ProcessStats) error { 381 // A spout pipeline should have created a 'pfs/out` fifo for the user 382 // code to write to 383 requireContents(t, dir, []*inputData{newInputData("out", "")}) 384 return nil 385 }, 386 ) 387 require.NoError(t, err) 388 requireEmptyScratch(t, env.driver.InputDir()) 389 requireContents(t, env.driver.InputDir(), []*inputData{}) 390 }) 391 }) 392 require.NoError(t, err) 393 } 394 395 // Shitty helper function to create possibly-not-malformed input structures 396 func newInput(repo string, path string) *common.Input { 397 return &common.Input{ 398 FileInfo: &pfs.FileInfo{ 399 File: &pfs.File{ 400 Commit: &pfs.Commit{ 401 Repo: &pfs.Repo{ 402 Name: repo, 403 }, 404 ID: "commit-id-string", 405 }, 406 Path: path, 407 }, 408 FileType: pfs.FileType_FILE, 409 }, 410 Name: repo, 411 Branch: "master", 412 } 413 } 414 415 func TestWithDataCancel(t *testing.T) { 416 t.Parallel() 417 err := withTestEnv(func(env *testEnv) { 418 requireLogs(t, []string{"errored downloading data", "context canceled"}, func(logger logs.TaggedLogger) { 419 ctx, cancel := context.WithCancel(env.Context) 420 driver := env.driver.WithContext(ctx) 421 422 // Cancel the context during the download 423 env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error { 424 cancel() 425 <-serv.Context().Done() 426 return errors.Errorf("WalkFile canceled") 427 }) 428 429 _, err := driver.WithData( 430 []*common.Input{newInput("repo", "input.txt")}, 431 nil, 432 logger, 433 func(dir string, stats *pps.ProcessStats) error { 434 require.True(t, false, "Should have been canceled before the callback") 435 cancel() 436 return nil 437 }, 438 ) 439 require.YesError(t, err, "WithData call should have been canceled") 440 requireEmptyScratch(t, env.driver.InputDir()) 441 requireContents(t, env.driver.InputDir(), []*inputData{}) 442 }) 443 }) 444 require.NoError(t, err) 445 } 446 447 // Check that the driver will download the requested inputs, put them in place 448 // during WithData, and clean them up after running the inner function. 449 func TestWithDataDownload(t *testing.T) { 450 t.Parallel() 451 err := withTestEnv(func(env *testEnv) { 452 requireLogs(t, []string{"finished downloading data", "inner function"}, func(logger logs.TaggedLogger) { 453 // Mock out the calls that will be used to download the data 454 env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error { 455 return serv.Send(&pfs.FileInfo{ 456 File: req.File, 457 FileType: pfs.FileType_FILE, 458 }) 459 }) 460 461 env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) error { 462 return serv.Send(&types.BytesValue{Value: []byte(fmt.Sprintf("%s-data", req.File.Commit.Repo.Name))}) 463 }) 464 465 _, err := env.driver.WithData( 466 []*common.Input{newInput("repoA", "input.txt"), newInput("repoB", "input.md")}, 467 nil, 468 logger, 469 func(dir string, stats *pps.ProcessStats) error { 470 requireContents(t, dir, []*inputData{ 471 newInputData("repoA/input.txt", "repoA-data"), 472 newInputData("repoB/input.md", "repoB-data"), 473 }) 474 logger.Logf("inner function") 475 return nil 476 }, 477 ) 478 require.NoError(t, err) 479 requireEmptyScratch(t, env.driver.InputDir()) 480 requireContents(t, env.driver.InputDir(), []*inputData{}) 481 }) 482 }) 483 require.NoError(t, err) 484 } 485 486 // Create several files and directories inside WithData and verify that they are 487 // cleaned up after WithData returns. 488 func TestWithActiveDataCleanup(t *testing.T) { 489 t.Parallel() 490 err := withTestEnv(func(env *testEnv) { 491 create := func(relPath string) { 492 fullPath := filepath.Join(env.driver.InputDir(), relPath) 493 require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0777)) 494 file, err := os.Create(fullPath) 495 require.NoError(t, err) 496 require.NoError(t, file.Close()) 497 } 498 499 requireLogs(t, []string{"finished downloading data", "inner function"}, func(logger logs.TaggedLogger) { 500 _, err := env.driver.WithData( 501 []*common.Input{}, 502 nil, 503 logger, 504 func(dir string, stats *pps.ProcessStats) error { 505 requireContents(t, dir, []*inputData{}) 506 logger.Logf("inner function") 507 508 expectedContents := []*inputData{ 509 newInputData("c", ""), 510 newInputData("out/1", ""), 511 newInputData("out/2/a", ""), 512 newInputData("out/2/b", ""), 513 newInputData("out/2/3/c", ""), 514 newInputData("foo/barbaz", ""), 515 newInputData("foo/bar/baz", ""), 516 newInputData("floop/blarp/blazj/etc", ""), 517 } 518 519 err := env.driver.WithActiveData([]*common.Input{}, dir, func() error { 520 for _, x := range expectedContents { 521 create(x.path) 522 } 523 524 requireContents(t, env.driver.InputDir(), expectedContents) 525 return nil 526 }) 527 require.NoError(t, err) 528 requireContents(t, dir, expectedContents) 529 requireContents(t, env.driver.InputDir(), []*inputData{}) 530 return nil 531 }, 532 ) 533 require.NoError(t, err) 534 requireEmptyScratch(t, env.driver.InputDir()) 535 requireContents(t, env.driver.InputDir(), []*inputData{}) 536 }) 537 }) 538 require.NoError(t, err) 539 } 540 541 func newGitInput(repo string, url string) *common.Input { 542 return &common.Input{ 543 FileInfo: &pfs.FileInfo{ 544 File: &pfs.File{ 545 Commit: &pfs.Commit{ 546 Repo: &pfs.Repo{ 547 Name: repo, 548 }, 549 ID: "commit-id-string", 550 }, 551 Path: "commit.json", 552 }, 553 FileType: pfs.FileType_FILE, 554 }, 555 GitURL: url, 556 Name: repo, 557 } 558 } 559 560 func mockGitGetFile(env *testEnv, repo string, ref string, sha string, cb func(*pfs.GetFileRequest)) { 561 env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) (retErr error) { 562 payload := &github.PushPayload{ 563 Ref: ref, 564 After: sha, 565 } 566 payload.Repository.CloneURL = repo 567 jsonBytes, err := json.Marshal(payload) 568 if err != nil { 569 return err 570 } 571 572 if cb != nil { 573 cb(req) 574 } 575 576 return serv.Send(&types.BytesValue{Value: jsonBytes}) 577 }) 578 } 579 580 func TestWithDataGit(t *testing.T) { 581 t.Parallel() 582 err := withTestEnv(func(env *testEnv) { 583 requireLogs(t, []string{"finished downloading data"}, func(logger logs.TaggedLogger) { 584 var getFileReq *pfs.GetFileRequest 585 mockGitGetFile(env, inputGitRepo, "refs/heads/master", "9047fbfc251e7412ef3300868f743f2c24852539", func(req *pfs.GetFileRequest) { 586 getFileReq = req 587 }) 588 589 _, err := env.driver.WithData( 590 []*common.Input{newGitInput("artifacts", inputGitRepo)}, 591 nil, 592 logger, 593 func(dir string, stats *pps.ProcessStats) error { 594 requireContents(t, dir, []*inputData{newInputDataRegex("artifacts/readme.md", "Test Artifacts")}) 595 return nil 596 }, 597 ) 598 require.NoError(t, err) 599 require.NotNil(t, getFileReq) 600 require.Equal(t, getFileReq.File, client.NewFile("artifacts", "commit-id-string", "commit.json")) 601 requireEmptyScratch(t, env.driver.InputDir()) 602 requireContents(t, env.driver.InputDir(), []*inputData{}) 603 }) 604 }) 605 require.NoError(t, err) 606 } 607 608 func TestWithDataGitHookError(t *testing.T) { 609 t.Parallel() 610 err := withTestEnv(func(env *testEnv) { 611 requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) { 612 mockGitGetFile(env, "", "", "", nil) 613 614 _, err := env.driver.WithData( 615 []*common.Input{newGitInput("artifacts", inputGitRepo)}, 616 nil, 617 logger, 618 func(dir string, stats *pps.ProcessStats) error { 619 require.True(t, false, "Should have errored before calling WithData callback") 620 return nil 621 }, 622 ) 623 require.YesError(t, err) 624 require.Matches(t, "payload does not specify", err.Error()) 625 requireEmptyScratch(t, env.driver.InputDir()) 626 requireContents(t, env.driver.InputDir(), []*inputData{}) 627 }) 628 }) 629 require.NoError(t, err) 630 } 631 632 func TestWithDataGitRepoMissing(t *testing.T) { 633 t.Parallel() 634 err := withTestEnv(func(env *testEnv) { 635 requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) { 636 mockGitGetFile(env, inputGitRepoFake, "refs/heads/master", "foobar", nil) 637 638 _, err := env.driver.WithData( 639 []*common.Input{newGitInput("artifacts", inputGitRepo)}, 640 nil, 641 logger, 642 func(dir string, stats *pps.ProcessStats) error { 643 require.True(t, false, "Should have errored before calling WithData callback") 644 return nil 645 }, 646 ) 647 require.YesError(t, err) 648 require.Matches(t, "authentication required", err.Error()) 649 requireEmptyScratch(t, env.driver.InputDir()) 650 requireContents(t, env.driver.InputDir(), []*inputData{}) 651 }) 652 }) 653 require.NoError(t, err) 654 } 655 656 func TestWithDataGitInvalidSHA(t *testing.T) { 657 t.Parallel() 658 err := withTestEnv(func(env *testEnv) { 659 requireLogs(t, []string{"errored downloading data"}, func(logger logs.TaggedLogger) { 660 mockGitGetFile(env, inputGitRepo, "refs/heads/master", "foobar", nil) 661 662 _, err := env.driver.WithData( 663 []*common.Input{newGitInput("artifacts", inputGitRepo)}, 664 nil, 665 logger, 666 func(dir string, stats *pps.ProcessStats) error { 667 require.True(t, false, "Should have errored before calling WithData callback") 668 return nil 669 }, 670 ) 671 require.YesError(t, err) 672 require.Matches(t, "could not find SHA foobar", err.Error()) 673 requireEmptyScratch(t, env.driver.InputDir()) 674 requireContents(t, env.driver.InputDir(), []*inputData{}) 675 }) 676 }) 677 require.NoError(t, err) 678 } 679 680 // Test that user code will successfully run and the output will be forwarded to logs 681 func TestRunUserCode(t *testing.T) { 682 t.Parallel() 683 logMessage := "this is a user code log message" 684 err := withTestEnv(func(env *testEnv) { 685 env.driver.pipelineInfo.Transform.Cmd = []string{"echo", logMessage} 686 requireLogs(t, []string{logMessage}, func(logger logs.TaggedLogger) { 687 err := env.driver.RunUserCode(logger, []string{}, nil, nil) 688 require.NoError(t, err) 689 }) 690 }) 691 require.NoError(t, err) 692 } 693 694 func TestRunUserCodeError(t *testing.T) { 695 t.Parallel() 696 err := withTestEnv(func(env *testEnv) { 697 env.driver.pipelineInfo.Transform.Cmd = []string{"false"} 698 requireLogs(t, []string{"exit status 1"}, func(logger logs.TaggedLogger) { 699 err := env.driver.RunUserCode(logger, []string{}, nil, nil) 700 require.YesError(t, err) 701 }) 702 }) 703 require.NoError(t, err) 704 } 705 706 func TestRunUserCodeNoCommand(t *testing.T) { 707 t.Parallel() 708 err := withTestEnv(func(env *testEnv) { 709 env.driver.pipelineInfo.Transform.Cmd = []string{} 710 requireLogs(t, []string{"no command specified"}, func(logger logs.TaggedLogger) { 711 err := env.driver.RunUserCode(logger, []string{}, nil, nil) 712 require.YesError(t, err) 713 }) 714 }) 715 require.NoError(t, err) 716 } 717 718 func TestRunUserCodeTimeout(t *testing.T) { 719 t.Parallel() 720 err := withTestEnv(func(env *testEnv) { 721 env.driver.pipelineInfo.Transform.Cmd = []string{"sleep", "10"} 722 timeout := types.DurationProto(10 * time.Millisecond) 723 requireLogs(t, []string{"context deadline exceeded"}, func(logger logs.TaggedLogger) { 724 err := env.driver.RunUserCode(logger, []string{}, nil, timeout) 725 require.YesError(t, err) 726 require.Matches(t, "context deadline exceeded", err.Error()) 727 }) 728 }) 729 require.NoError(t, err) 730 } 731 732 func TestRunUserCodeEnv(t *testing.T) { 733 t.Parallel() 734 err := withTestEnv(func(env *testEnv) { 735 env.driver.pipelineInfo.Transform.Cmd = []string{"env"} 736 requireLogs(t, []string{"FOO=password", "BAR=hunter2"}, func(logger logs.TaggedLogger) { 737 err := env.driver.RunUserCode(logger, []string{"FOO=password", "BAR=hunter2"}, nil, nil) 738 require.NoError(t, err) 739 }) 740 }) 741 require.NoError(t, err) 742 } 743 744 func TestRunUserCodeWithData(t *testing.T) { 745 t.Parallel() 746 err := withTestEnv(func(env *testEnv) { 747 env.driver.pipelineInfo.Transform.Cmd = []string{"bash", "-c", "cat pfs/repoA/input.txt pfs/repoB/input.md > pfs/out/output.txt"} 748 requireLogs(t, []string{"finished running user code"}, func(logger logs.TaggedLogger) { 749 // Mock out the calls that will be used to download the data 750 env.MockPachd.PFS.WalkFile.Use(func(req *pfs.WalkFileRequest, serv pfs.API_WalkFileServer) error { 751 return serv.Send(&pfs.FileInfo{ 752 File: req.File, 753 FileType: pfs.FileType_FILE, 754 }) 755 }) 756 757 env.MockPachd.PFS.GetFile.Use(func(req *pfs.GetFileRequest, serv pfs.API_GetFileServer) error { 758 return serv.Send(&types.BytesValue{Value: []byte(fmt.Sprintf("%s-data", req.File.Commit.Repo.Name))}) 759 }) 760 761 inputs := []*common.Input{newInput("repoA", "input.txt"), newInput("repoB", "input.md")} 762 _, err := env.driver.WithData( 763 inputs, 764 nil, 765 logger, 766 func(dir string, stats *pps.ProcessStats) error { 767 requireContents(t, dir, []*inputData{ 768 newInputData("repoA/input.txt", "repoA-data"), 769 newInputData("repoB/input.md", "repoB-data"), 770 }) 771 772 err := env.driver.WithActiveData(inputs, dir, func() error { 773 requireContents(t, env.driver.InputDir(), []*inputData{ 774 newInputData("repoA/input.txt", "repoA-data"), 775 newInputData("repoB/input.md", "repoB-data"), 776 }) 777 778 err := env.driver.RunUserCode(logger, []string{}, nil, nil) 779 require.NoError(t, err) 780 781 requireContents(t, env.driver.InputDir(), []*inputData{ 782 newInputData("repoA/input.txt", "repoA-data"), 783 newInputData("repoB/input.md", "repoB-data"), 784 newInputData("out/output.txt", "repoA-datarepoB-data"), 785 }) 786 return nil 787 }) 788 require.NoError(t, err) 789 790 requireContents(t, dir, []*inputData{ 791 newInputData("repoA/input.txt", "repoA-data"), 792 newInputData("repoB/input.md", "repoB-data"), 793 newInputData("out/output.txt", "repoA-datarepoB-data"), 794 }) 795 return nil 796 }, 797 ) 798 require.NoError(t, err) 799 requireEmptyScratch(t, env.driver.InputDir()) 800 requireContents(t, env.driver.InputDir(), []*inputData{}) 801 }) 802 }) 803 require.NoError(t, err) 804 }