github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/dispatch/sendfunc.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dispatch
    16  
    17  import (
    18  	"context"
    19  	"hash/crc32"
    20  	"sync/atomic"
    21  
    22  	plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/cnservice/cnclient"
    25  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    26  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    27  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    28  	"github.com/matrixorigin/matrixone/pkg/container/types"
    29  	"github.com/matrixorigin/matrixone/pkg/pb/pipeline"
    30  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    31  )
    32  
    33  // common sender: send to all LocalReceiver
    34  func sendToAllLocalFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
    35  	var refCountAdd int64
    36  	var err error
    37  	if !ap.RecSink {
    38  		refCountAdd = int64(ap.ctr.localRegsCnt - 1)
    39  		atomic.AddInt64(&bat.Cnt, refCountAdd)
    40  		if jm, ok := bat.AuxData.(*hashmap.JoinMap); ok {
    41  			jm.IncRef(refCountAdd)
    42  			jm.SetDupCount(int64(ap.ctr.localRegsCnt))
    43  		}
    44  	}
    45  	var bats []*batch.Batch
    46  	if ap.RecSink {
    47  		bats = append(bats, bat)
    48  		for k := 1; k < len(ap.LocalRegs); k++ {
    49  			bat, err = bat.Dup(proc.Mp())
    50  			if err != nil {
    51  				return false, err
    52  			}
    53  			bats = append(bats, bat)
    54  		}
    55  	}
    56  
    57  	for i, reg := range ap.LocalRegs {
    58  		if ap.RecSink {
    59  			bat = bats[i]
    60  		}
    61  		select {
    62  		case <-proc.Ctx.Done():
    63  			handleUnsent(proc, bat, refCountAdd, int64(i))
    64  			return true, nil
    65  
    66  		case <-reg.Ctx.Done():
    67  			if ap.IsSink {
    68  				atomic.AddInt64(&bat.Cnt, -1)
    69  				continue
    70  			}
    71  			handleUnsent(proc, bat, refCountAdd, int64(i))
    72  			return true, nil
    73  
    74  		case reg.Ch <- bat:
    75  		}
    76  	}
    77  	return false, nil
    78  }
    79  
    80  // common sender: send to all RemoteReceiver
    81  func sendToAllRemoteFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
    82  	if !ap.ctr.prepared {
    83  		end, err := ap.waitRemoteRegsReady(proc)
    84  		if err != nil {
    85  			return false, err
    86  		}
    87  		if end {
    88  			return true, nil
    89  		}
    90  	}
    91  
    92  	{ // send to remote regs
    93  		encodeData, errEncode := types.Encode(bat)
    94  		if errEncode != nil {
    95  			return false, errEncode
    96  		}
    97  		for _, r := range ap.ctr.remoteReceivers {
    98  			if err := sendBatchToClientSession(proc.Ctx, encodeData, r); err != nil {
    99  				return false, err
   100  			}
   101  		}
   102  	}
   103  
   104  	return false, nil
   105  }
   106  
   107  func sendBatToIndex(ap *Argument, proc *process.Process, bat *batch.Batch, regIndex uint32) (err error) {
   108  	for i, reg := range ap.LocalRegs {
   109  		batIndex := uint32(ap.ShuffleRegIdxLocal[i])
   110  		if regIndex == batIndex {
   111  			if bat != nil && bat.RowCount() != 0 {
   112  				select {
   113  				case <-proc.Ctx.Done():
   114  					return nil
   115  
   116  				case <-reg.Ctx.Done():
   117  					return nil
   118  
   119  				case reg.Ch <- bat:
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	forShuffle := false
   126  	for _, r := range ap.ctr.remoteReceivers {
   127  		batIndex := uint32(ap.ctr.remoteToIdx[r.Uid])
   128  		if regIndex == batIndex {
   129  			if bat != nil && bat.RowCount() != 0 {
   130  				forShuffle = true
   131  
   132  				encodeData, errEncode := types.Encode(bat)
   133  				if errEncode != nil {
   134  					err = errEncode
   135  					break
   136  				}
   137  				if err = sendBatchToClientSession(proc.Ctx, encodeData, r); err != nil {
   138  					break
   139  				}
   140  			}
   141  		}
   142  	}
   143  
   144  	if forShuffle {
   145  		// in shuffle dispatch, this batch only send to remote CN, we can safely put it back into pool
   146  		proc.PutBatch(bat)
   147  	}
   148  	return err
   149  }
   150  
   151  func sendBatToLocalMatchedReg(ap *Argument, proc *process.Process, bat *batch.Batch, regIndex uint32) error {
   152  	localRegsCnt := uint32(ap.ctr.localRegsCnt)
   153  	for i, reg := range ap.LocalRegs {
   154  		batIndex := uint32(ap.ShuffleRegIdxLocal[i])
   155  		if regIndex%localRegsCnt == batIndex%localRegsCnt {
   156  			if bat != nil && bat.RowCount() != 0 {
   157  				select {
   158  				case <-proc.Ctx.Done():
   159  					return nil
   160  
   161  				case <-reg.Ctx.Done():
   162  					return nil
   163  
   164  				case reg.Ch <- bat:
   165  				}
   166  			}
   167  		}
   168  	}
   169  	return nil
   170  }
   171  
   172  func sendBatToMultiMatchedReg(ap *Argument, proc *process.Process, bat *batch.Batch, regIndex uint32) error {
   173  	localRegsCnt := uint32(ap.ctr.localRegsCnt)
   174  	atomic.AddInt64(&bat.Cnt, 1)
   175  	defer atomic.AddInt64(&bat.Cnt, -1)
   176  	for i, reg := range ap.LocalRegs {
   177  		batIndex := uint32(ap.ShuffleRegIdxLocal[i])
   178  		if regIndex%localRegsCnt == batIndex%localRegsCnt {
   179  			if bat != nil && bat.RowCount() != 0 {
   180  				select {
   181  				case <-proc.Ctx.Done():
   182  					return nil
   183  
   184  				case <-reg.Ctx.Done():
   185  					return nil
   186  
   187  				case reg.Ch <- bat:
   188  				}
   189  			}
   190  		}
   191  	}
   192  	for _, r := range ap.ctr.remoteReceivers {
   193  		batIndex := uint32(ap.ctr.remoteToIdx[r.Uid])
   194  		if regIndex%localRegsCnt == batIndex%localRegsCnt {
   195  			if bat != nil && bat.RowCount() != 0 {
   196  				encodeData, errEncode := types.Encode(bat)
   197  				if errEncode != nil {
   198  					return errEncode
   199  				}
   200  				if err := sendBatchToClientSession(proc.Ctx, encodeData, r); err != nil {
   201  					return err
   202  				}
   203  			}
   204  		}
   205  	}
   206  	return nil
   207  }
   208  
   209  // shuffle to all receiver (include LocalReceiver and RemoteReceiver)
   210  func shuffleToAllFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
   211  	if !ap.ctr.prepared {
   212  		end, err := ap.waitRemoteRegsReady(proc)
   213  		if err != nil {
   214  			return false, err
   215  		}
   216  		if end {
   217  			return true, nil
   218  		}
   219  	}
   220  
   221  	ap.ctr.batchCnt[bat.ShuffleIDX]++
   222  	ap.ctr.rowCnt[bat.ShuffleIDX] += bat.RowCount()
   223  	if ap.ShuffleType == plan2.ShuffleToRegIndex {
   224  		return false, sendBatToIndex(ap, proc, bat, uint32(bat.ShuffleIDX))
   225  	} else if ap.ShuffleType == plan2.ShuffleToLocalMatchedReg {
   226  		return false, sendBatToLocalMatchedReg(ap, proc, bat, uint32(bat.ShuffleIDX))
   227  	} else {
   228  		return false, sendBatToMultiMatchedReg(ap, proc, bat, uint32(bat.ShuffleIDX))
   229  	}
   230  }
   231  
   232  // send to all receiver (include LocalReceiver and RemoteReceiver)
   233  func sendToAllFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
   234  	end, remoteErr := sendToAllRemoteFunc(bat, ap, proc)
   235  	if remoteErr != nil || end {
   236  		return end, remoteErr
   237  	}
   238  
   239  	return sendToAllLocalFunc(bat, ap, proc)
   240  }
   241  
   242  // common sender: send to any LocalReceiver
   243  // if the reg which you want to send to is closed
   244  // send it to next one.
   245  func sendToAnyLocalFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
   246  	for {
   247  		sendto := ap.ctr.sendCnt % ap.ctr.localRegsCnt
   248  		reg := ap.LocalRegs[sendto]
   249  		select {
   250  		case <-proc.Ctx.Done():
   251  			return true, nil
   252  
   253  		case <-reg.Ctx.Done():
   254  			ap.LocalRegs = append(ap.LocalRegs[:sendto], ap.LocalRegs[sendto+1:]...)
   255  			ap.ctr.localRegsCnt--
   256  			ap.ctr.aliveRegCnt--
   257  			if ap.ctr.localRegsCnt == 0 {
   258  				return true, nil
   259  			}
   260  
   261  		case reg.Ch <- bat:
   262  			ap.ctr.sendCnt++
   263  			return false, nil
   264  		}
   265  	}
   266  }
   267  
   268  // common sender: send to any RemoteReceiver
   269  // if the reg which you want to send to is closed
   270  // send it to next one.
   271  func sendToAnyRemoteFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
   272  	if !ap.ctr.prepared {
   273  		end, err := ap.waitRemoteRegsReady(proc)
   274  		if err != nil {
   275  			return false, err
   276  		}
   277  		// update the cnt
   278  		ap.ctr.remoteRegsCnt = len(ap.ctr.remoteReceivers)
   279  		ap.ctr.aliveRegCnt = ap.ctr.remoteRegsCnt + ap.ctr.localRegsCnt
   280  		if end || ap.ctr.remoteRegsCnt == 0 {
   281  			return true, nil
   282  		}
   283  	}
   284  	select {
   285  	case <-proc.Ctx.Done():
   286  		return true, nil
   287  
   288  	default:
   289  	}
   290  
   291  	encodeData, errEncode := types.Encode(bat)
   292  	if errEncode != nil {
   293  		return false, errEncode
   294  	}
   295  
   296  	for {
   297  		regIdx := ap.ctr.sendCnt % ap.ctr.remoteRegsCnt
   298  		reg := ap.ctr.remoteReceivers[regIdx]
   299  
   300  		if err := sendBatchToClientSession(proc.Ctx, encodeData, reg); err != nil {
   301  			if moerr.IsMoErrCode(err, moerr.ErrStreamClosed) {
   302  				ap.ctr.remoteReceivers = append(ap.ctr.remoteReceivers[:regIdx], ap.ctr.remoteReceivers[regIdx+1:]...)
   303  				ap.ctr.remoteRegsCnt--
   304  				ap.ctr.aliveRegCnt--
   305  				if ap.ctr.remoteRegsCnt == 0 {
   306  					return true, nil
   307  				}
   308  				ap.ctr.sendCnt++
   309  				continue
   310  			} else {
   311  				return false, err
   312  			}
   313  		}
   314  		ap.ctr.sendCnt++
   315  		return false, nil
   316  	}
   317  }
   318  
   319  // Make sure enter this function LocalReceiver and RemoteReceiver are both not equal 0
   320  func sendToAnyFunc(bat *batch.Batch, ap *Argument, proc *process.Process) (bool, error) {
   321  	toLocal := (ap.ctr.sendCnt % ap.ctr.aliveRegCnt) < ap.ctr.localRegsCnt
   322  	if toLocal {
   323  		allclosed, err := sendToAnyLocalFunc(bat, ap, proc)
   324  		if err != nil {
   325  			return false, nil
   326  		}
   327  		if allclosed { // all local reg closed, change sendFunc to send remote only
   328  			proc.PutBatch(bat)
   329  			ap.ctr.sendFunc = sendToAnyRemoteFunc
   330  			return ap.ctr.sendFunc(bat, ap, proc)
   331  		}
   332  	} else {
   333  		allclosed, err := sendToAnyRemoteFunc(bat, ap, proc)
   334  		if err != nil {
   335  			return false, nil
   336  		}
   337  		if allclosed { // all remote reg closed, change sendFunc to send local only
   338  			ap.ctr.sendFunc = sendToAnyLocalFunc
   339  			return ap.ctr.sendFunc(bat, ap, proc)
   340  		}
   341  	}
   342  	return false, nil
   343  
   344  }
   345  
   346  func sendBatchToClientSession(ctx context.Context, encodeBatData []byte, wcs process.WrapCs) error {
   347  	checksum := crc32.ChecksumIEEE(encodeBatData)
   348  	if len(encodeBatData) <= maxMessageSizeToMoRpc {
   349  		msg := cnclient.AcquireMessage()
   350  		{
   351  			msg.Id = wcs.MsgId
   352  			msg.Data = encodeBatData
   353  			msg.Cmd = pipeline.Method_BatchMessage
   354  			msg.Sid = pipeline.Status_Last
   355  			msg.Checksum = checksum
   356  		}
   357  		if err := wcs.Cs.Write(ctx, msg); err != nil {
   358  			return err
   359  		}
   360  		return nil
   361  	}
   362  
   363  	start := 0
   364  	for start < len(encodeBatData) {
   365  		end := start + maxMessageSizeToMoRpc
   366  		sid := pipeline.Status_WaitingNext
   367  		if end > len(encodeBatData) {
   368  			end = len(encodeBatData)
   369  			sid = pipeline.Status_Last
   370  		}
   371  		msg := cnclient.AcquireMessage()
   372  		{
   373  			msg.Id = wcs.MsgId
   374  			msg.Data = encodeBatData[start:end]
   375  			msg.Cmd = pipeline.Method_BatchMessage
   376  			msg.Sid = sid
   377  			msg.Checksum = checksum
   378  		}
   379  
   380  		if err := wcs.Cs.Write(ctx, msg); err != nil {
   381  			return err
   382  		}
   383  		start = end
   384  	}
   385  	return nil
   386  }
   387  
   388  // success count is always no greater than refcnt
   389  func handleUnsent(proc *process.Process, bat *batch.Batch, refCnt int64, successCnt int64) {
   390  	diff := successCnt - refCnt
   391  	atomic.AddInt64(&bat.Cnt, diff)
   392  	if jm, ok := bat.AuxData.(*hashmap.JoinMap); ok {
   393  		jm.IncRef(diff)
   394  		jm.SetDupCount(diff)
   395  	}
   396  
   397  	proc.PutBatch(bat)
   398  }