github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/consensus/dbft/core/message_set.go (about)

     1  package core
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/quickchainproject/quickchain/common"
    10  	bft "github.com/quickchainproject/quickchain/consensus/dbft"
    11  )
    12  
    13  // Construct a new message set to accumulate messages for given sequence/view number.
    14  func newMessageSet(valSet bft.ValidatorSet) *messageSet {
    15  	return &messageSet{
    16  		view: &bft.View{
    17  			Round:    new(big.Int),
    18  			Sequence: new(big.Int),
    19  		},
    20  		messagesMu: new(sync.Mutex),
    21  		messages:   make(map[common.Address]*message),
    22  		valSet:     valSet,
    23  	}
    24  }
    25  
    26  // ----------------------------------------------------------------------------
    27  
    28  type messageSet struct {
    29  	view       *bft.View
    30  	valSet     bft.ValidatorSet
    31  	messagesMu *sync.Mutex
    32  	messages   map[common.Address]*message
    33  }
    34  
    35  func (ms *messageSet) View() *bft.View {
    36  	return ms.view
    37  }
    38  
    39  func (ms *messageSet) Add(msg *message) error {
    40  	ms.messagesMu.Lock()
    41  	defer ms.messagesMu.Unlock()
    42  
    43  	if err := ms.verify(msg); err != nil {
    44  		return err
    45  	}
    46  
    47  	return ms.addVerifiedMessage(msg)
    48  }
    49  
    50  func (ms *messageSet) Values() (result []*message) {
    51  	ms.messagesMu.Lock()
    52  	defer ms.messagesMu.Unlock()
    53  
    54  	for _, v := range ms.messages {
    55  		result = append(result, v)
    56  	}
    57  
    58  	return result
    59  }
    60  
    61  func (ms *messageSet) Size() int {
    62  	ms.messagesMu.Lock()
    63  	defer ms.messagesMu.Unlock()
    64  	return len(ms.messages)
    65  }
    66  
    67  func (ms *messageSet) Get(addr common.Address) *message {
    68  	ms.messagesMu.Lock()
    69  	defer ms.messagesMu.Unlock()
    70  	return ms.messages[addr]
    71  }
    72  
    73  // ----------------------------------------------------------------------------
    74  
    75  func (ms *messageSet) verify(msg *message) error {
    76  	// verify if the message comes from one of the validators
    77  	if _, v := ms.valSet.GetByAddress(msg.Address); v == nil {
    78  		return bft.ErrUnauthorizedAddress
    79  	}
    80  
    81  	// TODO: check view number and sequence number
    82  
    83  	return nil
    84  }
    85  
    86  func (ms *messageSet) addVerifiedMessage(msg *message) error {
    87  	ms.messages[msg.Address] = msg
    88  	return nil
    89  }
    90  
    91  func (ms *messageSet) String() string {
    92  	ms.messagesMu.Lock()
    93  	defer ms.messagesMu.Unlock()
    94  	addresses := make([]string, 0, len(ms.messages))
    95  	for _, v := range ms.messages {
    96  		addresses = append(addresses, v.Address.String())
    97  	}
    98  	return fmt.Sprintf("[%v]", strings.Join(addresses, ", "))
    99  }