github.com/cranelv/ethereum_mpc@v0.0.0-20191031014521-23aeb1415092/mpcService/step/base_step.go (about)

     1  package step
     2  
     3  import (
     4  	"github.com/ethereum/go-ethereum/mpcService/protocol"
     5  	"time"
     6  	"github.com/ethereum/go-ethereum/log"
     7  	"github.com/ethereum/go-ethereum/common"
     8  )
     9  
    10  type BaseStep struct {
    11  	nodeInfo protocol.MpcNodeInterface
    12  	mpcResult protocol.MpcResultInterface
    13  	msgChan chan *protocol.StepMessage
    14  	msgFilter     map[common.Hash]bool
    15  	finish  chan error
    16  	waiting int
    17  
    18  }
    19  
    20  func CreateBaseStep(result protocol.MpcResultInterface, nodeInfo protocol.MpcNodeInterface, wait int, bFilter bool) *BaseStep {
    21  	step := &BaseStep{mpcResult:result,nodeInfo: nodeInfo, msgChan: make(chan *protocol.StepMessage, wait+3), finish: make(chan error, 3)}
    22  	step.waiting = wait
    23  	if bFilter {
    24  		step.msgFilter = make(map[common.Hash]bool)
    25  	}
    26  	return step
    27  }
    28  
    29  func (step *BaseStep) InitMessageLoop(msger protocol.GetMessageInterface) error {
    30  	log.Info("BaseStep.InitMessageLoop begin")
    31  	if step.waiting <= 0 {
    32  		step.finishMessage(nil)
    33  	} else {
    34  		go func() {
    35  			log.Info("InitMessageLoop begin")
    36  
    37  			for {
    38  //				err := step.HandleMessage(msger)
    39  				if step.HandleMessage(msger) {
    40  //					if err != protocol.ErrQuit {
    41  //						log.Error("InitMessageLoop fail, get message err, err:%s", err.Error())
    42  //					}
    43  
    44  					break
    45  				}
    46  			}
    47  		}()
    48  	}
    49  
    50  	return nil
    51  }
    52  func (step *BaseStep) quitMessage(){
    53  	select {
    54  	case step.msgChan <- nil:
    55  	default:
    56  	}
    57  }
    58  func (step *BaseStep) finishMessage(err error){
    59  	select {
    60  	case step.finish <- err:
    61  	default:
    62  	}
    63  }
    64  func (step *BaseStep) Quit(err error) {
    65  	step.quitMessage()
    66  	step.finishMessage(err)
    67  }
    68  
    69  func (step *BaseStep) FinishStep() error {
    70  	select {
    71  	case err := <-step.finish:
    72  		if err != nil {
    73  //			log.Error("BaseStep.FinishStep, get a step finish error.","error", err.Error())
    74  		}
    75  
    76  		step.quitMessage()
    77  		return err
    78  	case <-time.After(protocol.MPCTimeOut):
    79  //		log.Error("BaseStep.FinishStep, wait step finish timeout")
    80  		step.quitMessage()
    81  		return protocol.ErrTimeOut
    82  	}
    83  }
    84  
    85  func (step *BaseStep) GetMessageChan() chan *protocol.StepMessage {
    86  	return step.msgChan
    87  }
    88  
    89  func (step *BaseStep) HandleMessage(msger protocol.GetMessageInterface) bool {
    90  	var msg *protocol.StepMessage
    91  	select {
    92  	case msg = <-step.msgChan:
    93  		if msg == nil {
    94  			log.Info("BaseStep get a quit msg")
    95  			return true
    96  		}
    97  		if step.msgFilter != nil {
    98  			_, exist := step.msgFilter[common.BytesToHash(msg.PeerID[:32])]
    99  			if exist {
   100  				log.Error("BaseStep.HandleMessage, get message from peerID fail", "peer", msg.PeerID)
   101  				return false
   102  			}
   103  			step.msgFilter[common.BytesToHash(msg.PeerID[:32])] = true
   104  		}
   105  		if step.waiting > 0 && msger.HandleMessage(msg) {
   106  			step.waiting--
   107  			if step.waiting <= 0 {
   108  				step.finishMessage(nil)
   109  				return true
   110  			}
   111  		}
   112  	}
   113  
   114  	return false
   115  }