github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/dispatch/dispatch.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dispatch
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/logutil"
    22  
    23  	"github.com/google/uuid"
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    26  	"github.com/matrixorigin/matrixone/pkg/container/types"
    27  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    28  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    29  	"github.com/matrixorigin/matrixone/pkg/vm"
    30  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    31  )
    32  
    33  const argName = "dispatch"
    34  
    35  func (arg *Argument) String(buf *bytes.Buffer) {
    36  	buf.WriteString(argName)
    37  	buf.WriteString(": dispatch")
    38  }
    39  
    40  func (arg *Argument) Prepare(proc *process.Process) error {
    41  	ap := arg
    42  	ctr := new(container)
    43  	ap.ctr = ctr
    44  	ctr.localRegsCnt = len(ap.LocalRegs)
    45  	ctr.remoteRegsCnt = len(ap.RemoteRegs)
    46  	ctr.aliveRegCnt = ctr.localRegsCnt + ctr.remoteRegsCnt
    47  
    48  	switch ap.FuncId {
    49  	case SendToAllFunc:
    50  		if ctr.remoteRegsCnt == 0 {
    51  			return moerr.NewInternalError(proc.Ctx, "SendToAllFunc should include RemoteRegs")
    52  		}
    53  		if len(ap.LocalRegs) == 0 {
    54  			ctr.sendFunc = sendToAllRemoteFunc
    55  		} else {
    56  			ctr.sendFunc = sendToAllFunc
    57  		}
    58  		return ap.prepareRemote(proc)
    59  
    60  	case ShuffleToAllFunc:
    61  		ap.ctr.sendFunc = shuffleToAllFunc
    62  		if ap.ctr.remoteRegsCnt > 0 {
    63  			if err := ap.prepareRemote(proc); err != nil {
    64  				return err
    65  			}
    66  		} else {
    67  			ap.prepareLocal()
    68  		}
    69  		ap.ctr.batchCnt = make([]int, ctr.aliveRegCnt)
    70  		ap.ctr.rowCnt = make([]int, ctr.aliveRegCnt)
    71  
    72  	case SendToAnyFunc:
    73  		if ctr.remoteRegsCnt == 0 {
    74  			return moerr.NewInternalError(proc.Ctx, "SendToAnyFunc should include RemoteRegs")
    75  		}
    76  		if len(ap.LocalRegs) == 0 {
    77  			ctr.sendFunc = sendToAnyRemoteFunc
    78  		} else {
    79  			ctr.sendFunc = sendToAnyFunc
    80  		}
    81  		return ap.prepareRemote(proc)
    82  
    83  	case SendToAllLocalFunc:
    84  		if ctr.remoteRegsCnt != 0 {
    85  			return moerr.NewInternalError(proc.Ctx, "SendToAllLocalFunc should not send to remote")
    86  		}
    87  		ctr.sendFunc = sendToAllLocalFunc
    88  		ap.prepareLocal()
    89  
    90  	case SendToAnyLocalFunc:
    91  		if ctr.remoteRegsCnt != 0 {
    92  			return moerr.NewInternalError(proc.Ctx, "SendToAnyLocalFunc should not send to remote")
    93  		}
    94  		ap.ctr.sendFunc = sendToAnyLocalFunc
    95  		ap.prepareLocal()
    96  
    97  	default:
    98  		return moerr.NewInternalError(proc.Ctx, "wrong sendFunc id for dispatch")
    99  	}
   100  
   101  	return nil
   102  }
   103  
   104  func printShuffleResult(arg *Argument) {
   105  	if arg.ctr.batchCnt != nil && arg.ctr.rowCnt != nil {
   106  		logutil.Debugf("shuffle type %v,  dispatch result: batchcnt %v, rowcnt %v", arg.ShuffleType, arg.ctr.batchCnt, arg.ctr.rowCnt)
   107  	}
   108  }
   109  
   110  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
   111  	if err, isCancel := vm.CancelCheck(proc); isCancel {
   112  		return vm.CancelResult, err
   113  	}
   114  
   115  	ap := arg
   116  
   117  	result, err := arg.Children[0].Call(proc)
   118  	if err != nil {
   119  		return result, err
   120  	}
   121  
   122  	bat := result.Batch
   123  
   124  	if result.Batch == nil {
   125  		if ap.RecSink {
   126  			bat, err = makeEndBatch(proc)
   127  			if err != nil {
   128  				return result, err
   129  			}
   130  			defer func() {
   131  				if bat != nil {
   132  					proc.PutBatch(bat)
   133  				}
   134  			}()
   135  		} else {
   136  			printShuffleResult(ap)
   137  			result.Status = vm.ExecStop
   138  			return result, nil
   139  		}
   140  	}
   141  
   142  	if bat.Last() {
   143  		if !ap.ctr.hasData {
   144  			bat.SetEnd()
   145  		} else {
   146  			ap.ctr.hasData = false
   147  		}
   148  	} else if bat.IsEmpty() {
   149  		result.Batch = batch.EmptyBatch
   150  		return result, nil
   151  	} else {
   152  		ap.ctr.hasData = true
   153  	}
   154  	bat.AddCnt(1)
   155  	ok, err := ap.ctr.sendFunc(bat, ap, proc)
   156  	if ok {
   157  		result.Status = vm.ExecStop
   158  		return result, err
   159  	} else {
   160  		// result.Batch = nil
   161  		return result, err
   162  	}
   163  }
   164  
   165  func makeEndBatch(proc *process.Process) (*batch.Batch, error) {
   166  	b := batch.NewWithSize(1)
   167  	b.Attrs = []string{
   168  		"recursive_col",
   169  	}
   170  	b.SetVector(0, proc.GetVector(types.T_varchar.ToType()))
   171  	err := vector.AppendBytes(b.GetVector(0), []byte("check recursive status"), false, proc.GetMPool())
   172  	if err == nil {
   173  		batch.SetLength(b, 1)
   174  		b.SetEnd()
   175  	}
   176  	return b, err
   177  }
   178  
   179  func (arg *Argument) waitRemoteRegsReady(proc *process.Process) (bool, error) {
   180  	cnt := len(arg.RemoteRegs)
   181  	for cnt > 0 {
   182  		timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), waitNotifyTimeout)
   183  		select {
   184  		case <-timeoutCtx.Done():
   185  			timeoutCancel()
   186  			return false, moerr.NewInternalErrorNoCtx("wait notify message timeout")
   187  
   188  		case <-proc.Ctx.Done():
   189  			timeoutCancel()
   190  			arg.ctr.prepared = true
   191  			return true, nil
   192  
   193  		case csinfo := <-proc.DispatchNotifyCh:
   194  			timeoutCancel()
   195  			arg.ctr.remoteReceivers = append(arg.ctr.remoteReceivers, csinfo)
   196  			cnt--
   197  		}
   198  	}
   199  	arg.ctr.prepared = true
   200  	return false, nil
   201  }
   202  
   203  func (arg *Argument) prepareRemote(proc *process.Process) error {
   204  	arg.ctr.prepared = false
   205  	arg.ctr.isRemote = true
   206  	arg.ctr.remoteReceivers = make([]process.WrapCs, 0, arg.ctr.remoteRegsCnt)
   207  	arg.ctr.remoteToIdx = make(map[uuid.UUID]int)
   208  	for i, rr := range arg.RemoteRegs {
   209  		if arg.FuncId == ShuffleToAllFunc {
   210  			arg.ctr.remoteToIdx[rr.Uuid] = arg.ShuffleRegIdxRemote[i]
   211  		}
   212  		if err := colexec.Get().PutProcIntoUuidMap(rr.Uuid, proc); err != nil {
   213  			return err
   214  		}
   215  	}
   216  	return nil
   217  }
   218  
   219  func (arg *Argument) prepareLocal() {
   220  	arg.ctr.prepared = true
   221  	arg.ctr.isRemote = false
   222  	arg.ctr.remoteReceivers = nil
   223  }