github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/causet/embedded/task.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package embedded
    15  
    16  import (
    17  	"math"
    18  
    19  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    20  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    21  	"github.com/whtcorpsinc/BerolinaSQL/charset"
    22  	"github.com/whtcorpsinc/milevadb/causet/property"
    23  	"github.com/whtcorpsinc/milevadb/causet/soliton"
    24  	"github.com/whtcorpsinc/milevadb/config"
    25  	"github.com/whtcorpsinc/milevadb/ekv"
    26  	"github.com/whtcorpsinc/milevadb/memex"
    27  	"github.com/whtcorpsinc/milevadb/memex/aggregation"
    28  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    29  	"github.com/whtcorpsinc/milevadb/soliton/plancodec"
    30  	"github.com/whtcorpsinc/milevadb/statistics"
    31  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    32  	"github.com/whtcorpsinc/milevadb/types"
    33  )
    34  
    35  // task is a new version of `PhysicalCausetInfo`. It stores cost information for a task.
    36  // A task may be CopTask, RootTask, MPPTask or a ParallelTask.
    37  type task interface {
    38  	count() float64
    39  	addCost(cost float64)
    40  	cost() float64
    41  	copy() task
    42  	plan() PhysicalCauset
    43  	invalid() bool
    44  }
    45  
    46  // copTask is a task that runs in a distributed ekv causetstore.
    47  // TODO: In future, we should split copTask to indexTask and blockTask.
    48  type copTask struct {
    49  	indexCauset PhysicalCauset
    50  	blockCauset PhysicalCauset
    51  	cst         float64
    52  	// indexCausetFinished means we have finished index plan.
    53  	indexCausetFinished bool
    54  	// keepOrder indicates if the plan scans data by order.
    55  	keepOrder bool
    56  	// doubleReadNeedProj means an extra prune is needed because
    57  	// in double read case, it may output one more column for handle(event id).
    58  	doubleReadNeedProj bool
    59  
    60  	extraHandleDefCaus   *memex.DeferredCauset
    61  	commonHandleDefCauss []*memex.DeferredCauset
    62  	// tblDefCausHists stores the original stats of DataSource, it is used to get
    63  	// average event width when computing network cost.
    64  	tblDefCausHists *statistics.HistDefCausl
    65  	// tblDefCauss stores the original columns of DataSource before being pruned, it
    66  	// is used to compute average event width when computing scan cost.
    67  	tblDefCauss         []*memex.DeferredCauset
    68  	idxMergePartCausets []PhysicalCauset
    69  	// rootTaskConds stores select conditions containing virtual columns.
    70  	// These conditions can't push to EinsteinDB, so we have to add a selection for rootTask
    71  	rootTaskConds []memex.Expression
    72  
    73  	// For causet partition.
    74  	partitionInfo PartitionInfo
    75  }
    76  
    77  func (t *copTask) invalid() bool {
    78  	return t.blockCauset == nil && t.indexCauset == nil
    79  }
    80  
    81  func (t *rootTask) invalid() bool {
    82  	return t.p == nil
    83  }
    84  
    85  func (t *copTask) count() float64 {
    86  	if t.indexCausetFinished {
    87  		return t.blockCauset.statsInfo().RowCount
    88  	}
    89  	return t.indexCauset.statsInfo().RowCount
    90  }
    91  
    92  func (t *copTask) addCost(cst float64) {
    93  	t.cst += cst
    94  }
    95  
    96  func (t *copTask) cost() float64 {
    97  	return t.cst
    98  }
    99  
   100  func (t *copTask) copy() task {
   101  	nt := *t
   102  	return &nt
   103  }
   104  
   105  func (t *copTask) plan() PhysicalCauset {
   106  	if t.indexCausetFinished {
   107  		return t.blockCauset
   108  	}
   109  	return t.indexCauset
   110  }
   111  
   112  func attachCauset2Task(p PhysicalCauset, t task) task {
   113  	switch v := t.(type) {
   114  	case *copTask:
   115  		if v.indexCausetFinished {
   116  			p.SetChildren(v.blockCauset)
   117  			v.blockCauset = p
   118  		} else {
   119  			p.SetChildren(v.indexCauset)
   120  			v.indexCauset = p
   121  		}
   122  	case *rootTask:
   123  		p.SetChildren(v.p)
   124  		v.p = p
   125  	}
   126  	return t
   127  }
   128  
   129  // finishIndexCauset means we no longer add plan to index plan, and compute the network cost for it.
   130  func (t *copTask) finishIndexCauset() {
   131  	if t.indexCausetFinished {
   132  		return
   133  	}
   134  	cnt := t.count()
   135  	t.indexCausetFinished = true
   136  	sessVars := t.indexCauset.SCtx().GetStochastikVars()
   137  	// Network cost of transferring rows of index scan to MilevaDB.
   138  	t.cst += cnt * sessVars.NetworkFactor * t.tblDefCausHists.GetAvgRowSize(t.indexCauset.SCtx(), t.indexCauset.Schema().DeferredCausets, true, false)
   139  
   140  	if t.blockCauset == nil {
   141  		return
   142  	}
   143  	// Calculate the IO cost of causet scan here because we cannot know its stats until we finish index plan.
   144  	t.blockCauset.(*PhysicalBlockScan).stats = t.indexCauset.statsInfo()
   145  	var p PhysicalCauset
   146  	for p = t.indexCauset; len(p.Children()) > 0; p = p.Children()[0] {
   147  	}
   148  	rowSize := t.tblDefCausHists.GetIndexAvgRowSize(t.indexCauset.SCtx(), t.tblDefCauss, p.(*PhysicalIndexScan).Index.Unique)
   149  	t.cst += cnt * rowSize * sessVars.ScanFactor
   150  }
   151  
   152  func (t *copTask) getStoreType() ekv.StoreType {
   153  	if t.blockCauset == nil {
   154  		return ekv.EinsteinDB
   155  	}
   156  	tp := t.blockCauset
   157  	for len(tp.Children()) > 0 {
   158  		if len(tp.Children()) > 1 {
   159  			return ekv.TiFlash
   160  		}
   161  		tp = tp.Children()[0]
   162  	}
   163  	if ts, ok := tp.(*PhysicalBlockScan); ok {
   164  		return ts.StoreType
   165  	}
   166  	return ekv.EinsteinDB
   167  }
   168  
   169  func (p *basePhysicalCauset) attach2Task(tasks ...task) task {
   170  	t := finishCopTask(p.ctx, tasks[0].copy())
   171  	return attachCauset2Task(p.self, t)
   172  }
   173  
   174  func (p *PhysicalUnionScan) attach2Task(tasks ...task) task {
   175  	p.stats = tasks[0].plan().statsInfo()
   176  	return p.basePhysicalCauset.attach2Task(tasks...)
   177  }
   178  
   179  func (p *PhysicalApply) attach2Task(tasks ...task) task {
   180  	lTask := finishCopTask(p.ctx, tasks[0].copy())
   181  	rTask := finishCopTask(p.ctx, tasks[1].copy())
   182  	p.SetChildren(lTask.plan(), rTask.plan())
   183  	p.schemaReplicant = BuildPhysicalJoinSchema(p.JoinType, p)
   184  	return &rootTask{
   185  		p:   p,
   186  		cst: p.GetCost(lTask.count(), rTask.count(), lTask.cost(), rTask.cost()),
   187  	}
   188  }
   189  
   190  // GetCost computes the cost of apply operator.
   191  func (p *PhysicalApply) GetCost(lCount, rCount, lCost, rCost float64) float64 {
   192  	var cpuCost float64
   193  	sessVars := p.ctx.GetStochastikVars()
   194  	if len(p.LeftConditions) > 0 {
   195  		cpuCost += lCount * sessVars.CPUFactor
   196  		lCount *= SelectionFactor
   197  	}
   198  	if len(p.RightConditions) > 0 {
   199  		cpuCost += lCount * rCount * sessVars.CPUFactor
   200  		rCount *= SelectionFactor
   201  	}
   202  	if len(p.EqualConditions)+len(p.OtherConditions) > 0 {
   203  		if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   204  			p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   205  			cpuCost += lCount * rCount * sessVars.CPUFactor * 0.5
   206  		} else {
   207  			cpuCost += lCount * rCount * sessVars.CPUFactor
   208  		}
   209  	}
   210  	// Apply uses a NestedLoop method for execution.
   211  	// For every event from the left(outer) side, it executes
   212  	// the whole right(inner) plan tree. So the cost of apply
   213  	// should be : apply cost + left cost + left count * right cost
   214  	return cpuCost + lCost + lCount*rCost
   215  }
   216  
   217  func (p *PhysicalIndexMergeJoin) attach2Task(tasks ...task) task {
   218  	innerTask := p.innerTask
   219  	outerTask := finishCopTask(p.ctx, tasks[1-p.InnerChildIdx].copy())
   220  	if p.InnerChildIdx == 1 {
   221  		p.SetChildren(outerTask.plan(), innerTask.plan())
   222  	} else {
   223  		p.SetChildren(innerTask.plan(), outerTask.plan())
   224  	}
   225  	return &rootTask{
   226  		p:   p,
   227  		cst: p.GetCost(outerTask, innerTask),
   228  	}
   229  }
   230  
   231  // GetCost computes the cost of index merge join operator and its children.
   232  func (p *PhysicalIndexMergeJoin) GetCost(outerTask, innerTask task) float64 {
   233  	var cpuCost float64
   234  	outerCnt, innerCnt := outerTask.count(), innerTask.count()
   235  	sessVars := p.ctx.GetStochastikVars()
   236  	// Add the cost of evaluating outer filter, since inner filter of index join
   237  	// is always empty, we can simply tell whether outer filter is empty using the
   238  	// summed length of left/right conditions.
   239  	if len(p.LeftConditions)+len(p.RightConditions) > 0 {
   240  		cpuCost += sessVars.CPUFactor * outerCnt
   241  		outerCnt *= SelectionFactor
   242  	}
   243  	// Cost of extracting lookup keys.
   244  	innerCPUCost := sessVars.CPUFactor * outerCnt
   245  	// Cost of sorting and removing duplicate lookup keys:
   246  	// (outerCnt / batchSize) * (sortFactor + 1.0) * batchSize * cpuFactor
   247  	// If `p.NeedOuterSort` is true, the sortFactor is batchSize * Log2(batchSize).
   248  	// Otherwise, it's 0.
   249  	batchSize := math.Min(float64(p.ctx.GetStochastikVars().IndexJoinBatchSize), outerCnt)
   250  	sortFactor := 0.0
   251  	if p.NeedOuterSort {
   252  		sortFactor = math.Log2(float64(batchSize))
   253  	}
   254  	if batchSize > 2 {
   255  		innerCPUCost += outerCnt * (sortFactor + 1.0) * sessVars.CPUFactor
   256  	}
   257  	// Add cost of building inner interlocks. CPU cost of building CausetTasks:
   258  	// (outerCnt / batchSize) * (batchSize * distinctFactor) * cpuFactor
   259  	// Since we don't know the number of CausetTasks built, ignore these network cost now.
   260  	innerCPUCost += outerCnt * distinctFactor * sessVars.CPUFactor
   261  	innerConcurrency := float64(p.ctx.GetStochastikVars().IndexLookupJoinConcurrency())
   262  	cpuCost += innerCPUCost / innerConcurrency
   263  	// Cost of merge join in inner worker.
   264  	numPairs := outerCnt * innerCnt
   265  	if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   266  		p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   267  		if len(p.OtherConditions) > 0 {
   268  			numPairs *= 0.5
   269  		} else {
   270  			numPairs = 0
   271  		}
   272  	}
   273  	avgProbeCnt := numPairs / outerCnt
   274  	var probeCost float64
   275  	// Inner workers do merge join in parallel, but they can only save ONE outer batch
   276  	// results. So as the number of outer batch exceeds inner concurrency, it would fall back to
   277  	// linear execution. In a word, the merge join only run in parallel for the first
   278  	// `innerConcurrency` number of inner tasks.
   279  	if outerCnt/batchSize >= innerConcurrency {
   280  		probeCost = (numPairs - batchSize*avgProbeCnt*(innerConcurrency-1)) * sessVars.CPUFactor
   281  	} else {
   282  		probeCost = batchSize * avgProbeCnt * sessVars.CPUFactor
   283  	}
   284  	cpuCost += probeCost + (innerConcurrency+1.0)*sessVars.ConcurrencyFactor
   285  
   286  	// Index merge join save the join results in inner worker.
   287  	// So the memory cost consider the results size for each batch.
   288  	memoryCost := innerConcurrency * (batchSize * avgProbeCnt) * sessVars.MemoryFactor
   289  
   290  	innerCausetCost := outerCnt * innerTask.cost()
   291  	return outerTask.cost() + innerCausetCost + cpuCost + memoryCost
   292  }
   293  
   294  func (p *PhysicalIndexHashJoin) attach2Task(tasks ...task) task {
   295  	innerTask := p.innerTask
   296  	outerTask := finishCopTask(p.ctx, tasks[1-p.InnerChildIdx].copy())
   297  	if p.InnerChildIdx == 1 {
   298  		p.SetChildren(outerTask.plan(), innerTask.plan())
   299  	} else {
   300  		p.SetChildren(innerTask.plan(), outerTask.plan())
   301  	}
   302  	return &rootTask{
   303  		p:   p,
   304  		cst: p.GetCost(outerTask, innerTask),
   305  	}
   306  }
   307  
   308  // GetCost computes the cost of index merge join operator and its children.
   309  func (p *PhysicalIndexHashJoin) GetCost(outerTask, innerTask task) float64 {
   310  	var cpuCost float64
   311  	outerCnt, innerCnt := outerTask.count(), innerTask.count()
   312  	sessVars := p.ctx.GetStochastikVars()
   313  	// Add the cost of evaluating outer filter, since inner filter of index join
   314  	// is always empty, we can simply tell whether outer filter is empty using the
   315  	// summed length of left/right conditions.
   316  	if len(p.LeftConditions)+len(p.RightConditions) > 0 {
   317  		cpuCost += sessVars.CPUFactor * outerCnt
   318  		outerCnt *= SelectionFactor
   319  	}
   320  	// Cost of extracting lookup keys.
   321  	innerCPUCost := sessVars.CPUFactor * outerCnt
   322  	// Cost of sorting and removing duplicate lookup keys:
   323  	// (outerCnt / batchSize) * (batchSize * Log2(batchSize) + batchSize) * CPUFactor
   324  	batchSize := math.Min(float64(sessVars.IndexJoinBatchSize), outerCnt)
   325  	if batchSize > 2 {
   326  		innerCPUCost += outerCnt * (math.Log2(batchSize) + 1) * sessVars.CPUFactor
   327  	}
   328  	// Add cost of building inner interlocks. CPU cost of building CausetTasks:
   329  	// (outerCnt / batchSize) * (batchSize * distinctFactor) * CPUFactor
   330  	// Since we don't know the number of CausetTasks built, ignore these network cost now.
   331  	innerCPUCost += outerCnt * distinctFactor * sessVars.CPUFactor
   332  	concurrency := float64(sessVars.IndexLookupJoinConcurrency())
   333  	cpuCost += innerCPUCost / concurrency
   334  	// CPU cost of building hash causet for outer results concurrently.
   335  	// (outerCnt / batchSize) * (batchSize * CPUFactor)
   336  	outerCPUCost := outerCnt * sessVars.CPUFactor
   337  	cpuCost += outerCPUCost / concurrency
   338  	// Cost of probing hash causet concurrently.
   339  	numPairs := outerCnt * innerCnt
   340  	if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   341  		p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   342  		if len(p.OtherConditions) > 0 {
   343  			numPairs *= 0.5
   344  		} else {
   345  			numPairs = 0
   346  		}
   347  	}
   348  	// Inner workers do hash join in parallel, but they can only save ONE outer
   349  	// batch results. So as the number of outer batch exceeds inner concurrency,
   350  	// it would fall back to linear execution. In a word, the hash join only runs
   351  	// in parallel for the first `innerConcurrency` number of inner tasks.
   352  	var probeCost float64
   353  	if outerCnt/batchSize >= concurrency {
   354  		probeCost = (numPairs - batchSize*innerCnt*(concurrency-1)) * sessVars.CPUFactor
   355  	} else {
   356  		probeCost = batchSize * innerCnt * sessVars.CPUFactor
   357  	}
   358  	cpuCost += probeCost
   359  	// Cost of additional concurrent goroutines.
   360  	cpuCost += (concurrency + 1.0) * sessVars.ConcurrencyFactor
   361  	// Memory cost of hash blocks for outer rows. The computed result is the upper bound,
   362  	// since the interlock is pipelined and not all workers are always in full load.
   363  	memoryCost := concurrency * (batchSize * distinctFactor) * innerCnt * sessVars.MemoryFactor
   364  	// Cost of inner child plan, i.e, mainly I/O and network cost.
   365  	innerCausetCost := outerCnt * innerTask.cost()
   366  	return outerTask.cost() + innerCausetCost + cpuCost + memoryCost
   367  }
   368  
   369  func (p *PhysicalIndexJoin) attach2Task(tasks ...task) task {
   370  	innerTask := p.innerTask
   371  	outerTask := finishCopTask(p.ctx, tasks[1-p.InnerChildIdx].copy())
   372  	if p.InnerChildIdx == 1 {
   373  		p.SetChildren(outerTask.plan(), innerTask.plan())
   374  	} else {
   375  		p.SetChildren(innerTask.plan(), outerTask.plan())
   376  	}
   377  	return &rootTask{
   378  		p:   p,
   379  		cst: p.GetCost(outerTask, innerTask),
   380  	}
   381  }
   382  
   383  // GetCost computes the cost of index join operator and its children.
   384  func (p *PhysicalIndexJoin) GetCost(outerTask, innerTask task) float64 {
   385  	var cpuCost float64
   386  	outerCnt, innerCnt := outerTask.count(), innerTask.count()
   387  	sessVars := p.ctx.GetStochastikVars()
   388  	// Add the cost of evaluating outer filter, since inner filter of index join
   389  	// is always empty, we can simply tell whether outer filter is empty using the
   390  	// summed length of left/right conditions.
   391  	if len(p.LeftConditions)+len(p.RightConditions) > 0 {
   392  		cpuCost += sessVars.CPUFactor * outerCnt
   393  		outerCnt *= SelectionFactor
   394  	}
   395  	// Cost of extracting lookup keys.
   396  	innerCPUCost := sessVars.CPUFactor * outerCnt
   397  	// Cost of sorting and removing duplicate lookup keys:
   398  	// (outerCnt / batchSize) * (batchSize * Log2(batchSize) + batchSize) * CPUFactor
   399  	batchSize := math.Min(float64(p.ctx.GetStochastikVars().IndexJoinBatchSize), outerCnt)
   400  	if batchSize > 2 {
   401  		innerCPUCost += outerCnt * (math.Log2(batchSize) + 1) * sessVars.CPUFactor
   402  	}
   403  	// Add cost of building inner interlocks. CPU cost of building CausetTasks:
   404  	// (outerCnt / batchSize) * (batchSize * distinctFactor) * CPUFactor
   405  	// Since we don't know the number of CausetTasks built, ignore these network cost now.
   406  	innerCPUCost += outerCnt * distinctFactor * sessVars.CPUFactor
   407  	// CPU cost of building hash causet for inner results:
   408  	// (outerCnt / batchSize) * (batchSize * distinctFactor) * innerCnt * CPUFactor
   409  	innerCPUCost += outerCnt * distinctFactor * innerCnt * sessVars.CPUFactor
   410  	innerConcurrency := float64(p.ctx.GetStochastikVars().IndexLookupJoinConcurrency())
   411  	cpuCost += innerCPUCost / innerConcurrency
   412  	// Cost of probing hash causet in main thread.
   413  	numPairs := outerCnt * innerCnt
   414  	if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   415  		p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   416  		if len(p.OtherConditions) > 0 {
   417  			numPairs *= 0.5
   418  		} else {
   419  			numPairs = 0
   420  		}
   421  	}
   422  	probeCost := numPairs * sessVars.CPUFactor
   423  	// Cost of additional concurrent goroutines.
   424  	cpuCost += probeCost + (innerConcurrency+1.0)*sessVars.ConcurrencyFactor
   425  	// Memory cost of hash blocks for inner rows. The computed result is the upper bound,
   426  	// since the interlock is pipelined and not all workers are always in full load.
   427  	memoryCost := innerConcurrency * (batchSize * distinctFactor) * innerCnt * sessVars.MemoryFactor
   428  	// Cost of inner child plan, i.e, mainly I/O and network cost.
   429  	innerCausetCost := outerCnt * innerTask.cost()
   430  	return outerTask.cost() + innerCausetCost + cpuCost + memoryCost
   431  }
   432  
   433  func getAvgRowSize(stats *property.StatsInfo, schemaReplicant *memex.Schema) (size float64) {
   434  	if stats.HistDefCausl != nil {
   435  		size = stats.HistDefCausl.GetAvgRowSizeListInDisk(schemaReplicant.DeferredCausets)
   436  	} else {
   437  		// Estimate using just the type info.
   438  		defcaus := schemaReplicant.DeferredCausets
   439  		for _, col := range defcaus {
   440  			size += float64(chunk.EstimateTypeWidth(col.GetType()))
   441  		}
   442  	}
   443  	return
   444  }
   445  
   446  // GetCost computes cost of hash join operator itself.
   447  func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64) float64 {
   448  	buildCnt, probeCnt := lCnt, rCnt
   449  	build := p.children[0]
   450  	// Taking the right as the inner for right join or using the outer to build a hash causet.
   451  	if (p.InnerChildIdx == 1 && !p.UseOuterToBuild) || (p.InnerChildIdx == 0 && p.UseOuterToBuild) {
   452  		buildCnt, probeCnt = rCnt, lCnt
   453  		build = p.children[1]
   454  	}
   455  	sessVars := p.ctx.GetStochastikVars()
   456  	oomUseTmpStorage := config.GetGlobalConfig().OOMUseTmpStorage
   457  	memQuota := sessVars.StmtCtx.MemTracker.GetBytesLimit() // sessVars.MemQuotaQuery && hint
   458  	rowSize := getAvgRowSize(build.statsInfo(), build.Schema())
   459  	spill := oomUseTmpStorage && memQuota > 0 && rowSize*buildCnt > float64(memQuota)
   460  	// Cost of building hash causet.
   461  	cpuCost := buildCnt * sessVars.CPUFactor
   462  	memoryCost := buildCnt * sessVars.MemoryFactor
   463  	diskCost := buildCnt * sessVars.DiskFactor * rowSize
   464  	// Number of matched event pairs regarding the equal join conditions.
   465  	helper := &fullJoinRowCountHelper{
   466  		cartesian:     false,
   467  		leftProfile:   p.children[0].statsInfo(),
   468  		rightProfile:  p.children[1].statsInfo(),
   469  		leftJoinKeys:  p.LeftJoinKeys,
   470  		rightJoinKeys: p.RightJoinKeys,
   471  		leftSchema:    p.children[0].Schema(),
   472  		rightSchema:   p.children[1].Schema(),
   473  	}
   474  	numPairs := helper.estimate()
   475  	// For semi-join class, if `OtherConditions` is empty, we already know
   476  	// the join results after querying hash causet, otherwise, we have to
   477  	// evaluate those resulted event pairs after querying hash causet; if we
   478  	// find one pair satisfying the `OtherConditions`, we then know the
   479  	// join result for this given outer event, otherwise we have to iterate
   480  	// to the end of those pairs; since we have no idea about when we can
   481  	// terminate the iteration, we assume that we need to iterate half of
   482  	// those pairs in average.
   483  	if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   484  		p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   485  		if len(p.OtherConditions) > 0 {
   486  			numPairs *= 0.5
   487  		} else {
   488  			numPairs = 0
   489  		}
   490  	}
   491  	// Cost of querying hash causet is cheap actually, so we just compute the cost of
   492  	// evaluating `OtherConditions` and joining event pairs.
   493  	probeCost := numPairs * sessVars.CPUFactor
   494  	probeDiskCost := numPairs * sessVars.DiskFactor * rowSize
   495  	// Cost of evaluating outer filter.
   496  	if len(p.LeftConditions)+len(p.RightConditions) > 0 {
   497  		// Input outer count for the above compution should be adjusted by SelectionFactor.
   498  		probeCost *= SelectionFactor
   499  		probeDiskCost *= SelectionFactor
   500  		probeCost += probeCnt * sessVars.CPUFactor
   501  	}
   502  	diskCost += probeDiskCost
   503  	probeCost /= float64(p.Concurrency)
   504  	// Cost of additional concurrent goroutines.
   505  	cpuCost += probeCost + float64(p.Concurrency+1)*sessVars.ConcurrencyFactor
   506  	// Cost of traveling the hash causet to resolve missing matched cases when building the hash causet from the outer causet
   507  	if p.UseOuterToBuild {
   508  		if spill {
   509  			// It runs in sequence when build data is on disk. See handleUnmatchedRowsFromHashBlockInDisk
   510  			cpuCost += buildCnt * sessVars.CPUFactor
   511  		} else {
   512  			cpuCost += buildCnt * sessVars.CPUFactor / float64(p.Concurrency)
   513  		}
   514  		diskCost += buildCnt * sessVars.DiskFactor * rowSize
   515  	}
   516  
   517  	if spill {
   518  		memoryCost *= float64(memQuota) / (rowSize * buildCnt)
   519  	} else {
   520  		diskCost = 0
   521  	}
   522  	return cpuCost + memoryCost + diskCost
   523  }
   524  
   525  func (p *PhysicalHashJoin) attach2Task(tasks ...task) task {
   526  	lTask := finishCopTask(p.ctx, tasks[0].copy())
   527  	rTask := finishCopTask(p.ctx, tasks[1].copy())
   528  	p.SetChildren(lTask.plan(), rTask.plan())
   529  	task := &rootTask{
   530  		p:   p,
   531  		cst: lTask.cost() + rTask.cost() + p.GetCost(lTask.count(), rTask.count()),
   532  	}
   533  	return task
   534  }
   535  
   536  // GetCost computes cost of broadcast join operator itself.
   537  func (p *PhysicalBroadCastJoin) GetCost(lCnt, rCnt float64) float64 {
   538  	buildCnt := lCnt
   539  	if p.InnerChildIdx == 1 {
   540  		buildCnt = rCnt
   541  	}
   542  	sessVars := p.ctx.GetStochastikVars()
   543  	// Cost of building hash causet.
   544  	cpuCost := buildCnt * sessVars.CopCPUFactor
   545  	memoryCost := buildCnt * sessVars.MemoryFactor
   546  	// Number of matched event pairs regarding the equal join conditions.
   547  	helper := &fullJoinRowCountHelper{
   548  		cartesian:     false,
   549  		leftProfile:   p.children[0].statsInfo(),
   550  		rightProfile:  p.children[1].statsInfo(),
   551  		leftJoinKeys:  p.LeftJoinKeys,
   552  		rightJoinKeys: p.RightJoinKeys,
   553  		leftSchema:    p.children[0].Schema(),
   554  		rightSchema:   p.children[1].Schema(),
   555  	}
   556  	numPairs := helper.estimate()
   557  	probeCost := numPairs * sessVars.CopCPUFactor
   558  	// should divided by the concurrency in tiflash, which should be the number of embedded in tiflash nodes.
   559  	probeCost /= float64(sessVars.CopTiFlashConcurrencyFactor)
   560  	cpuCost += probeCost
   561  
   562  	// todo since TiFlash join is significant faster than MilevaDB join, maybe
   563  	//  need to add a variable like 'tiflash_accelerate_factor', and divide
   564  	//  the final cost by that factor
   565  	return cpuCost + memoryCost
   566  }
   567  
   568  func (p *PhysicalBroadCastJoin) attach2Task(tasks ...task) task {
   569  	lTask, lok := tasks[0].(*copTask)
   570  	rTask, rok := tasks[1].(*copTask)
   571  	if !lok || !rok || (lTask.getStoreType() != ekv.TiFlash && rTask.getStoreType() != ekv.TiFlash) {
   572  		return invalidTask
   573  	}
   574  	p.SetChildren(lTask.plan(), rTask.plan())
   575  	p.schemaReplicant = BuildPhysicalJoinSchema(p.JoinType, p)
   576  	if !lTask.indexCausetFinished {
   577  		lTask.finishIndexCauset()
   578  	}
   579  	if !rTask.indexCausetFinished {
   580  		rTask.finishIndexCauset()
   581  	}
   582  
   583  	lCost := lTask.cost()
   584  	rCost := rTask.cost()
   585  
   586  	task := &copTask{
   587  		tblDefCausHists:     rTask.tblDefCausHists,
   588  		indexCausetFinished: true,
   589  		blockCauset:         p,
   590  		cst:                 lCost + rCost + p.GetCost(lTask.count(), rTask.count()),
   591  	}
   592  	return task
   593  }
   594  
   595  // GetCost computes cost of merge join operator itself.
   596  func (p *PhysicalMergeJoin) GetCost(lCnt, rCnt float64) float64 {
   597  	outerCnt := lCnt
   598  	innerKeys := p.RightJoinKeys
   599  	innerSchema := p.children[1].Schema()
   600  	innerStats := p.children[1].statsInfo()
   601  	if p.JoinType == RightOuterJoin {
   602  		outerCnt = rCnt
   603  		innerKeys = p.LeftJoinKeys
   604  		innerSchema = p.children[0].Schema()
   605  		innerStats = p.children[0].statsInfo()
   606  	}
   607  	helper := &fullJoinRowCountHelper{
   608  		cartesian:     false,
   609  		leftProfile:   p.children[0].statsInfo(),
   610  		rightProfile:  p.children[1].statsInfo(),
   611  		leftJoinKeys:  p.LeftJoinKeys,
   612  		rightJoinKeys: p.RightJoinKeys,
   613  		leftSchema:    p.children[0].Schema(),
   614  		rightSchema:   p.children[1].Schema(),
   615  	}
   616  	numPairs := helper.estimate()
   617  	if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin ||
   618  		p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin {
   619  		if len(p.OtherConditions) > 0 {
   620  			numPairs *= 0.5
   621  		} else {
   622  			numPairs = 0
   623  		}
   624  	}
   625  	sessVars := p.ctx.GetStochastikVars()
   626  	probeCost := numPairs * sessVars.CPUFactor
   627  	// Cost of evaluating outer filters.
   628  	var cpuCost float64
   629  	if len(p.LeftConditions)+len(p.RightConditions) > 0 {
   630  		probeCost *= SelectionFactor
   631  		cpuCost += outerCnt * sessVars.CPUFactor
   632  	}
   633  	cpuCost += probeCost
   634  	// For merge join, only one group of rows with same join key(not null) are cached,
   635  	// we compute average memory cost using estimated group size.
   636  	NDV := getCardinality(innerKeys, innerSchema, innerStats)
   637  	memoryCost := (innerStats.RowCount / NDV) * sessVars.MemoryFactor
   638  	return cpuCost + memoryCost
   639  }
   640  
   641  func (p *PhysicalMergeJoin) attach2Task(tasks ...task) task {
   642  	lTask := finishCopTask(p.ctx, tasks[0].copy())
   643  	rTask := finishCopTask(p.ctx, tasks[1].copy())
   644  	p.SetChildren(lTask.plan(), rTask.plan())
   645  	return &rootTask{
   646  		p:   p,
   647  		cst: lTask.cost() + rTask.cost() + p.GetCost(lTask.count(), rTask.count()),
   648  	}
   649  }
   650  
   651  func buildIndexLookUpTask(ctx stochastikctx.Context, t *copTask) *rootTask {
   652  	newTask := &rootTask{cst: t.cst}
   653  	sessVars := ctx.GetStochastikVars()
   654  	p := PhysicalIndexLookUpReader{
   655  		blockCauset:          t.blockCauset,
   656  		indexCauset:          t.indexCauset,
   657  		ExtraHandleDefCaus:   t.extraHandleDefCaus,
   658  		CommonHandleDefCauss: t.commonHandleDefCauss,
   659  	}.Init(ctx, t.blockCauset.SelectBlockOffset())
   660  	p.PartitionInfo = t.partitionInfo
   661  	setBlockScanToBlockRowIDScan(p.blockCauset)
   662  	p.stats = t.blockCauset.statsInfo()
   663  	// Add cost of building causet reader interlocks. Handles are extracted in batch style,
   664  	// each handle is a range, the CPU cost of building CausetTasks should be:
   665  	// (indexRows / batchSize) * batchSize * CPUFactor
   666  	// Since we don't know the number of CausetTasks built, ignore these network cost now.
   667  	indexRows := t.indexCauset.statsInfo().RowCount
   668  	newTask.cst += indexRows * sessVars.CPUFactor
   669  	// Add cost of worker goroutines in index lookup.
   670  	numTblWorkers := float64(sessVars.IndexLookupConcurrency())
   671  	newTask.cst += (numTblWorkers + 1) * sessVars.ConcurrencyFactor
   672  	// When building causet reader interlock for each batch, we would sort the handles. CPU
   673  	// cost of sort is:
   674  	// CPUFactor * batchSize * Log2(batchSize) * (indexRows / batchSize)
   675  	indexLookupSize := float64(sessVars.IndexLookupSize)
   676  	batchSize := math.Min(indexLookupSize, indexRows)
   677  	if batchSize > 2 {
   678  		sortCPUCost := (indexRows * math.Log2(batchSize) * sessVars.CPUFactor) / numTblWorkers
   679  		newTask.cst += sortCPUCost
   680  	}
   681  	// Also, we need to sort the retrieved rows if index lookup reader is expected to return
   682  	// ordered results. Note that event count of these two sorts can be different, if there are
   683  	// operators above causet scan.
   684  	blockRows := t.blockCauset.statsInfo().RowCount
   685  	selectivity := blockRows / indexRows
   686  	batchSize = math.Min(indexLookupSize*selectivity, blockRows)
   687  	if t.keepOrder && batchSize > 2 {
   688  		sortCPUCost := (blockRows * math.Log2(batchSize) * sessVars.CPUFactor) / numTblWorkers
   689  		newTask.cst += sortCPUCost
   690  	}
   691  	if t.doubleReadNeedProj {
   692  		schemaReplicant := p.IndexCausets[0].(*PhysicalIndexScan).dataSourceSchema
   693  		proj := PhysicalProjection{Exprs: memex.DeferredCauset2Exprs(schemaReplicant.DeferredCausets)}.Init(ctx, p.stats, t.blockCauset.SelectBlockOffset(), nil)
   694  		proj.SetSchema(schemaReplicant)
   695  		proj.SetChildren(p)
   696  		newTask.p = proj
   697  	} else {
   698  		newTask.p = p
   699  	}
   700  	return newTask
   701  }
   702  
   703  // finishCopTask means we close the interlock task and create a root task.
   704  func finishCopTask(ctx stochastikctx.Context, task task) task {
   705  	t, ok := task.(*copTask)
   706  	if !ok {
   707  		return task
   708  	}
   709  	sessVars := ctx.GetStochastikVars()
   710  	// CausetTasks are run in parallel, to make the estimated cost closer to execution time, we amortize
   711  	// the cost to cop iterator workers. According to `CopClient::Send`, the concurrency
   712  	// is Min(DistALLEGROSQLScanConcurrency, numRegionsInvolvedInScan), since we cannot infer
   713  	// the number of regions involved, we simply use DistALLEGROSQLScanConcurrency.
   714  	copIterWorkers := float64(t.plan().SCtx().GetStochastikVars().DistALLEGROSQLScanConcurrency())
   715  	t.finishIndexCauset()
   716  	// Network cost of transferring rows of causet scan to MilevaDB.
   717  	if t.blockCauset != nil {
   718  		t.cst += t.count() * sessVars.NetworkFactor * t.tblDefCausHists.GetAvgRowSize(ctx, t.blockCauset.Schema().DeferredCausets, false, false)
   719  
   720  		tp := t.blockCauset
   721  		for len(tp.Children()) > 0 {
   722  			if len(tp.Children()) == 1 {
   723  				tp = tp.Children()[0]
   724  			} else {
   725  				join := tp.(*PhysicalBroadCastJoin)
   726  				tp = join.children[1-join.InnerChildIdx]
   727  			}
   728  		}
   729  		ts := tp.(*PhysicalBlockScan)
   730  		ts.DeferredCausets = ExpandVirtualDeferredCauset(ts.DeferredCausets, ts.schemaReplicant, ts.Block.DeferredCausets)
   731  	}
   732  	t.cst /= copIterWorkers
   733  	newTask := &rootTask{
   734  		cst: t.cst,
   735  	}
   736  	if t.idxMergePartCausets != nil {
   737  		p := PhysicalIndexMergeReader{
   738  			partialCausets: t.idxMergePartCausets,
   739  			blockCauset:    t.blockCauset,
   740  		}.Init(ctx, t.idxMergePartCausets[0].SelectBlockOffset())
   741  		p.PartitionInfo = t.partitionInfo
   742  		setBlockScanToBlockRowIDScan(p.blockCauset)
   743  		newTask.p = p
   744  		return newTask
   745  	}
   746  	if t.indexCauset != nil && t.blockCauset != nil {
   747  		newTask = buildIndexLookUpTask(ctx, t)
   748  	} else if t.indexCauset != nil {
   749  		p := PhysicalIndexReader{indexCauset: t.indexCauset}.Init(ctx, t.indexCauset.SelectBlockOffset())
   750  		p.PartitionInfo = t.partitionInfo
   751  		p.stats = t.indexCauset.statsInfo()
   752  		newTask.p = p
   753  	} else {
   754  		tp := t.blockCauset
   755  		for len(tp.Children()) > 0 {
   756  			if len(tp.Children()) == 1 {
   757  				tp = tp.Children()[0]
   758  			} else {
   759  				join := tp.(*PhysicalBroadCastJoin)
   760  				tp = join.children[1-join.InnerChildIdx]
   761  			}
   762  		}
   763  		ts := tp.(*PhysicalBlockScan)
   764  		p := PhysicalBlockReader{
   765  			blockCauset:    t.blockCauset,
   766  			StoreType:      ts.StoreType,
   767  			IsCommonHandle: ts.Block.IsCommonHandle,
   768  		}.Init(ctx, t.blockCauset.SelectBlockOffset())
   769  		p.PartitionInfo = t.partitionInfo
   770  		p.stats = t.blockCauset.statsInfo()
   771  		newTask.p = p
   772  	}
   773  
   774  	if len(t.rootTaskConds) > 0 {
   775  		sel := PhysicalSelection{Conditions: t.rootTaskConds}.Init(ctx, newTask.p.statsInfo(), newTask.p.SelectBlockOffset())
   776  		sel.SetChildren(newTask.p)
   777  		newTask.p = sel
   778  	}
   779  
   780  	return newTask
   781  }
   782  
   783  // setBlockScanToBlockRowIDScan is to uFIDelate the isChildOfIndexLookUp attribute of PhysicalBlockScan child
   784  func setBlockScanToBlockRowIDScan(p PhysicalCauset) {
   785  	if ts, ok := p.(*PhysicalBlockScan); ok {
   786  		ts.SetIsChildOfIndexLookUp(true)
   787  	} else {
   788  		for _, child := range p.Children() {
   789  			setBlockScanToBlockRowIDScan(child)
   790  		}
   791  	}
   792  }
   793  
   794  // rootTask is the final sink node of a plan graph. It should be a single goroutine on milevadb.
   795  type rootTask struct {
   796  	p   PhysicalCauset
   797  	cst float64
   798  }
   799  
   800  func (t *rootTask) copy() task {
   801  	return &rootTask{
   802  		p:   t.p,
   803  		cst: t.cst,
   804  	}
   805  }
   806  
   807  func (t *rootTask) count() float64 {
   808  	return t.p.statsInfo().RowCount
   809  }
   810  
   811  func (t *rootTask) addCost(cst float64) {
   812  	t.cst += cst
   813  }
   814  
   815  func (t *rootTask) cost() float64 {
   816  	return t.cst
   817  }
   818  
   819  func (t *rootTask) plan() PhysicalCauset {
   820  	return t.p
   821  }
   822  
   823  func (p *PhysicalLimit) attach2Task(tasks ...task) task {
   824  	t := tasks[0].copy()
   825  	sunk := false
   826  	if cop, ok := t.(*copTask); ok {
   827  		// For double read which requires order being kept, the limit cannot be pushed down to the causet side,
   828  		// because handles would be reordered before being sent to causet scan.
   829  		if (!cop.keepOrder || !cop.indexCausetFinished || cop.indexCauset == nil) && len(cop.rootTaskConds) == 0 {
   830  			// When limit is pushed down, we should remove its offset.
   831  			newCount := p.Offset + p.Count
   832  			childProfile := cop.plan().statsInfo()
   833  			// Strictly speaking, for the event count of stats, we should multiply newCount with "regionNum",
   834  			// but "regionNum" is unknown since the copTask can be a double read, so we ignore it now.
   835  			stats := deriveLimitStats(childProfile, float64(newCount))
   836  			pushedDownLimit := PhysicalLimit{Count: newCount}.Init(p.ctx, stats, p.blockOffset)
   837  			cop = attachCauset2Task(pushedDownLimit, cop).(*copTask)
   838  		}
   839  		t = finishCopTask(p.ctx, cop)
   840  		sunk = p.sinkIntoIndexLookUp(t)
   841  	}
   842  	if sunk {
   843  		return t
   844  	}
   845  	return attachCauset2Task(p, t)
   846  }
   847  
   848  func (p *PhysicalLimit) sinkIntoIndexLookUp(t task) bool {
   849  	root := t.(*rootTask)
   850  	reader, isDoubleRead := root.p.(*PhysicalIndexLookUpReader)
   851  	proj, isProj := root.p.(*PhysicalProjection)
   852  	if !isDoubleRead && !isProj {
   853  		return false
   854  	}
   855  	if isProj {
   856  		reader, isDoubleRead = proj.Children()[0].(*PhysicalIndexLookUpReader)
   857  		if !isDoubleRead {
   858  			return false
   859  		}
   860  	}
   861  	// We can sink Limit into IndexLookUpReader only if blockCauset contains no Selection.
   862  	ts, isBlockScan := reader.blockCauset.(*PhysicalBlockScan)
   863  	if !isBlockScan {
   864  		return false
   865  	}
   866  	reader.PushedLimit = &PushedDownLimit{
   867  		Offset: p.Offset,
   868  		Count:  p.Count,
   869  	}
   870  	ts.stats = p.stats
   871  	reader.stats = p.stats
   872  	if isProj {
   873  		proj.stats = p.stats
   874  	}
   875  	return true
   876  }
   877  
   878  // GetCost computes cost of TopN operator itself.
   879  func (p *PhysicalTopN) GetCost(count float64, isRoot bool) float64 {
   880  	heapSize := float64(p.Offset + p.Count)
   881  	if heapSize < 2.0 {
   882  		heapSize = 2.0
   883  	}
   884  	sessVars := p.ctx.GetStochastikVars()
   885  	// Ignore the cost of `doCompaction` in current implementation of `TopNInterDirc`, since it is the
   886  	// special side-effect of our Chunk format in MilevaDB layer, which may not exist in interlock's
   887  	// implementation, or may be removed in the future if we change data format.
   888  	// Note that we are using worst complexity to compute CPU cost, because it is simpler compared with
   889  	// considering probabilities of average complexity, i.e, we may not need adjust heap for each input
   890  	// event.
   891  	var cpuCost float64
   892  	if isRoot {
   893  		cpuCost = count * math.Log2(heapSize) * sessVars.CPUFactor
   894  	} else {
   895  		cpuCost = count * math.Log2(heapSize) * sessVars.CopCPUFactor
   896  	}
   897  	memoryCost := heapSize * sessVars.MemoryFactor
   898  	return cpuCost + memoryCost
   899  }
   900  
   901  // canPushDown checks if this topN can be pushed down. If each of the memex can be converted to pb, it can be pushed.
   902  func (p *PhysicalTopN) canPushDown(cop *copTask) bool {
   903  	exprs := make([]memex.Expression, 0, len(p.ByItems))
   904  	for _, item := range p.ByItems {
   905  		exprs = append(exprs, item.Expr)
   906  	}
   907  	return memex.CanExprsPushDown(p.ctx.GetStochastikVars().StmtCtx, exprs, p.ctx.GetClient(), cop.getStoreType())
   908  }
   909  
   910  func (p *PhysicalTopN) allDefCaussFromSchema(schemaReplicant *memex.Schema) bool {
   911  	defcaus := make([]*memex.DeferredCauset, 0, len(p.ByItems))
   912  	for _, item := range p.ByItems {
   913  		defcaus = append(defcaus, memex.ExtractDeferredCausets(item.Expr)...)
   914  	}
   915  	return len(schemaReplicant.DeferredCausetsIndices(defcaus)) > 0
   916  }
   917  
   918  // GetCost computes the cost of in memory sort.
   919  func (p *PhysicalSort) GetCost(count float64, schemaReplicant *memex.Schema) float64 {
   920  	if count < 2.0 {
   921  		count = 2.0
   922  	}
   923  	sessVars := p.ctx.GetStochastikVars()
   924  	cpuCost := count * math.Log2(count) * sessVars.CPUFactor
   925  	memoryCost := count * sessVars.MemoryFactor
   926  
   927  	oomUseTmpStorage := config.GetGlobalConfig().OOMUseTmpStorage
   928  	memQuota := sessVars.StmtCtx.MemTracker.GetBytesLimit() // sessVars.MemQuotaQuery && hint
   929  	rowSize := getAvgRowSize(p.statsInfo(), schemaReplicant)
   930  	spill := oomUseTmpStorage && memQuota > 0 && rowSize*count > float64(memQuota)
   931  	diskCost := count * sessVars.DiskFactor * rowSize
   932  	if !spill {
   933  		diskCost = 0
   934  	} else {
   935  		memoryCost *= float64(memQuota) / (rowSize * count)
   936  	}
   937  	return cpuCost + memoryCost + diskCost
   938  }
   939  
   940  func (p *PhysicalSort) attach2Task(tasks ...task) task {
   941  	t := tasks[0].copy()
   942  	t = attachCauset2Task(p, t)
   943  	t.addCost(p.GetCost(t.count(), p.Schema()))
   944  	return t
   945  }
   946  
   947  func (p *NominalSort) attach2Task(tasks ...task) task {
   948  	if p.OnlyDeferredCauset {
   949  		return tasks[0]
   950  	}
   951  	t := tasks[0].copy()
   952  	t = attachCauset2Task(p, t)
   953  	return t
   954  }
   955  
   956  func (p *PhysicalTopN) getPushedDownTopN(childCauset PhysicalCauset) *PhysicalTopN {
   957  	newByItems := make([]*soliton.ByItems, 0, len(p.ByItems))
   958  	for _, expr := range p.ByItems {
   959  		newByItems = append(newByItems, expr.Clone())
   960  	}
   961  	newCount := p.Offset + p.Count
   962  	childProfile := childCauset.statsInfo()
   963  	// Strictly speaking, for the event count of pushed down TopN, we should multiply newCount with "regionNum",
   964  	// but "regionNum" is unknown since the copTask can be a double read, so we ignore it now.
   965  	stats := deriveLimitStats(childProfile, float64(newCount))
   966  	topN := PhysicalTopN{
   967  		ByItems: newByItems,
   968  		Count:   newCount,
   969  	}.Init(p.ctx, stats, p.blockOffset)
   970  	topN.SetChildren(childCauset)
   971  	return topN
   972  }
   973  
   974  func (p *PhysicalTopN) attach2Task(tasks ...task) task {
   975  	t := tasks[0].copy()
   976  	inputCount := t.count()
   977  	if copTask, ok := t.(*copTask); ok && p.canPushDown(copTask) && len(copTask.rootTaskConds) == 0 {
   978  		// If all columns in topN are from index plan, we push it to index plan, otherwise we finish the index plan and
   979  		// push it to causet plan.
   980  		var pushedDownTopN *PhysicalTopN
   981  		if !copTask.indexCausetFinished && p.allDefCaussFromSchema(copTask.indexCauset.Schema()) {
   982  			pushedDownTopN = p.getPushedDownTopN(copTask.indexCauset)
   983  			copTask.indexCauset = pushedDownTopN
   984  		} else {
   985  			copTask.finishIndexCauset()
   986  			pushedDownTopN = p.getPushedDownTopN(copTask.blockCauset)
   987  			copTask.blockCauset = pushedDownTopN
   988  		}
   989  		copTask.addCost(pushedDownTopN.GetCost(inputCount, false))
   990  	}
   991  	rootTask := finishCopTask(p.ctx, t)
   992  	rootTask.addCost(p.GetCost(rootTask.count(), true))
   993  	rootTask = attachCauset2Task(p, rootTask)
   994  	return rootTask
   995  }
   996  
   997  // GetCost computes the cost of projection operator itself.
   998  func (p *PhysicalProjection) GetCost(count float64) float64 {
   999  	sessVars := p.ctx.GetStochastikVars()
  1000  	cpuCost := count * sessVars.CPUFactor
  1001  	concurrency := float64(sessVars.ProjectionConcurrency())
  1002  	if concurrency <= 0 {
  1003  		return cpuCost
  1004  	}
  1005  	cpuCost /= concurrency
  1006  	concurrencyCost := (1 + concurrency) * sessVars.ConcurrencyFactor
  1007  	return cpuCost + concurrencyCost
  1008  }
  1009  
  1010  func (p *PhysicalProjection) attach2Task(tasks ...task) task {
  1011  	t := tasks[0].copy()
  1012  	if copTask, ok := t.(*copTask); ok {
  1013  		// TODO: support projection push down.
  1014  		t = finishCopTask(p.ctx, copTask)
  1015  	}
  1016  	t = attachCauset2Task(p, t)
  1017  	t.addCost(p.GetCost(t.count()))
  1018  	return t
  1019  }
  1020  
  1021  func (p *PhysicalUnionAll) attach2Task(tasks ...task) task {
  1022  	t := &rootTask{p: p}
  1023  	childCausets := make([]PhysicalCauset, 0, len(tasks))
  1024  	var childMaxCost float64
  1025  	for _, task := range tasks {
  1026  		task = finishCopTask(p.ctx, task)
  1027  		childCost := task.cost()
  1028  		if childCost > childMaxCost {
  1029  			childMaxCost = childCost
  1030  		}
  1031  		childCausets = append(childCausets, task.plan())
  1032  	}
  1033  	p.SetChildren(childCausets...)
  1034  	sessVars := p.ctx.GetStochastikVars()
  1035  	// Children of UnionInterDirc are executed in parallel.
  1036  	t.cst = childMaxCost + float64(1+len(tasks))*sessVars.ConcurrencyFactor
  1037  	return t
  1038  }
  1039  
  1040  func (sel *PhysicalSelection) attach2Task(tasks ...task) task {
  1041  	sessVars := sel.ctx.GetStochastikVars()
  1042  	t := finishCopTask(sel.ctx, tasks[0].copy())
  1043  	t.addCost(t.count() * sessVars.CPUFactor)
  1044  	t = attachCauset2Task(sel, t)
  1045  	return t
  1046  }
  1047  
  1048  // CheckAggCanPushCop checks whether the aggFuncs and groupByItems can
  1049  // be pushed down to interlock.
  1050  func CheckAggCanPushCop(sctx stochastikctx.Context, aggFuncs []*aggregation.AggFuncDesc, groupByItems []memex.Expression, storeType ekv.StoreType) bool {
  1051  	sc := sctx.GetStochastikVars().StmtCtx
  1052  	client := sctx.GetClient()
  1053  	for _, aggFunc := range aggFuncs {
  1054  		if memex.ContainVirtualDeferredCauset(aggFunc.Args) {
  1055  			return false
  1056  		}
  1057  		pb := aggregation.AggFuncToPBExpr(sc, client, aggFunc)
  1058  		if pb == nil {
  1059  			return false
  1060  		}
  1061  		if !aggregation.CheckAggPushDown(aggFunc, storeType) {
  1062  			return false
  1063  		}
  1064  		if !memex.CanExprsPushDown(sc, aggFunc.Args, client, storeType) {
  1065  			return false
  1066  		}
  1067  	}
  1068  	if memex.ContainVirtualDeferredCauset(groupByItems) {
  1069  		return false
  1070  	}
  1071  	return memex.CanExprsPushDown(sc, groupByItems, client, storeType)
  1072  }
  1073  
  1074  // AggInfo stores the information of an Aggregation.
  1075  type AggInfo struct {
  1076  	AggFuncs     []*aggregation.AggFuncDesc
  1077  	GroupByItems []memex.Expression
  1078  	Schema       *memex.Schema
  1079  }
  1080  
  1081  // BuildFinalModeAggregation splits either LogicalAggregation or PhysicalAggregation to finalAgg and partial1Agg,
  1082  // returns the information of partial and final agg.
  1083  // partialIsCop means whether partial agg is a cop task.
  1084  func BuildFinalModeAggregation(
  1085  	sctx stochastikctx.Context, original *AggInfo, partialIsCop bool) (partial, final *AggInfo, funcMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) {
  1086  
  1087  	funcMap = make(map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc, len(original.AggFuncs))
  1088  	partial = &AggInfo{
  1089  		AggFuncs:     make([]*aggregation.AggFuncDesc, 0, len(original.AggFuncs)),
  1090  		GroupByItems: original.GroupByItems,
  1091  		Schema:       memex.NewSchema(),
  1092  	}
  1093  	partialCursor := 0
  1094  	final = &AggInfo{
  1095  		AggFuncs:     make([]*aggregation.AggFuncDesc, len(original.AggFuncs)),
  1096  		GroupByItems: make([]memex.Expression, 0, len(original.GroupByItems)),
  1097  		Schema:       original.Schema,
  1098  	}
  1099  
  1100  	partialGbySchema := memex.NewSchema()
  1101  	// add group by columns
  1102  	for _, gbyExpr := range partial.GroupByItems {
  1103  		var gbyDefCaus *memex.DeferredCauset
  1104  		if col, ok := gbyExpr.(*memex.DeferredCauset); ok {
  1105  			gbyDefCaus = col
  1106  		} else {
  1107  			gbyDefCaus = &memex.DeferredCauset{
  1108  				UniqueID: sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
  1109  				RetType:  gbyExpr.GetType(),
  1110  			}
  1111  		}
  1112  		partialGbySchema.Append(gbyDefCaus)
  1113  		final.GroupByItems = append(final.GroupByItems, gbyDefCaus)
  1114  	}
  1115  
  1116  	// TODO: Refactor the way of constructing aggregation functions.
  1117  	// This fop loop is ugly, but I do not find a proper way to reconstruct
  1118  	// it right away.
  1119  	for i, aggFunc := range original.AggFuncs {
  1120  		finalAggFunc := &aggregation.AggFuncDesc{HasDistinct: false}
  1121  		finalAggFunc.Name = aggFunc.Name
  1122  		args := make([]memex.Expression, 0, len(aggFunc.Args))
  1123  		if aggFunc.HasDistinct {
  1124  			/*
  1125  				eg: SELECT COUNT(DISTINCT a), SUM(b) FROM t GROUP BY c
  1126  
  1127  				change from
  1128  					[root] group by: c, funcs:count(distinct a), funcs:sum(b)
  1129  				to
  1130  					[root] group by: c, funcs:count(distinct a), funcs:sum(b)
  1131  						[cop]: group by: c, a
  1132  			*/
  1133  			for _, distinctArg := range aggFunc.Args {
  1134  				// 1. add all args to partial.GroupByItems
  1135  				foundInGroupBy := false
  1136  				for j, gbyExpr := range partial.GroupByItems {
  1137  					if gbyExpr.Equal(sctx, distinctArg) {
  1138  						foundInGroupBy = true
  1139  						args = append(args, partialGbySchema.DeferredCausets[j])
  1140  						break
  1141  					}
  1142  				}
  1143  				if !foundInGroupBy {
  1144  					partial.GroupByItems = append(partial.GroupByItems, distinctArg)
  1145  					var gbyDefCaus *memex.DeferredCauset
  1146  					if col, ok := distinctArg.(*memex.DeferredCauset); ok {
  1147  						gbyDefCaus = col
  1148  					} else {
  1149  						gbyDefCaus = &memex.DeferredCauset{
  1150  							UniqueID: sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
  1151  							RetType:  distinctArg.GetType(),
  1152  						}
  1153  					}
  1154  					partialGbySchema.Append(gbyDefCaus)
  1155  					if !partialIsCop {
  1156  						// if partial is a cop task, firstrow function is redundant since group by items are outputted
  1157  						// by group by schemaReplicant, and final functions use group by schemaReplicant as their arguments.
  1158  						// if partial agg is not cop, we must append firstrow function & schemaReplicant, to output the group by
  1159  						// items.
  1160  						// maybe we can unify them sometime.
  1161  						firstRow, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncFirstRow, []memex.Expression{gbyDefCaus}, false)
  1162  						if err != nil {
  1163  							panic("NewAggFuncDesc FirstRow meets error: " + err.Error())
  1164  						}
  1165  						partial.AggFuncs = append(partial.AggFuncs, firstRow)
  1166  						newDefCaus, _ := gbyDefCaus.Clone().(*memex.DeferredCauset)
  1167  						newDefCaus.RetType = firstRow.RetTp
  1168  						partial.Schema.Append(newDefCaus)
  1169  						partialCursor++
  1170  					}
  1171  					args = append(args, gbyDefCaus)
  1172  				}
  1173  			}
  1174  
  1175  			finalAggFunc.HasDistinct = true
  1176  			finalAggFunc.Mode = aggregation.CompleteMode
  1177  		} else {
  1178  			if aggregation.NeedCount(finalAggFunc.Name) {
  1179  				ft := types.NewFieldType(allegrosql.TypeLonglong)
  1180  				ft.Flen, ft.Charset, ft.DefCauslate = 21, charset.CharsetBin, charset.DefCauslationBin
  1181  				partial.Schema.Append(&memex.DeferredCauset{
  1182  					UniqueID: sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
  1183  					RetType:  ft,
  1184  				})
  1185  				args = append(args, partial.Schema.DeferredCausets[partialCursor])
  1186  				partialCursor++
  1187  			}
  1188  			if finalAggFunc.Name == ast.AggFuncApproxCountDistinct {
  1189  				ft := types.NewFieldType(allegrosql.TypeString)
  1190  				ft.Charset, ft.DefCauslate = charset.CharsetBin, charset.DefCauslationBin
  1191  				ft.Flag |= allegrosql.NotNullFlag
  1192  				partial.Schema.Append(&memex.DeferredCauset{
  1193  					UniqueID: sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
  1194  					RetType:  ft,
  1195  				})
  1196  				args = append(args, partial.Schema.DeferredCausets[partialCursor])
  1197  				partialCursor++
  1198  			}
  1199  			if aggregation.NeedValue(finalAggFunc.Name) {
  1200  				partial.Schema.Append(&memex.DeferredCauset{
  1201  					UniqueID: sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
  1202  					RetType:  original.Schema.DeferredCausets[i].GetType(),
  1203  				})
  1204  				args = append(args, partial.Schema.DeferredCausets[partialCursor])
  1205  				partialCursor++
  1206  			}
  1207  			if aggFunc.Name == ast.AggFuncAvg {
  1208  				cntAgg := *aggFunc
  1209  				cntAgg.Name = ast.AggFuncCount
  1210  				cntAgg.RetTp = partial.Schema.DeferredCausets[partialCursor-2].GetType()
  1211  				cntAgg.RetTp.Flag = aggFunc.RetTp.Flag
  1212  				sumAgg := *aggFunc
  1213  				sumAgg.Name = ast.AggFuncSum
  1214  				sumAgg.RetTp = partial.Schema.DeferredCausets[partialCursor-1].GetType()
  1215  				partial.AggFuncs = append(partial.AggFuncs, &cntAgg, &sumAgg)
  1216  			} else if aggFunc.Name == ast.AggFuncApproxCountDistinct {
  1217  				approxCountDistinctAgg := *aggFunc
  1218  				approxCountDistinctAgg.Name = ast.AggFuncApproxCountDistinct
  1219  				approxCountDistinctAgg.RetTp = partial.Schema.DeferredCausets[partialCursor-1].GetType()
  1220  				partial.AggFuncs = append(partial.AggFuncs, &approxCountDistinctAgg)
  1221  			} else {
  1222  				partial.AggFuncs = append(partial.AggFuncs, aggFunc)
  1223  			}
  1224  
  1225  			finalAggFunc.Mode = aggregation.FinalMode
  1226  			funcMap[aggFunc] = finalAggFunc
  1227  		}
  1228  
  1229  		finalAggFunc.Args = args
  1230  		finalAggFunc.RetTp = aggFunc.RetTp
  1231  		final.AggFuncs[i] = finalAggFunc
  1232  	}
  1233  	partial.Schema.Append(partialGbySchema.DeferredCausets...)
  1234  	return
  1235  }
  1236  
  1237  func (p *basePhysicalAgg) newPartialAggregate(copTaskType ekv.StoreType) (partial, final PhysicalCauset) {
  1238  	// Check if this aggregation can push down.
  1239  	if !CheckAggCanPushCop(p.ctx, p.AggFuncs, p.GroupByItems, copTaskType) {
  1240  		return nil, p.self
  1241  	}
  1242  	partialPref, finalPref, funcMap := BuildFinalModeAggregation(p.ctx, &AggInfo{
  1243  		AggFuncs:     p.AggFuncs,
  1244  		GroupByItems: p.GroupByItems,
  1245  		Schema:       p.Schema().Clone(),
  1246  	}, true)
  1247  	if p.tp == plancodec.TypeStreamAgg && len(partialPref.GroupByItems) != len(finalPref.GroupByItems) {
  1248  		return nil, p.self
  1249  	}
  1250  	// Remove unnecessary FirstRow.
  1251  	partialPref.AggFuncs = RemoveUnnecessaryFirstRow(p.ctx,
  1252  		finalPref.AggFuncs, finalPref.GroupByItems,
  1253  		partialPref.AggFuncs, partialPref.GroupByItems, partialPref.Schema, funcMap)
  1254  	if copTaskType == ekv.MilevaDB {
  1255  		// For partial agg of MilevaDB cop task, since MilevaDB interlock reuse the MilevaDB interlock,
  1256  		// and MilevaDB aggregation interlock won't output the group by value,
  1257  		// so we need add `firstrow` aggregation function to output the group by value.
  1258  		aggFuncs, err := genFirstRowAggForGroupBy(p.ctx, partialPref.GroupByItems)
  1259  		if err != nil {
  1260  			return nil, p.self
  1261  		}
  1262  		partialPref.AggFuncs = append(partialPref.AggFuncs, aggFuncs...)
  1263  	}
  1264  	p.AggFuncs = partialPref.AggFuncs
  1265  	p.GroupByItems = partialPref.GroupByItems
  1266  	p.schemaReplicant = partialPref.Schema
  1267  	partialAgg := p.self
  1268  	// Create physical "final" aggregation.
  1269  	prop := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64}
  1270  	if p.tp == plancodec.TypeStreamAgg {
  1271  		finalAgg := basePhysicalAgg{
  1272  			AggFuncs:     finalPref.AggFuncs,
  1273  			GroupByItems: finalPref.GroupByItems,
  1274  		}.initForStream(p.ctx, p.stats, p.blockOffset, prop)
  1275  		finalAgg.schemaReplicant = finalPref.Schema
  1276  		return partialAgg, finalAgg
  1277  	}
  1278  
  1279  	finalAgg := basePhysicalAgg{
  1280  		AggFuncs:     finalPref.AggFuncs,
  1281  		GroupByItems: finalPref.GroupByItems,
  1282  	}.initForHash(p.ctx, p.stats, p.blockOffset, prop)
  1283  	finalAgg.schemaReplicant = finalPref.Schema
  1284  	return partialAgg, finalAgg
  1285  }
  1286  
  1287  func genFirstRowAggForGroupBy(ctx stochastikctx.Context, groupByItems []memex.Expression) ([]*aggregation.AggFuncDesc, error) {
  1288  	aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems))
  1289  	for _, groupBy := range groupByItems {
  1290  		agg, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncFirstRow, []memex.Expression{groupBy}, false)
  1291  		if err != nil {
  1292  			return nil, err
  1293  		}
  1294  		aggFuncs = append(aggFuncs, agg)
  1295  	}
  1296  	return aggFuncs, nil
  1297  }
  1298  
  1299  // RemoveUnnecessaryFirstRow removes unnecessary FirstRow of the aggregation. This function can be
  1300  // used for both LogicalAggregation and PhysicalAggregation.
  1301  // When the select column is same with the group by key, the column can be removed and gets value from the group by key.
  1302  // e.g
  1303  // select a, count(b) from t group by a;
  1304  // The schemaReplicant is [firstrow(a), count(b), a]. The column firstrow(a) is unnecessary.
  1305  // Can optimize the schemaReplicant to [count(b), a] , and change the index to get value.
  1306  func RemoveUnnecessaryFirstRow(
  1307  	sctx stochastikctx.Context,
  1308  	finalAggFuncs []*aggregation.AggFuncDesc,
  1309  	finalGbyItems []memex.Expression,
  1310  	partialAggFuncs []*aggregation.AggFuncDesc,
  1311  	partialGbyItems []memex.Expression,
  1312  	partialSchema *memex.Schema,
  1313  	funcMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) []*aggregation.AggFuncDesc {
  1314  
  1315  	partialCursor := 0
  1316  	newAggFuncs := make([]*aggregation.AggFuncDesc, 0, len(partialAggFuncs))
  1317  	for _, aggFunc := range partialAggFuncs {
  1318  		if aggFunc.Name == ast.AggFuncFirstRow {
  1319  			canOptimize := false
  1320  			for j, gbyExpr := range partialGbyItems {
  1321  				if j >= len(finalGbyItems) {
  1322  					// after distinct push, len(partialGbyItems) may larger than len(finalGbyItems)
  1323  					// for example,
  1324  					// select /*+ HASH_AGG() */ a, count(distinct a) from t;
  1325  					// will generate to,
  1326  					//   HashAgg root  funcs:count(distinct a), funcs:firstrow(a)"
  1327  					//     HashAgg cop  group by:a, funcs:firstrow(a)->DeferredCauset#6"
  1328  					// the firstrow in root task can not be removed.
  1329  					break
  1330  				}
  1331  				if gbyExpr.Equal(sctx, aggFunc.Args[0]) {
  1332  					canOptimize = true
  1333  					funcMap[aggFunc].Args[0] = finalGbyItems[j]
  1334  					break
  1335  				}
  1336  			}
  1337  			if canOptimize {
  1338  				partialSchema.DeferredCausets = append(partialSchema.DeferredCausets[:partialCursor], partialSchema.DeferredCausets[partialCursor+1:]...)
  1339  				continue
  1340  			}
  1341  		}
  1342  		partialCursor += computePartialCursorOffset(aggFunc.Name)
  1343  		newAggFuncs = append(newAggFuncs, aggFunc)
  1344  	}
  1345  	return newAggFuncs
  1346  }
  1347  
  1348  func computePartialCursorOffset(name string) int {
  1349  	offset := 0
  1350  	if aggregation.NeedCount(name) {
  1351  		offset++
  1352  	}
  1353  	if aggregation.NeedValue(name) {
  1354  		offset++
  1355  	}
  1356  	if name == ast.AggFuncApproxCountDistinct {
  1357  		offset++
  1358  	}
  1359  	return offset
  1360  }
  1361  
  1362  func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task {
  1363  	t := tasks[0].copy()
  1364  	inputRows := t.count()
  1365  	if cop, ok := t.(*copTask); ok {
  1366  		// We should not push agg down across double read, since the data of second read is ordered by handle instead of index.
  1367  		// The `extraHandleDefCaus` is added if the double read needs to keep order. So we just use it to decided
  1368  		// whether the following plan is double read with order reserved.
  1369  		if cop.extraHandleDefCaus != nil || len(cop.rootTaskConds) > 0 {
  1370  			t = finishCopTask(p.ctx, cop)
  1371  			inputRows = t.count()
  1372  			attachCauset2Task(p, t)
  1373  		} else {
  1374  			copTaskType := cop.getStoreType()
  1375  			partialAgg, finalAgg := p.newPartialAggregate(copTaskType)
  1376  			if partialAgg != nil {
  1377  				if cop.blockCauset != nil {
  1378  					cop.finishIndexCauset()
  1379  					partialAgg.SetChildren(cop.blockCauset)
  1380  					cop.blockCauset = partialAgg
  1381  				} else {
  1382  					partialAgg.SetChildren(cop.indexCauset)
  1383  					cop.indexCauset = partialAgg
  1384  				}
  1385  				cop.addCost(p.GetCost(inputRows, false))
  1386  			}
  1387  			t = finishCopTask(p.ctx, cop)
  1388  			inputRows = t.count()
  1389  			attachCauset2Task(finalAgg, t)
  1390  		}
  1391  	} else {
  1392  		attachCauset2Task(p, t)
  1393  	}
  1394  	t.addCost(p.GetCost(inputRows, true))
  1395  	return t
  1396  }
  1397  
  1398  // GetCost computes cost of stream aggregation considering CPU/memory.
  1399  func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool) float64 {
  1400  	aggFuncFactor := p.getAggFuncCostFactor()
  1401  	var cpuCost float64
  1402  	sessVars := p.ctx.GetStochastikVars()
  1403  	if isRoot {
  1404  		cpuCost = inputRows * sessVars.CPUFactor * aggFuncFactor
  1405  	} else {
  1406  		cpuCost = inputRows * sessVars.CopCPUFactor * aggFuncFactor
  1407  	}
  1408  	rowsPerGroup := inputRows / p.statsInfo().RowCount
  1409  	memoryCost := rowsPerGroup * distinctFactor * sessVars.MemoryFactor * float64(p.numDistinctFunc())
  1410  	return cpuCost + memoryCost
  1411  }
  1412  
  1413  // cpuCostDivisor computes the concurrency to which we would amortize CPU cost
  1414  // for hash aggregation.
  1415  func (p *PhysicalHashAgg) cpuCostDivisor(hasDistinct bool) (float64, float64) {
  1416  	if hasDistinct {
  1417  		return 0, 0
  1418  	}
  1419  	stochastikVars := p.ctx.GetStochastikVars()
  1420  	finalCon, partialCon := stochastikVars.HashAggFinalConcurrency(), stochastikVars.HashAggPartialConcurrency()
  1421  	// According to `ValidateSetSystemVar`, `finalCon` and `partialCon` cannot be less than or equal to 0.
  1422  	if finalCon == 1 && partialCon == 1 {
  1423  		return 0, 0
  1424  	}
  1425  	// It is tricky to decide which concurrency we should use to amortize CPU cost. Since cost of hash
  1426  	// aggregation is tend to be under-estimated as explained in `attach2Task`, we choose the smaller
  1427  	// concurrecy to make some compensation.
  1428  	return math.Min(float64(finalCon), float64(partialCon)), float64(finalCon + partialCon)
  1429  }
  1430  
  1431  func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
  1432  	t := tasks[0].copy()
  1433  	inputRows := t.count()
  1434  	if cop, ok := t.(*copTask); ok {
  1435  		if len(cop.rootTaskConds) == 0 {
  1436  			copTaskType := cop.getStoreType()
  1437  			partialAgg, finalAgg := p.newPartialAggregate(copTaskType)
  1438  			if partialAgg != nil {
  1439  				if cop.blockCauset != nil {
  1440  					cop.finishIndexCauset()
  1441  					partialAgg.SetChildren(cop.blockCauset)
  1442  					cop.blockCauset = partialAgg
  1443  				} else {
  1444  					partialAgg.SetChildren(cop.indexCauset)
  1445  					cop.indexCauset = partialAgg
  1446  				}
  1447  				cop.addCost(p.GetCost(inputRows, false))
  1448  			}
  1449  			// In `newPartialAggregate`, we are using stats of final aggregation as stats
  1450  			// of `partialAgg`, so the network cost of transferring result rows of `partialAgg`
  1451  			// to MilevaDB is normally under-estimated for hash aggregation, since the group-by
  1452  			// column may be independent of the column used for region distribution, so a closer
  1453  			// estimation of network cost for hash aggregation may multiply the number of
  1454  			// regions involved in the `partialAgg`, which is unknown however.
  1455  			t = finishCopTask(p.ctx, cop)
  1456  			inputRows = t.count()
  1457  			attachCauset2Task(finalAgg, t)
  1458  		} else {
  1459  			t = finishCopTask(p.ctx, cop)
  1460  			inputRows = t.count()
  1461  			attachCauset2Task(p, t)
  1462  		}
  1463  	} else {
  1464  		attachCauset2Task(p, t)
  1465  	}
  1466  	// We may have 3-phase hash aggregation actually, strictly speaking, we'd better
  1467  	// calculate cost of each phase and sum the results up, but in fact we don't have
  1468  	// region level causet stats, and the concurrency of the `partialAgg`,
  1469  	// i.e, max(number_of_regions, DistALLEGROSQLScanConcurrency) is unknown either, so it is hard
  1470  	// to compute costs separately. We ignore region level parallelism for both hash
  1471  	// aggregation and stream aggregation when calculating cost, though this would lead to inaccuracy,
  1472  	// hopefully this inaccuracy would be imposed on both aggregation implementations,
  1473  	// so they are still comparable horizontally.
  1474  	// Also, we use the stats of `partialAgg` as the input of cost computing for MilevaDB layer
  1475  	// hash aggregation, it would cause under-estimation as the reason mentioned in comment above.
  1476  	// To make it simple, we also treat 2-phase parallel hash aggregation in MilevaDB layer as
  1477  	// 1-phase when computing cost.
  1478  	t.addCost(p.GetCost(inputRows, true))
  1479  	return t
  1480  }
  1481  
  1482  // GetCost computes the cost of hash aggregation considering CPU/memory.
  1483  func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool) float64 {
  1484  	cardinality := p.statsInfo().RowCount
  1485  	numDistinctFunc := p.numDistinctFunc()
  1486  	aggFuncFactor := p.getAggFuncCostFactor()
  1487  	var cpuCost float64
  1488  	sessVars := p.ctx.GetStochastikVars()
  1489  	if isRoot {
  1490  		cpuCost = inputRows * sessVars.CPUFactor * aggFuncFactor
  1491  		divisor, con := p.cpuCostDivisor(numDistinctFunc > 0)
  1492  		if divisor > 0 {
  1493  			cpuCost /= divisor
  1494  			// Cost of additional goroutines.
  1495  			cpuCost += (con + 1) * sessVars.ConcurrencyFactor
  1496  		}
  1497  	} else {
  1498  		cpuCost = inputRows * sessVars.CopCPUFactor * aggFuncFactor
  1499  	}
  1500  	memoryCost := cardinality * sessVars.MemoryFactor * float64(len(p.AggFuncs))
  1501  	// When aggregation has distinct flag, we would allocate a map for each group to
  1502  	// check duplication.
  1503  	memoryCost += inputRows * distinctFactor * sessVars.MemoryFactor * float64(numDistinctFunc)
  1504  	return cpuCost + memoryCost
  1505  }