github.com/igggame/nebulas-go@v2.1.0+incompatible/net/stream_manager.go (about)

     1  // Copyright (C) 2018 go-nebulas authors
     2  //
     3  // This file is part of the go-nebulas library.
     4  //
     5  // the go-nebulas library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // the go-nebulas library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU General Public License
    16  // along with the go-nebulas library.  If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  
    19  package net
    20  
    21  import (
    22  	"errors"
    23  	"fmt"
    24  	"hash/crc32"
    25  	"sort"
    26  	"strconv"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/sirupsen/logrus"
    31  
    32  	"github.com/gogo/protobuf/proto"
    33  	libnet "github.com/libp2p/go-libp2p-net"
    34  	peer "github.com/libp2p/go-libp2p-peer"
    35  	"github.com/nebulasio/go-nebulas/util/logging"
    36  )
    37  
    38  // const
    39  const (
    40  	CleanupInterval = time.Second * 60
    41  	// MaxStreamNum      = 500
    42  	// ReservedStreamNum = 50 // of MaxStreamNum
    43  )
    44  
    45  // var
    46  var (
    47  	ErrExceedMaxStreamNum = errors.New("too many streams connected")
    48  	ErrElimination        = errors.New("eliminated for low value")
    49  	ErrDeprecatedStream   = errors.New("deprecated stream")
    50  )
    51  
    52  // StreamManager manages all streams
    53  type StreamManager struct {
    54  	mu                sync.Mutex
    55  	quitCh            chan bool
    56  	allStreams        *sync.Map
    57  	activePeersCount  int32
    58  	maxStreamNum      int32
    59  	reservedStreamNum int32
    60  }
    61  
    62  // NewStreamManager return a new stream manager
    63  func NewStreamManager(config *Config) *StreamManager {
    64  	return &StreamManager{
    65  		quitCh:            make(chan bool, 1),
    66  		allStreams:        new(sync.Map),
    67  		activePeersCount:  0,
    68  		maxStreamNum:      config.StreamLimits,
    69  		reservedStreamNum: config.ReservedStreamLimits,
    70  	}
    71  }
    72  
    73  // Count return active peers count in the stream manager
    74  func (sm *StreamManager) Count() int32 {
    75  	return sm.activePeersCount
    76  }
    77  
    78  // Start stream manager service
    79  func (sm *StreamManager) Start() {
    80  	logging.CLog().Info("Starting NebService StreamManager...")
    81  
    82  	go sm.loop()
    83  }
    84  
    85  // Stop stream manager service
    86  func (sm *StreamManager) Stop() {
    87  	logging.CLog().Info("Stopping NebService StreamManager...")
    88  
    89  	sm.quitCh <- true
    90  }
    91  
    92  // Add a new stream into the stream manager
    93  func (sm *StreamManager) Add(s libnet.Stream, node *Node) {
    94  	stream := NewStream(s, node)
    95  	sm.AddStream(stream)
    96  }
    97  
    98  // AddStream into the stream manager
    99  func (sm *StreamManager) AddStream(stream *Stream) {
   100  
   101  	sm.mu.Lock()
   102  	defer sm.mu.Unlock()
   103  
   104  	if sm.activePeersCount >= sm.maxStreamNum {
   105  		if stream.stream != nil {
   106  			stream.stream.Close()
   107  		}
   108  		return
   109  	}
   110  
   111  	// check & close old stream
   112  	if v, ok := sm.allStreams.Load(stream.pid.Pretty()); ok {
   113  		old, _ := v.(*Stream)
   114  
   115  		logging.VLog().WithFields(logrus.Fields{
   116  			"pid": old.pid.Pretty(),
   117  		}).Debug("Removing old stream.")
   118  
   119  		sm.activePeersCount--
   120  		sm.allStreams.Delete(old.pid.Pretty())
   121  
   122  		if old.stream != nil {
   123  			old.stream.Close()
   124  		}
   125  	}
   126  
   127  	logging.VLog().WithFields(logrus.Fields{
   128  		"stream": stream.String(),
   129  	}).Debug("Added a new stream.")
   130  
   131  	sm.activePeersCount++
   132  	sm.allStreams.Store(stream.pid.Pretty(), stream)
   133  	stream.StartLoop()
   134  }
   135  
   136  // Remove the stream with the given pid from the stream manager
   137  // func (sm *StreamManager) Remove(pid peer.ID) {
   138  
   139  // 	sm.mu.Lock()
   140  // 	defer sm.mu.Unlock()
   141  
   142  // 	logging.VLog().WithFields(logrus.Fields{
   143  // 		"pid": pid.Pretty(),
   144  // 	}).Debug("Removing a stream.")
   145  
   146  // 	if _, ok := sm.allStreams.Load(pid.Pretty()); !ok {
   147  // 		// caused by close in AddStream
   148  // 		return
   149  // 	}
   150  
   151  // 	sm.activePeersCount--
   152  // 	sm.allStreams.Delete(pid.Pretty())
   153  // }
   154  
   155  // RemoveStream from the stream manager
   156  func (sm *StreamManager) RemoveStream(s *Stream) {
   157  
   158  	sm.mu.Lock()
   159  	defer sm.mu.Unlock()
   160  
   161  	v, ok := sm.allStreams.Load(s.pid.Pretty())
   162  	if !ok {
   163  		return
   164  	}
   165  
   166  	exist, _ := v.(*Stream)
   167  	if s != exist {
   168  		return
   169  	}
   170  
   171  	logging.VLog().WithFields(logrus.Fields{
   172  		"pid": s.pid.Pretty(),
   173  	}).Debug("Removing a stream.")
   174  
   175  	sm.activePeersCount--
   176  	sm.allStreams.Delete(s.pid.Pretty())
   177  }
   178  
   179  // FindByPeerID find the stream with the given peerID
   180  func (sm *StreamManager) FindByPeerID(peerID string) *Stream {
   181  	v, _ := sm.allStreams.Load(peerID)
   182  	if v == nil {
   183  		return nil
   184  	}
   185  	return v.(*Stream)
   186  }
   187  
   188  // Find the stream with the given pid
   189  func (sm *StreamManager) Find(pid peer.ID) *Stream {
   190  	return sm.FindByPeerID(pid.Pretty())
   191  }
   192  
   193  func (sm *StreamManager) loop() {
   194  	logging.CLog().Info("Started NebService StreamManager.")
   195  
   196  	ticker := time.NewTicker(CleanupInterval)
   197  	for {
   198  		select {
   199  		case <-sm.quitCh:
   200  			logging.CLog().Info("Stopped Stream Manager Loop.")
   201  			return
   202  		case <-ticker.C:
   203  			sm.cleanup()
   204  		}
   205  	}
   206  }
   207  
   208  // BroadcastMessage broadcast the message
   209  func (sm *StreamManager) BroadcastMessage(messageName string, messageContent Serializable, priority int) {
   210  	pb, _ := messageContent.ToProto()
   211  	data, err := proto.Marshal(pb)
   212  	if err != nil {
   213  		return
   214  	}
   215  
   216  	dataCheckSum := crc32.ChecksumIEEE(data)
   217  
   218  	sm.allStreams.Range(func(key, value interface{}) bool {
   219  		stream := value.(*Stream)
   220  		if stream.IsHandshakeSucceed() && !HasRecvMessage(stream, dataCheckSum) {
   221  			stream.SendMessage(messageName, data, priority)
   222  		}
   223  		return true
   224  	})
   225  }
   226  
   227  // RelayMessage relay the message
   228  func (sm *StreamManager) RelayMessage(messageName string, messageContent Serializable, priority int) {
   229  	pb, _ := messageContent.ToProto()
   230  	data, err := proto.Marshal(pb)
   231  	if err != nil {
   232  		return
   233  	}
   234  
   235  	dataCheckSum := crc32.ChecksumIEEE(data)
   236  
   237  	sm.allStreams.Range(func(key, value interface{}) bool {
   238  		stream := value.(*Stream)
   239  		if stream.IsHandshakeSucceed() && !HasRecvMessage(stream, dataCheckSum) {
   240  			stream.SendMessage(messageName, data, priority)
   241  		}
   242  		return true
   243  	})
   244  }
   245  
   246  // SendMessageToPeers send the message to the peers filtered by the filter algorithm
   247  func (sm *StreamManager) SendMessageToPeers(messageName string, data []byte, priority int, filter PeerFilterAlgorithm) []string {
   248  	allPeers := make(PeersSlice, 0)
   249  
   250  	sm.allStreams.Range(func(key, value interface{}) bool {
   251  		stream := value.(*Stream)
   252  		if stream.IsHandshakeSucceed() {
   253  			allPeers = append(allPeers, value)
   254  		}
   255  		return true
   256  	})
   257  
   258  	selectedPeers := filter.Filter(allPeers)
   259  	selectedPeersPrettyID := make([]string, 0)
   260  
   261  	for _, v := range selectedPeers {
   262  		stream := v.(*Stream)
   263  		if err := stream.SendMessage(messageName, data, priority); err == nil {
   264  			selectedPeersPrettyID = append(selectedPeersPrettyID, stream.pid.Pretty())
   265  		}
   266  	}
   267  
   268  	return selectedPeersPrettyID
   269  }
   270  
   271  // CloseStream with the given pid and reason
   272  func (sm *StreamManager) CloseStream(peerID string, reason error) {
   273  	stream := sm.FindByPeerID(peerID)
   274  	if stream != nil {
   275  		stream.close(reason)
   276  	}
   277  }
   278  
   279  // cleanup eliminating low value streams if reaching the limit
   280  func (sm *StreamManager) cleanup() {
   281  
   282  	if sm.activePeersCount < sm.maxStreamNum {
   283  		logging.VLog().WithFields(logrus.Fields{
   284  			"maxNum":      sm.maxStreamNum,
   285  			"reservedNum": sm.reservedStreamNum,
   286  			"currentNum":  sm.activePeersCount,
   287  		}).Debug("No need for streams cleanup.")
   288  		return
   289  	}
   290  
   291  	// total number of each msg type
   292  	msgTotal := make(map[string]int)
   293  
   294  	// weight of each msg type
   295  	msgWeight := make(map[string]MessageWeight)
   296  	msgWeight[ROUTETABLE] = MessageWeightRouteTable
   297  
   298  	svs := make(StreamValueSlice, 0)
   299  
   300  	sm.allStreams.Range(func(key, value interface{}) bool {
   301  		stream := value.(*Stream)
   302  
   303  		// t type, c count
   304  		for t, c := range stream.msgCount {
   305  			msgTotal[t] += c
   306  			if _, ok := msgWeight[t]; ok {
   307  				continue
   308  			}
   309  
   310  			v, _ := stream.node.netService.dispatcher.subscribersMap.Load(t)
   311  			if m, ok := v.(*sync.Map); ok {
   312  				m.Range(func(key, value interface{}) bool {
   313  					msgWeight[t] = key.(*Subscriber).MessageWeight()
   314  					return false
   315  				})
   316  			}
   317  		}
   318  
   319  		svs = append(svs, &StreamValue{
   320  			stream: stream,
   321  		})
   322  
   323  		return true
   324  	})
   325  
   326  	// check length
   327  	if len(svs) <= int(sm.maxStreamNum-sm.reservedStreamNum) {
   328  		logging.CLog().WithFields(logrus.Fields{
   329  			"streamValueSliceLength": len(svs),
   330  		}).Debug("StreamValueSlice length is not enough, return directly.")
   331  		return
   332  	}
   333  
   334  	for _, sv := range svs {
   335  		for t, c := range sv.stream.msgCount {
   336  			w, _ := msgWeight[t]
   337  			sv.value += float64(c) * float64(w) / float64(msgTotal[t])
   338  		}
   339  	}
   340  
   341  	sort.Sort(sort.Reverse(svs))
   342  	logging.VLog().WithFields(logrus.Fields{
   343  		"maxNum":           sm.maxStreamNum,
   344  		"reservedNum":      sm.reservedStreamNum,
   345  		"currentNum":       sm.activePeersCount,
   346  		"msgTotal":         msgTotal,
   347  		"msgWeight":        msgWeight,
   348  		"streamValueSlice": svs,
   349  	}).Debug("Sorting streams before the cleanup.")
   350  
   351  	eliminated := svs[sm.maxStreamNum-sm.reservedStreamNum:]
   352  	for _, sv := range eliminated {
   353  		sv.stream.close(ErrElimination)
   354  	}
   355  
   356  	svs = svs[:sm.maxStreamNum-sm.reservedStreamNum]
   357  	logging.VLog().WithFields(logrus.Fields{
   358  		"eliminatedNum": len(eliminated),
   359  		"retained":      svs,
   360  	}).Debug("Streams cleanup is done.")
   361  }
   362  
   363  // StreamValue value of stream in the past CleanupInterval
   364  type StreamValue struct {
   365  	stream *Stream
   366  	value  float64
   367  }
   368  
   369  // StreamValueSlice StreamValue slice
   370  type StreamValueSlice []*StreamValue
   371  
   372  func (s StreamValueSlice) Len() int           { return len(s) }
   373  func (s StreamValueSlice) Less(i, j int) bool { return s[i].value < s[j].value }
   374  func (s StreamValueSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
   375  func (s *StreamValue) String() string {
   376  	return s.stream.addr.String() + ":" +
   377  		strconv.FormatFloat(s.value, 'f', 3, 64) + ":" +
   378  		fmt.Sprintf("%v", s.stream.msgCount)
   379  }