
     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    14  package cvs
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"math/rand"
    20  	"sync"
    21  	"time"
    23  	""
    24  	pb ""
    25  	""
    26  	frameModel ""
    27  	""
    28  	dcontext ""
    29  	""
    30  	""
    31  	""
    32  	""
    33  	""
    34  	""
    35  )
    37  const (
    38  	bufferSize = 1024
    39  )
    41  type strPair struct {
    42  	firstStr  string
    43  	secondStr string
    44  }
    46  // Config is cvs task config
    47  type Config struct {
    48  	Idx      int    `json:"Idx"`
    49  	SrcHost  string `json:"SrcHost"`
    50  	DstHost  string `json:"DstHost"`
    51  	DstDir   string `json:"DstIdx"`
    52  	StartLoc string `json:"StartLoc"`
    53  }
    55  // Status represents business status of cvs task
    56  type Status struct {
    57  	TaskConfig Config `json:"Config"`
    58  	CurrentLoc string `json:"CurLoc"`
    59  	Count      int64  `json:"Cnt"`
    60  }
    62  type connPool struct {
    63  	sync.Mutex
    65  	pool map[string]connArray
    66  }
    68  var pool connPool = connPool{pool: make(map[string]connArray)}
    70  func (c *connPool) getConn(addr string) (*grpc.ClientConn, error) {
    71  	c.Lock()
    72  	defer c.Unlock()
    73  	arr, ok := c.pool[addr]
    74  	if !ok {
    75  		for i := 0; i < 5; i++ {
    76  			conn, err := grpc.Dial(addr, grpc.WithInsecure())
    77  			if err != nil {
    78  				return nil, err
    79  			}
    80  			arr = append(arr, conn)
    81  		}
    82  		c.pool[addr] = arr
    83  	}
    84  	i := rand.Intn(5)
    85  	return arr[i], nil
    86  }
    88  type connArray []*grpc.ClientConn
    90  type cvsTask struct {
    91  	framework.BaseWorker
    92  	Config
    93  	counter  *atomic.Int64
    94  	curLoc   string
    95  	cancelFn func()
    96  	buffer   chan strPair
    97  	isEOF    bool
    99  	statusCode struct {
   100  		sync.RWMutex
   101  		code frameModel.WorkerState
   102  	}
   103  	runError struct {
   104  		sync.RWMutex
   105  		err error
   106  	}
   108  	statusRateLimiter *rate.Limiter
   109  }
   111  // RegisterWorker is used to register cvs task worker into global registry
   112  func RegisterWorker() {
   113  	factory := registry.NewSimpleWorkerFactory(newCvsTask)
   114  	registry.GlobalWorkerRegistry().MustRegisterWorkerType(frameModel.CvsTask, factory)
   115  }
   117  func newCvsTask(ctx *dcontext.Context, _workerID frameModel.WorkerID, masterID frameModel.MasterID, conf *Config) *cvsTask {
   118  	task := &cvsTask{
   119  		Config:            *conf,
   120  		curLoc:            conf.StartLoc,
   121  		buffer:            make(chan strPair, bufferSize),
   122  		statusRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1),
   123  		counter:           atomic.NewInt64(0),
   124  	}
   125  	return task
   126  }
   128  // InitImpl implements WorkerImpl.InitImpl
   129  func (task *cvsTask) InitImpl(ctx context.Context) error {
   130  	log.Info("init the task  ", zap.Any("task id :", task.ID()))
   131  	task.setState(frameModel.WorkerStateNormal)
   132  	// Don't use the ctx from the caller. Caller may cancel the ctx after InitImpl returns.
   133  	ctx, task.cancelFn = context.WithCancel(context.Background())
   134  	go func() {
   135  		err := task.Receive(ctx)
   136  		if err != nil {
   137  			log.Error("error happened when reading data from the upstream ", zap.String("id", task.ID()), zap.Any("message", err.Error()))
   138  			task.setRunError(err)
   139  			task.setState(frameModel.WorkerStateError)
   140  		}
   141  	}()
   142  	go func() {
   143  		err := task.send(ctx)
   144  		if err != nil {
   145  			log.Error("error happened when writing data to the downstream ", zap.String("id", task.ID()), zap.Any("message", err.Error()))
   146  			task.setRunError(err)
   147  			task.setState(frameModel.WorkerStateError)
   148  		} else {
   149  			task.setState(frameModel.WorkerStateFinished)
   150  		}
   151  	}()
   153  	return nil
   154  }
   156  // Tick is called on a fixed interval.
   157  func (task *cvsTask) Tick(ctx context.Context) error {
   158  	// log.Info("cvs task tick", zap.Any(" task id ", string(task.ID())+" -- "+strconv.FormatInt(task.counter, 10)))
   159  	if task.statusRateLimiter.Allow() {
   160  		err := task.BaseWorker.UpdateStatus(ctx, task.Status())
   161  		if errors.Is(err, errors.ErrWorkerUpdateStatusTryAgain) {
   162  			log.Warn("update status try again later", zap.String("id", task.ID()), zap.String("error", err.Error()))
   163  			return nil
   164  		}
   165  		return err
   166  	}
   168  	exitReason := framework.ExitReasonUnknown
   169  	switch task.getState() {
   170  	case frameModel.WorkerStateFinished:
   171  		exitReason = framework.ExitReasonFinished
   172  	case frameModel.WorkerStateError:
   173  		exitReason = framework.ExitReasonFailed
   174  	case frameModel.WorkerStateStopped:
   175  		exitReason = framework.ExitReasonCanceled
   176  	default:
   177  	}
   179  	if exitReason == framework.ExitReasonUnknown {
   180  		return nil
   181  	}
   183  	return task.BaseWorker.Exit(ctx, exitReason, task.getRunError(), task.Status().ExtBytes)
   184  }
   186  // Status returns a short worker status to be periodically sent to the master.
   187  func (task *cvsTask) Status() frameModel.WorkerStatus {
   188  	stats := &Status{
   189  		TaskConfig: task.Config,
   190  		CurrentLoc: task.curLoc,
   191  		Count:      task.counter.Load(),
   192  	}
   193  	statsBytes, err := json.Marshal(stats)
   194  	if err != nil {
   195  		log.Panic("get stats error", zap.String("id", task.ID()), zap.Error(err))
   196  	}
   197  	return frameModel.WorkerStatus{
   198  		State:    task.getState(),
   199  		ErrorMsg: "",
   200  		ExtBytes: statsBytes,
   201  	}
   202  }
   204  func (task *cvsTask) OnMasterMessage(ctx context.Context, topic p2p.Topic, message p2p.MessageValue) error {
   205  	switch msg := message.(type) {
   206  	case *frameModel.StatusChangeRequest:
   207  		switch msg.ExpectState {
   208  		case frameModel.WorkerStateStopped:
   209  			task.setState(frameModel.WorkerStateStopped)
   210  		default:
   211  			log.Info("FakeWorker: ignore status change state", zap.Int32("state", int32(msg.ExpectState)))
   212  		}
   213  	default:
   214  		log.Info("unsupported message", zap.Any("message", message))
   215  	}
   217  	return nil
   218  }
   220  // CloseImpl tells the WorkerImpl to quitrunStatusWorker and release resources.
   221  func (task *cvsTask) CloseImpl(ctx context.Context) {
   222  	if task.cancelFn != nil {
   223  		task.cancelFn()
   224  	}
   225  }
   227  func (task *cvsTask) Receive(ctx context.Context) error {
   228  	conn, err := pool.getConn(task.SrcHost)
   229  	if err != nil {
   230  		log.Error("cann't connect with the source address ", zap.String("id", task.ID()), zap.Any("message", task.SrcHost))
   231  		return err
   232  	}
   233  	client := pb.NewDataRWServiceClient(conn)
   234  	reader, err := client.ReadLines(ctx, &pb.ReadLinesRequest{FileIdx: int32(task.Idx), LineNo: []byte(task.StartLoc)})
   235  	if err != nil {
   236  		log.Error("read data from file failed ", zap.String("id", task.ID()), zap.Error(err))
   237  		return err
   238  	}
   239  	for {
   240  		reply, err := reader.Recv()
   241  		if err != nil {
   242  			log.Error("read data failed", zap.String("id", task.ID()), zap.Error(err))
   243  			if !task.isEOF {
   244  				task.cancelFn()
   245  			}
   246  			return err
   247  		}
   248  		if reply.IsEof {
   249  			log.Info("Reach the end of the file ", zap.String("id", task.ID()), zap.Any("fileID", task.Idx))
   250  			close(task.buffer)
   251  			break
   252  		}
   253  		select {
   254  		case <-ctx.Done():
   255  			return nil
   256  		case task.buffer <- strPair{firstStr: string(reply.Key), secondStr: string(reply.Val)}:
   257  		}
   258  		// waiting longer time to read lines slowly
   259  	}
   260  	return nil
   261  }
   263  func (task *cvsTask) send(ctx context.Context) error {
   264  	conn, err := pool.getConn(task.DstHost)
   265  	if err != nil {
   266  		log.Error("can't connect with the destination address ", zap.Any("id", task.ID()), zap.Error(err))
   267  		return err
   268  	}
   269  	client := pb.NewDataRWServiceClient(conn)
   270  	writer, err := client.WriteLines(ctx)
   271  	if err != nil {
   272  		log.Error("call write data rpc failed", zap.String("id", task.ID()), zap.Error(err))
   273  		task.cancelFn()
   274  		return err
   275  	}
   276  	for {
   277  		select {
   278  		case kv, more := <-task.buffer:
   279  			if !more {
   280  				log.Info("Reach the end of the file ", zap.String("id", task.ID()), zap.Any("cnt", task.counter.Load()), zap.String("last write", task.curLoc))
   281  				resp, err := writer.CloseAndRecv()
   282  				if err != nil {
   283  					return err
   284  				}
   285  				if len(resp.ErrMsg) > 0 {
   286  					log.Warn("close writing meet error", zap.String("id", task.ID()))
   287  				}
   288  				return nil
   289  			}
   290  			err := writer.Send(&pb.WriteLinesRequest{FileIdx: int32(task.Idx), Key: []byte(kv.firstStr), Value: []byte(kv.secondStr), Dir: task.DstDir})
   291  			if err != nil {
   292  				log.Error("call write data rpc failed ", zap.String("id", task.ID()), zap.Error(err))
   293  				task.cancelFn()
   294  				return err
   295  			}
   296  			task.counter.Add(1)
   297  			task.curLoc = kv.firstStr
   298  		case <-ctx.Done():
   299  			return ctx.Err()
   300  		}
   301  	}
   302  }
   304  func (task *cvsTask) getState() frameModel.WorkerState {
   305  	task.statusCode.RLock()
   306  	defer task.statusCode.RUnlock()
   307  	return task.statusCode.code
   308  }
   310  func (task *cvsTask) setState(status frameModel.WorkerState) {
   311  	task.statusCode.Lock()
   312  	defer task.statusCode.Unlock()
   313  	task.statusCode.code = status
   314  }
   316  func (task *cvsTask) getRunError() error {
   317  	task.runError.RLock()
   318  	defer task.runError.RUnlock()
   319  	return task.runError.err
   320  }
   322  func (task *cvsTask) setRunError(err error) {
   323  	task.runError.Lock()
   324  	defer task.runError.Unlock()
   325  	task.runError.err = err
   326  }