go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/dsmapper/controller_test.go (about)

     1  // Copyright 2018 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dsmapper
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  	"time"
    21  
    22  	"google.golang.org/protobuf/proto"
    23  	"google.golang.org/protobuf/types/known/timestamppb"
    24  
    25  	"go.chromium.org/luci/common/clock"
    26  	"go.chromium.org/luci/common/clock/testclock"
    27  	"go.chromium.org/luci/common/errors"
    28  	"go.chromium.org/luci/common/logging/gologger"
    29  	"go.chromium.org/luci/common/retry/transient"
    30  	"go.chromium.org/luci/gae/filter/txndefer"
    31  	"go.chromium.org/luci/gae/impl/memory"
    32  	"go.chromium.org/luci/gae/service/datastore"
    33  
    34  	"go.chromium.org/luci/server/dsmapper/dsmapperpb"
    35  	"go.chromium.org/luci/server/dsmapper/internal/splitter"
    36  	"go.chromium.org/luci/server/dsmapper/internal/tasks"
    37  	"go.chromium.org/luci/server/tq"
    38  	"go.chromium.org/luci/server/tq/tqtesting"
    39  
    40  	. "github.com/smartystreets/goconvey/convey"
    41  	. "go.chromium.org/luci/common/testing/assertions"
    42  )
    43  
    44  var (
    45  	testTime        = testclock.TestRecentTimeUTC.Round(time.Millisecond)
    46  	testTimeAsProto = timestamppb.New(testTime)
    47  )
    48  
    49  type intEnt struct {
    50  	ID int64 `gae:"$id"`
    51  }
    52  
    53  func TestController(t *testing.T) {
    54  	t.Parallel()
    55  
    56  	Convey("With controller", t, func() {
    57  		ctx := txndefer.FilterRDS(memory.Use(context.Background()))
    58  		ctx = gologger.StdConfig.Use(ctx)
    59  		ctx, tc := testclock.UseTime(ctx, testTime)
    60  		tc.SetTimerCallback(func(d time.Duration, t clock.Timer) {
    61  			if testclock.HasTags(t, tqtesting.ClockTag) {
    62  				tc.Add(d)
    63  			}
    64  		})
    65  
    66  		dispatcher := &tq.Dispatcher{}
    67  		ctx, sched := tq.TestingContext(ctx, dispatcher)
    68  
    69  		ctl := Controller{
    70  			MapperQueue:  "mapper-queue",
    71  			ControlQueue: "control-queue",
    72  		}
    73  		ctl.Install(dispatcher)
    74  
    75  		// mapperFunc is set by test cases.
    76  		var mapperFunc func(params []byte, shardIdx int, keys []*datastore.Key) error
    77  
    78  		const testMapperID ID = "test-mapper"
    79  		ctl.RegisterFactory(testMapperID, func(_ context.Context, j *Job, idx int) (Mapper, error) {
    80  			return func(_ context.Context, keys []*datastore.Key) error {
    81  				if mapperFunc == nil {
    82  					return nil
    83  				}
    84  				return mapperFunc(j.Config.Params, idx, keys)
    85  			}, nil
    86  		})
    87  
    88  		spinUntilDone := func(expectErrors bool) (executed []proto.Message) {
    89  			var succeeded tqtesting.TaskList
    90  			sched.TaskSucceeded = tqtesting.TasksCollector(&succeeded)
    91  			sched.TaskFailed = func(ctx context.Context, task *tqtesting.Task) {
    92  				if !expectErrors {
    93  					t.Fatalf("task %q %s failed unexpectedly", task.Name, task.Payload)
    94  				}
    95  			}
    96  			sched.Run(ctx, tqtesting.StopWhenDrained())
    97  			return succeeded.Payloads()
    98  		}
    99  
   100  		// Create a bunch of entities to run the mapper over.
   101  		entities := make([]intEnt, 512)
   102  		So(datastore.Put(ctx, entities), ShouldBeNil)
   103  		datastore.GetTestable(ctx).CatchupIndexes()
   104  
   105  		Convey("LaunchJob works", func() {
   106  			cfg := JobConfig{
   107  				Query: Query{
   108  					Kind: "intEnt",
   109  				},
   110  				Mapper:        testMapperID,
   111  				Params:        []byte("zzz"),
   112  				ShardCount:    4,
   113  				PageSize:      33, // make it weird to trigger "incomplete" pages
   114  				PagesPerTask:  2,  // to trigger multiple mapping tasks in a chain
   115  				TrackProgress: true,
   116  			}
   117  
   118  			// Before we start, there' no job with ID 1.
   119  			j, err := ctl.GetJob(ctx, 1)
   120  			So(err, ShouldEqual, ErrNoSuchJob)
   121  			So(j, ShouldBeNil)
   122  
   123  			jobID, err := ctl.LaunchJob(ctx, &cfg)
   124  			So(err, ShouldBeNil)
   125  			So(jobID, ShouldEqual, 1)
   126  
   127  			// In "starting" state.
   128  			job, err := ctl.GetJob(ctx, jobID)
   129  			So(err, ShouldBeNil)
   130  			So(job, ShouldResemble, &Job{
   131  				ID:      jobID,
   132  				Config:  cfg,
   133  				State:   dsmapperpb.State_STARTING,
   134  				Created: testTime,
   135  				Updated: testTime,
   136  			})
   137  
   138  			// No shards in the info yet.
   139  			info, err := job.FetchInfo(ctx)
   140  			So(err, ShouldBeNil)
   141  			So(info, ShouldResembleProto, &dsmapperpb.JobInfo{
   142  				Id:            int64(jobID),
   143  				State:         dsmapperpb.State_STARTING,
   144  				Created:       testTimeAsProto,
   145  				Updated:       testTimeAsProto,
   146  				TotalEntities: -1,
   147  			})
   148  
   149  			// Roll TQ forward.
   150  			sched.Run(ctx, tqtesting.StopBeforeTask("dsmapper-fan-out-shards"))
   151  
   152  			// Switched into "running" state.
   153  			job, err = ctl.GetJob(ctx, jobID)
   154  			So(err, ShouldBeNil)
   155  			So(job.State, ShouldEqual, dsmapperpb.State_RUNNING)
   156  
   157  			expectedShard := func(id int64, idx int, l, r, expected int64) shard {
   158  				rng := splitter.Range{}
   159  				if l != -1 {
   160  					rng.Start = datastore.KeyForObj(ctx, &intEnt{ID: l})
   161  				}
   162  				if r != -1 {
   163  					rng.End = datastore.KeyForObj(ctx, &intEnt{ID: r})
   164  				}
   165  				return shard{
   166  					ID:            id,
   167  					JobID:         jobID,
   168  					Index:         idx,
   169  					State:         dsmapperpb.State_STARTING,
   170  					Range:         rng,
   171  					ExpectedCount: expected,
   172  					Created:       testTime,
   173  					Updated:       testTime,
   174  				}
   175  			}
   176  
   177  			// Created the shard entities.
   178  			shards, err := job.fetchShards(ctx)
   179  			So(err, ShouldBeNil)
   180  			So(shards, ShouldResemble, []shard{
   181  				expectedShard(1, 0, -1, 136, 136),
   182  				expectedShard(2, 1, 136, 268, 132),
   183  				expectedShard(3, 2, 268, 399, 131),
   184  				expectedShard(4, 3, 399, -1, 113),
   185  			})
   186  
   187  			// Shards also appear in the info now.
   188  			info, err = job.FetchInfo(ctx)
   189  			So(err, ShouldBeNil)
   190  
   191  			expectedShardInfo := func(idx, total int) *dsmapperpb.ShardInfo {
   192  				return &dsmapperpb.ShardInfo{
   193  					Index:         int32(idx),
   194  					State:         dsmapperpb.State_STARTING,
   195  					Created:       testTimeAsProto,
   196  					Updated:       testTimeAsProto,
   197  					TotalEntities: int64(total),
   198  				}
   199  			}
   200  			So(info, ShouldResembleProto, &dsmapperpb.JobInfo{
   201  				Id:            int64(jobID),
   202  				State:         dsmapperpb.State_RUNNING,
   203  				Created:       testTimeAsProto,
   204  				Updated:       testTimeAsProto,
   205  				TotalEntities: 512,
   206  				Shards: []*dsmapperpb.ShardInfo{
   207  					expectedShardInfo(0, 136),
   208  					expectedShardInfo(1, 132),
   209  					expectedShardInfo(2, 131),
   210  					expectedShardInfo(3, 113),
   211  				},
   212  			})
   213  
   214  			visitShards := func(cb func(s shard)) {
   215  				visitedShards, err := job.fetchShards(ctx)
   216  				So(err, ShouldBeNil)
   217  				So(visitedShards, ShouldHaveLength, cfg.ShardCount)
   218  				for _, s := range visitedShards {
   219  					cb(s)
   220  				}
   221  			}
   222  
   223  			seen := make(map[int64]struct{}, len(entities))
   224  
   225  			updateSeen := func(keys []*datastore.Key) {
   226  				for _, k := range keys {
   227  					_, ok := seen[k.IntID()]
   228  					So(ok, ShouldBeFalse)
   229  					seen[k.IntID()] = struct{}{}
   230  				}
   231  			}
   232  
   233  			assertAllSeen := func() {
   234  				So(len(seen), ShouldEqual, len(entities))
   235  				for _, e := range entities {
   236  					_, ok := seen[e.ID]
   237  					So(ok, ShouldBeTrue)
   238  				}
   239  			}
   240  
   241  			Convey("No errors when processing shards", func() {
   242  				mapperFunc = func(params []byte, shardIdx int, keys []*datastore.Key) error {
   243  					So(len(keys), ShouldBeLessThanOrEqualTo, cfg.PageSize)
   244  					So(params, ShouldResemble, cfg.Params)
   245  					updateSeen(keys)
   246  					return nil
   247  				}
   248  
   249  				spinUntilDone(false)
   250  
   251  				visitShards(func(s shard) {
   252  					So(s.State, ShouldEqual, dsmapperpb.State_SUCCESS)
   253  					So(s.ProcessTaskNum, ShouldEqual, 2)
   254  					So(s.ProcessedCount, ShouldEqual, []int64{
   255  						136, 132, 131, 113,
   256  					}[s.Index])
   257  				})
   258  
   259  				assertAllSeen()
   260  
   261  				job, err = ctl.GetJob(ctx, jobID)
   262  				So(err, ShouldBeNil)
   263  				So(job.State, ShouldEqual, dsmapperpb.State_SUCCESS)
   264  
   265  				info, err := job.FetchInfo(ctx)
   266  				So(err, ShouldBeNil)
   267  
   268  				expectedShardInfo := func(idx, total int) *dsmapperpb.ShardInfo {
   269  					return &dsmapperpb.ShardInfo{
   270  						Index:             int32(idx),
   271  						State:             dsmapperpb.State_SUCCESS,
   272  						Created:           testTimeAsProto,
   273  						Updated:           testTimeAsProto,
   274  						TotalEntities:     int64(total),
   275  						ProcessedEntities: int64(total),
   276  					}
   277  				}
   278  				So(info, ShouldResembleProto, &dsmapperpb.JobInfo{
   279  					Id:      int64(jobID),
   280  					State:   dsmapperpb.State_SUCCESS,
   281  					Created: testTimeAsProto,
   282  					// There's 2 sec delay before UpdateJobState task.
   283  					Updated:           timestamppb.New(testTime.Add(2 * time.Second)),
   284  					TotalEntities:     512,
   285  					ProcessedEntities: 512,
   286  					EntitiesPerSec:    256,
   287  					Shards: []*dsmapperpb.ShardInfo{
   288  						expectedShardInfo(0, 136),
   289  						expectedShardInfo(1, 132),
   290  						expectedShardInfo(2, 131),
   291  						expectedShardInfo(3, 113),
   292  					},
   293  				})
   294  			})
   295  
   296  			Convey("One shard fails", func() {
   297  				page := 0
   298  				processed := 0
   299  
   300  				mapperFunc = func(_ []byte, shardIdx int, keys []*datastore.Key) error {
   301  					if shardIdx == 1 {
   302  						page++
   303  						if page == 2 {
   304  							return errors.New("boom")
   305  						}
   306  					}
   307  					processed += len(keys)
   308  					return nil
   309  				}
   310  
   311  				spinUntilDone(true)
   312  
   313  				visitShards(func(s shard) {
   314  					if s.Index == 1 {
   315  						So(s.State, ShouldEqual, dsmapperpb.State_FAIL)
   316  						So(s.Error, ShouldEqual, `while mapping 33 keys: boom`)
   317  					} else {
   318  						So(s.State, ShouldEqual, dsmapperpb.State_SUCCESS)
   319  						So(s.ProcessTaskNum, ShouldEqual, 2)
   320  					}
   321  					So(s.ProcessedCount, ShouldEqual, []int64{
   322  						136, 33, 131, 113, // the failed shard is incomplete
   323  					}[s.Index])
   324  				})
   325  
   326  				// There are 5 pages per shard. We aborted on second. So 3 are skipped.
   327  				So(processed, ShouldEqual, len(entities)-3*cfg.PageSize)
   328  
   329  				job, err = ctl.GetJob(ctx, jobID)
   330  				So(err, ShouldBeNil)
   331  				So(job.State, ShouldEqual, dsmapperpb.State_FAIL)
   332  			})
   333  
   334  			Convey("Job aborted midway", func() {
   335  				processed := 0
   336  
   337  				mapperFunc = func(_ []byte, shardIdx int, keys []*datastore.Key) error {
   338  					processed += len(keys)
   339  
   340  					job, err = ctl.AbortJob(ctx, jobID)
   341  					So(err, ShouldBeNil)
   342  					So(job.State, ShouldEqual, dsmapperpb.State_ABORTING)
   343  
   344  					return nil
   345  				}
   346  
   347  				spinUntilDone(false)
   348  
   349  				// All shards eventually discovered that the job was aborted.
   350  				visitShards(func(s shard) {
   351  					So(s.State, ShouldEqual, dsmapperpb.State_ABORTED)
   352  				})
   353  
   354  				// And the job itself eventually switched into ABORTED state.
   355  				job, err = ctl.GetJob(ctx, jobID)
   356  				So(err, ShouldBeNil)
   357  				So(job.State, ShouldEqual, dsmapperpb.State_ABORTED)
   358  
   359  				// Processed 2 pages (instead of 1), since processShardHandler doesn't
   360  				// check job state inside the processing loop (only at the beginning).
   361  				So(processed, ShouldEqual, 2*cfg.PageSize)
   362  			})
   363  
   364  			Convey("processShardHandler saves state on transient errors", func() {
   365  				pages := 0
   366  
   367  				mapperFunc = func(_ []byte, shardIdx int, keys []*datastore.Key) error {
   368  					pages++
   369  					if pages == 2 {
   370  						return errors.New("boom", transient.Tag)
   371  					}
   372  					return nil
   373  				}
   374  
   375  				err := ctl.processShardHandler(ctx, &tasks.ProcessShard{
   376  					JobId:   int64(job.ID),
   377  					ShardId: shards[0].ID,
   378  				})
   379  				So(transient.Tag.In(err), ShouldBeTrue)
   380  
   381  				// Shard's resume point is updated. Its taskNum is left unchanged, since
   382  				// we are going to retry the task.
   383  				sh, err := getActiveShard(ctx, shards[0].ID, shards[0].ProcessTaskNum)
   384  				So(err, ShouldBeNil)
   385  				So(sh.ResumeFrom, ShouldNotBeNil)
   386  				So(sh.ProcessedCount, ShouldEqual, 33)
   387  			})
   388  		})
   389  
   390  		Convey("With simple starting job", func() {
   391  			cfg := JobConfig{
   392  				Query:      Query{Kind: "intEnt"},
   393  				Mapper:     testMapperID,
   394  				ShardCount: 4,
   395  				PageSize:   64,
   396  			}
   397  
   398  			jobID, err := ctl.LaunchJob(ctx, &cfg)
   399  			So(err, ShouldBeNil)
   400  			So(jobID, ShouldEqual, 1)
   401  
   402  			// In "starting" state initially.
   403  			job, err := ctl.GetJob(ctx, jobID)
   404  			So(err, ShouldBeNil)
   405  			So(job.State, ShouldEqual, dsmapperpb.State_STARTING)
   406  
   407  			Convey("Abort right after start", func() {
   408  				job, err := ctl.AbortJob(ctx, jobID)
   409  				So(err, ShouldBeNil)
   410  				So(job.State, ShouldEqual, dsmapperpb.State_ABORTED) // aborted right away
   411  
   412  				// Didn't actually launch any shards.
   413  				So(spinUntilDone(false), ShouldResembleProto, []proto.Message{
   414  					&tasks.SplitAndLaunch{JobId: int64(jobID)},
   415  				})
   416  			})
   417  
   418  			Convey("Abort after shards are created", func() {
   419  				// Stop right after we created the shards, before we launch them.
   420  				sched.Run(ctx, tqtesting.StopBeforeTask("dsmapper-fan-out-shards"))
   421  
   422  				job, err := ctl.AbortJob(ctx, jobID)
   423  				So(err, ShouldBeNil)
   424  				So(job.State, ShouldEqual, dsmapperpb.State_ABORTING) // waits for shards to die
   425  
   426  				spinUntilDone(false)
   427  
   428  				job, err = ctl.AbortJob(ctx, jobID)
   429  				So(err, ShouldBeNil)
   430  				So(job.State, ShouldEqual, dsmapperpb.State_ABORTED) // all shards are dead now
   431  
   432  				// Dead indeed.
   433  				info, err := job.FetchInfo(ctx)
   434  				So(err, ShouldBeNil)
   435  				So(info.Shards, ShouldHaveLength, 4)
   436  				for _, s := range info.Shards {
   437  					So(s.State, ShouldEqual, dsmapperpb.State_ABORTED)
   438  				}
   439  			})
   440  		})
   441  	})
   442  }