github.com/JFJun/bsc@v1.0.0/les/server_handler.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/JFJun/bsc/common"
    28  	"github.com/JFJun/bsc/common/mclock"
    29  	"github.com/JFJun/bsc/core"
    30  	"github.com/JFJun/bsc/core/rawdb"
    31  	"github.com/JFJun/bsc/core/state"
    32  	"github.com/JFJun/bsc/core/types"
    33  	"github.com/JFJun/bsc/ethdb"
    34  	"github.com/JFJun/bsc/light"
    35  	"github.com/JFJun/bsc/log"
    36  	"github.com/JFJun/bsc/metrics"
    37  	"github.com/JFJun/bsc/p2p"
    38  	"github.com/JFJun/bsc/rlp"
    39  	"github.com/JFJun/bsc/trie"
    40  )
    41  
    42  const (
    43  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    44  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    45  	ethVersion        = 63              // equivalent eth version for the downloader
    46  
    47  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    48  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    49  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    50  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    51  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    52  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    53  	MaxTxSend                = 64  // Amount of transactions to be send per request
    54  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    55  )
    56  
    57  var (
    58  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    59  	errFullClientPool        = errors.New("client pool is full")
    60  )
    61  
    62  // serverHandler is responsible for serving light client and process
    63  // all incoming light requests.
    64  type serverHandler struct {
    65  	blockchain *core.BlockChain
    66  	chainDb    ethdb.Database
    67  	txpool     *core.TxPool
    68  	server     *LesServer
    69  
    70  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    71  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    72  	synced  func() bool    // Callback function used to determine whether local node is synced.
    73  
    74  	// Testing fields
    75  	addTxsSync bool
    76  }
    77  
    78  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    79  	handler := &serverHandler{
    80  		server:     server,
    81  		blockchain: blockchain,
    82  		chainDb:    chainDb,
    83  		txpool:     txpool,
    84  		closeCh:    make(chan struct{}),
    85  		synced:     synced,
    86  	}
    87  	return handler
    88  }
    89  
    90  // start starts the server handler.
    91  func (h *serverHandler) start() {
    92  	h.wg.Add(1)
    93  	go h.broadcastHeaders()
    94  }
    95  
    96  // stop stops the server handler.
    97  func (h *serverHandler) stop() {
    98  	close(h.closeCh)
    99  	h.wg.Wait()
   100  }
   101  
   102  // runPeer is the p2p protocol run function for the given version.
   103  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   104  	peer := newClientPeer(int(version), h.server.config.NetworkId, p, newMeteredMsgWriter(rw, int(version)))
   105  	defer peer.close()
   106  	h.wg.Add(1)
   107  	defer h.wg.Done()
   108  	return h.handle(peer)
   109  }
   110  
   111  func (h *serverHandler) handle(p *clientPeer) error {
   112  	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
   113  
   114  	// Execute the LES handshake
   115  	var (
   116  		head   = h.blockchain.CurrentHeader()
   117  		hash   = head.Hash()
   118  		number = head.Number.Uint64()
   119  		td     = h.blockchain.GetTd(hash, number)
   120  	)
   121  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
   122  		p.Log().Debug("Light Ethereum handshake failed", "err", err)
   123  		return err
   124  	}
   125  	if p.server {
   126  		// connected to another server, no messages expected, just wait for disconnection
   127  		_, err := p.rw.ReadMsg()
   128  		return err
   129  	}
   130  	// Reject light clients if server is not synced.
   131  	if !h.synced() {
   132  		p.Log().Debug("Light server not synced, rejecting peer")
   133  		return p2p.DiscRequested
   134  	}
   135  	defer p.fcClient.Disconnect()
   136  
   137  	// Disconnect the inbound peer if it's rejected by clientPool
   138  	if !h.server.clientPool.connect(p, 0) {
   139  		p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
   140  		return errFullClientPool
   141  	}
   142  	// Register the peer locally
   143  	if err := h.server.peers.register(p); err != nil {
   144  		h.server.clientPool.disconnect(p)
   145  		p.Log().Error("Light Ethereum peer registration failed", "err", err)
   146  		return err
   147  	}
   148  	clientConnectionGauge.Update(int64(h.server.peers.len()))
   149  
   150  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   151  
   152  	connectedAt := mclock.Now()
   153  	defer func() {
   154  		wg.Wait() // Ensure all background task routines have exited.
   155  		h.server.peers.unregister(p.id)
   156  		h.server.clientPool.disconnect(p)
   157  		clientConnectionGauge.Update(int64(h.server.peers.len()))
   158  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   159  	}()
   160  	// Mark the peer starts to be served.
   161  	atomic.StoreUint32(&p.serving, 1)
   162  	defer atomic.StoreUint32(&p.serving, 0)
   163  
   164  	// Spawn a main loop to handle all incoming messages.
   165  	for {
   166  		select {
   167  		case err := <-p.errCh:
   168  			p.Log().Debug("Failed to send light ethereum response", "err", err)
   169  			return err
   170  		default:
   171  		}
   172  		if err := h.handleMsg(p, &wg); err != nil {
   173  			p.Log().Debug("Light Ethereum message handling failed", "err", err)
   174  			return err
   175  		}
   176  	}
   177  }
   178  
   179  // handleMsg is invoked whenever an inbound message is received from a remote
   180  // peer. The remote connection is torn down upon returning any error.
   181  func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
   182  	// Read the next message from the remote peer, and ensure it's fully consumed
   183  	msg, err := p.rw.ReadMsg()
   184  	if err != nil {
   185  		return err
   186  	}
   187  	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
   188  
   189  	// Discard large message which exceeds the limitation.
   190  	if msg.Size > ProtocolMaxMsgSize {
   191  		clientErrorMeter.Mark(1)
   192  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   193  	}
   194  	defer msg.Discard()
   195  
   196  	var (
   197  		maxCost uint64
   198  		task    *servingTask
   199  	)
   200  	p.responseCount++
   201  	responseCount := p.responseCount
   202  	// accept returns an indicator whether the request can be served.
   203  	// If so, deduct the max cost from the flow control buffer.
   204  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   205  		// Short circuit if the peer is already frozen or the request is invalid.
   206  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   207  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   208  			p.fcClient.OneTimeCost(inSizeCost)
   209  			return false
   210  		}
   211  		// Prepaid max cost units before request been serving.
   212  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   213  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   214  		if !accepted {
   215  			p.freeze()
   216  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   217  			p.fcClient.OneTimeCost(inSizeCost)
   218  			return false
   219  		}
   220  		// Create a multi-stage task, estimate the time it takes for the task to
   221  		// execute, and cache it in the request service queue.
   222  		factor := h.server.costTracker.globalFactor()
   223  		if factor < 0.001 {
   224  			factor = 1
   225  			p.Log().Error("Invalid global cost factor", "factor", factor)
   226  		}
   227  		maxTime := uint64(float64(maxCost) / factor)
   228  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   229  		if task.start() {
   230  			return true
   231  		}
   232  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   233  		return false
   234  	}
   235  	// sendResponse sends back the response and updates the flow control statistic.
   236  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   237  		p.responseLock.Lock()
   238  		defer p.responseLock.Unlock()
   239  
   240  		// Short circuit if the client is already frozen.
   241  		if p.isFrozen() {
   242  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   243  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   244  			return
   245  		}
   246  		// Positive correction buffer value with real cost.
   247  		var replySize uint32
   248  		if reply != nil {
   249  			replySize = reply.size()
   250  		}
   251  		var realCost uint64
   252  		if h.server.costTracker.testing {
   253  			realCost = maxCost // Assign a fake cost for testing purpose
   254  		} else {
   255  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   256  		}
   257  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   258  		if amount != 0 {
   259  			// Feed cost tracker request serving statistic.
   260  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   261  			// Reduce priority "balance" for the specific peer.
   262  			h.server.clientPool.requestCost(p, realCost)
   263  		}
   264  		if reply != nil {
   265  			p.mustQueueSend(func() {
   266  				if err := reply.send(bv); err != nil {
   267  					select {
   268  					case p.errCh <- err:
   269  					default:
   270  					}
   271  				}
   272  			})
   273  		}
   274  	}
   275  	switch msg.Code {
   276  	case GetBlockHeadersMsg:
   277  		p.Log().Trace("Received block header request")
   278  		if metrics.EnabledExpensive {
   279  			miscInHeaderPacketsMeter.Mark(1)
   280  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   281  		}
   282  		var req struct {
   283  			ReqID uint64
   284  			Query getBlockHeadersData
   285  		}
   286  		if err := msg.Decode(&req); err != nil {
   287  			clientErrorMeter.Mark(1)
   288  			return errResp(ErrDecode, "%v: %v", msg, err)
   289  		}
   290  		query := req.Query
   291  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   292  			wg.Add(1)
   293  			go func() {
   294  				defer wg.Done()
   295  				hashMode := query.Origin.Hash != (common.Hash{})
   296  				first := true
   297  				maxNonCanonical := uint64(100)
   298  
   299  				// Gather headers until the fetch or network limits is reached
   300  				var (
   301  					bytes   common.StorageSize
   302  					headers []*types.Header
   303  					unknown bool
   304  				)
   305  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   306  					if !first && !task.waitOrStop() {
   307  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   308  						return
   309  					}
   310  					// Retrieve the next header satisfying the query
   311  					var origin *types.Header
   312  					if hashMode {
   313  						if first {
   314  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   315  							if origin != nil {
   316  								query.Origin.Number = origin.Number.Uint64()
   317  							}
   318  						} else {
   319  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   320  						}
   321  					} else {
   322  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   323  					}
   324  					if origin == nil {
   325  						atomic.AddUint32(&p.invalidCount, 1)
   326  						break
   327  					}
   328  					headers = append(headers, origin)
   329  					bytes += estHeaderRlpSize
   330  
   331  					// Advance to the next header of the query
   332  					switch {
   333  					case hashMode && query.Reverse:
   334  						// Hash based traversal towards the genesis block
   335  						ancestor := query.Skip + 1
   336  						if ancestor == 0 {
   337  							unknown = true
   338  						} else {
   339  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   340  							unknown = query.Origin.Hash == common.Hash{}
   341  						}
   342  					case hashMode && !query.Reverse:
   343  						// Hash based traversal towards the leaf block
   344  						var (
   345  							current = origin.Number.Uint64()
   346  							next    = current + query.Skip + 1
   347  						)
   348  						if next <= current {
   349  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   350  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   351  							unknown = true
   352  						} else {
   353  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   354  								nextHash := header.Hash()
   355  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   356  								if expOldHash == query.Origin.Hash {
   357  									query.Origin.Hash, query.Origin.Number = nextHash, next
   358  								} else {
   359  									unknown = true
   360  								}
   361  							} else {
   362  								unknown = true
   363  							}
   364  						}
   365  					case query.Reverse:
   366  						// Number based traversal towards the genesis block
   367  						if query.Origin.Number >= query.Skip+1 {
   368  							query.Origin.Number -= query.Skip + 1
   369  						} else {
   370  							unknown = true
   371  						}
   372  
   373  					case !query.Reverse:
   374  						// Number based traversal towards the leaf block
   375  						query.Origin.Number += query.Skip + 1
   376  					}
   377  					first = false
   378  				}
   379  				reply := p.replyBlockHeaders(req.ReqID, headers)
   380  				sendResponse(req.ReqID, query.Amount, p.replyBlockHeaders(req.ReqID, headers), task.done())
   381  				if metrics.EnabledExpensive {
   382  					miscOutHeaderPacketsMeter.Mark(1)
   383  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   384  					miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime))
   385  				}
   386  			}()
   387  		}
   388  
   389  	case GetBlockBodiesMsg:
   390  		p.Log().Trace("Received block bodies request")
   391  		if metrics.EnabledExpensive {
   392  			miscInBodyPacketsMeter.Mark(1)
   393  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   394  		}
   395  		var req struct {
   396  			ReqID  uint64
   397  			Hashes []common.Hash
   398  		}
   399  		if err := msg.Decode(&req); err != nil {
   400  			clientErrorMeter.Mark(1)
   401  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   402  		}
   403  		var (
   404  			bytes  int
   405  			bodies []rlp.RawValue
   406  		)
   407  		reqCnt := len(req.Hashes)
   408  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   409  			wg.Add(1)
   410  			go func() {
   411  				defer wg.Done()
   412  				for i, hash := range req.Hashes {
   413  					if i != 0 && !task.waitOrStop() {
   414  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   415  						return
   416  					}
   417  					if bytes >= softResponseLimit {
   418  						break
   419  					}
   420  					body := h.blockchain.GetBodyRLP(hash)
   421  					if body == nil {
   422  						atomic.AddUint32(&p.invalidCount, 1)
   423  						continue
   424  					}
   425  					bodies = append(bodies, body)
   426  					bytes += len(body)
   427  				}
   428  				reply := p.replyBlockBodiesRLP(req.ReqID, bodies)
   429  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   430  				if metrics.EnabledExpensive {
   431  					miscOutBodyPacketsMeter.Mark(1)
   432  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   433  					miscServingTimeBodyTimer.Update(time.Duration(task.servingTime))
   434  				}
   435  			}()
   436  		}
   437  
   438  	case GetCodeMsg:
   439  		p.Log().Trace("Received code request")
   440  		if metrics.EnabledExpensive {
   441  			miscInCodePacketsMeter.Mark(1)
   442  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   443  		}
   444  		var req struct {
   445  			ReqID uint64
   446  			Reqs  []CodeReq
   447  		}
   448  		if err := msg.Decode(&req); err != nil {
   449  			clientErrorMeter.Mark(1)
   450  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   451  		}
   452  		var (
   453  			bytes int
   454  			data  [][]byte
   455  		)
   456  		reqCnt := len(req.Reqs)
   457  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   458  			wg.Add(1)
   459  			go func() {
   460  				defer wg.Done()
   461  				for i, request := range req.Reqs {
   462  					if i != 0 && !task.waitOrStop() {
   463  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   464  						return
   465  					}
   466  					// Look up the root hash belonging to the request
   467  					header := h.blockchain.GetHeaderByHash(request.BHash)
   468  					if header == nil {
   469  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   470  						atomic.AddUint32(&p.invalidCount, 1)
   471  						continue
   472  					}
   473  					// Refuse to search stale state data in the database since looking for
   474  					// a non-exist key is kind of expensive.
   475  					local := h.blockchain.CurrentHeader().Number.Uint64()
   476  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   477  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   478  						atomic.AddUint32(&p.invalidCount, 1)
   479  						continue
   480  					}
   481  					triedb := h.blockchain.StateCache().TrieDB()
   482  
   483  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   484  					if err != nil {
   485  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   486  						atomic.AddUint32(&p.invalidCount, 1)
   487  						continue
   488  					}
   489  					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
   490  					if err != nil {
   491  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   492  						continue
   493  					}
   494  					// Accumulate the code and abort if enough data was retrieved
   495  					data = append(data, code)
   496  					if bytes += len(code); bytes >= softResponseLimit {
   497  						break
   498  					}
   499  				}
   500  				reply := p.replyCode(req.ReqID, data)
   501  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   502  				if metrics.EnabledExpensive {
   503  					miscOutCodePacketsMeter.Mark(1)
   504  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   505  					miscServingTimeCodeTimer.Update(time.Duration(task.servingTime))
   506  				}
   507  			}()
   508  		}
   509  
   510  	case GetReceiptsMsg:
   511  		p.Log().Trace("Received receipts request")
   512  		if metrics.EnabledExpensive {
   513  			miscInReceiptPacketsMeter.Mark(1)
   514  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   515  		}
   516  		var req struct {
   517  			ReqID  uint64
   518  			Hashes []common.Hash
   519  		}
   520  		if err := msg.Decode(&req); err != nil {
   521  			clientErrorMeter.Mark(1)
   522  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   523  		}
   524  		var (
   525  			bytes    int
   526  			receipts []rlp.RawValue
   527  		)
   528  		reqCnt := len(req.Hashes)
   529  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   530  			wg.Add(1)
   531  			go func() {
   532  				defer wg.Done()
   533  				for i, hash := range req.Hashes {
   534  					if i != 0 && !task.waitOrStop() {
   535  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   536  						return
   537  					}
   538  					if bytes >= softResponseLimit {
   539  						break
   540  					}
   541  					// Retrieve the requested block's receipts, skipping if unknown to us
   542  					results := h.blockchain.GetReceiptsByHash(hash)
   543  					if results == nil {
   544  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   545  							atomic.AddUint32(&p.invalidCount, 1)
   546  							continue
   547  						}
   548  					}
   549  					// If known, encode and queue for response packet
   550  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   551  						log.Error("Failed to encode receipt", "err", err)
   552  					} else {
   553  						receipts = append(receipts, encoded)
   554  						bytes += len(encoded)
   555  					}
   556  				}
   557  				reply := p.replyReceiptsRLP(req.ReqID, receipts)
   558  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   559  				if metrics.EnabledExpensive {
   560  					miscOutReceiptPacketsMeter.Mark(1)
   561  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   562  					miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime))
   563  				}
   564  			}()
   565  		}
   566  
   567  	case GetProofsV2Msg:
   568  		p.Log().Trace("Received les/2 proofs request")
   569  		if metrics.EnabledExpensive {
   570  			miscInTrieProofPacketsMeter.Mark(1)
   571  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   572  		}
   573  		var req struct {
   574  			ReqID uint64
   575  			Reqs  []ProofReq
   576  		}
   577  		if err := msg.Decode(&req); err != nil {
   578  			clientErrorMeter.Mark(1)
   579  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   580  		}
   581  		// Gather state data until the fetch or network limits is reached
   582  		var (
   583  			lastBHash common.Hash
   584  			root      common.Hash
   585  		)
   586  		reqCnt := len(req.Reqs)
   587  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   588  			wg.Add(1)
   589  			go func() {
   590  				defer wg.Done()
   591  				nodes := light.NewNodeSet()
   592  
   593  				for i, request := range req.Reqs {
   594  					if i != 0 && !task.waitOrStop() {
   595  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   596  						return
   597  					}
   598  					// Look up the root hash belonging to the request
   599  					var (
   600  						header *types.Header
   601  						trie   state.Trie
   602  					)
   603  					if request.BHash != lastBHash {
   604  						root, lastBHash = common.Hash{}, request.BHash
   605  
   606  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   607  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   608  							atomic.AddUint32(&p.invalidCount, 1)
   609  							continue
   610  						}
   611  						// Refuse to search stale state data in the database since looking for
   612  						// a non-exist key is kind of expensive.
   613  						local := h.blockchain.CurrentHeader().Number.Uint64()
   614  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   615  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   616  							atomic.AddUint32(&p.invalidCount, 1)
   617  							continue
   618  						}
   619  						root = header.Root
   620  					}
   621  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   622  					if root == (common.Hash{}) {
   623  						atomic.AddUint32(&p.invalidCount, 1)
   624  						continue
   625  					}
   626  					// Open the account or storage trie for the request
   627  					statedb := h.blockchain.StateCache()
   628  
   629  					switch len(request.AccKey) {
   630  					case 0:
   631  						// No account key specified, open an account trie
   632  						trie, err = statedb.OpenTrie(root)
   633  						if trie == nil || err != nil {
   634  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   635  							continue
   636  						}
   637  					default:
   638  						// Account key specified, open a storage trie
   639  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   640  						if err != nil {
   641  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   642  							atomic.AddUint32(&p.invalidCount, 1)
   643  							continue
   644  						}
   645  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   646  						if trie == nil || err != nil {
   647  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   648  							continue
   649  						}
   650  					}
   651  					// Prove the user's request from the account or stroage trie
   652  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   653  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   654  						continue
   655  					}
   656  					if nodes.DataSize() >= softResponseLimit {
   657  						break
   658  					}
   659  				}
   660  				reply := p.replyProofsV2(req.ReqID, nodes.NodeList())
   661  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   662  				if metrics.EnabledExpensive {
   663  					miscOutTrieProofPacketsMeter.Mark(1)
   664  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   665  					miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime))
   666  				}
   667  			}()
   668  		}
   669  
   670  	case GetHelperTrieProofsMsg:
   671  		p.Log().Trace("Received helper trie proof request")
   672  		if metrics.EnabledExpensive {
   673  			miscInHelperTriePacketsMeter.Mark(1)
   674  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   675  		}
   676  		var req struct {
   677  			ReqID uint64
   678  			Reqs  []HelperTrieReq
   679  		}
   680  		if err := msg.Decode(&req); err != nil {
   681  			clientErrorMeter.Mark(1)
   682  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   683  		}
   684  		// Gather state data until the fetch or network limits is reached
   685  		var (
   686  			auxBytes int
   687  			auxData  [][]byte
   688  		)
   689  		reqCnt := len(req.Reqs)
   690  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   691  			wg.Add(1)
   692  			go func() {
   693  				defer wg.Done()
   694  				var (
   695  					lastIdx  uint64
   696  					lastType uint
   697  					root     common.Hash
   698  					auxTrie  *trie.Trie
   699  				)
   700  				nodes := light.NewNodeSet()
   701  				for i, request := range req.Reqs {
   702  					if i != 0 && !task.waitOrStop() {
   703  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   704  						return
   705  					}
   706  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   707  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   708  
   709  						var prefix string
   710  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   711  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   712  						}
   713  					}
   714  					if request.AuxReq == auxRoot {
   715  						var data []byte
   716  						if root != (common.Hash{}) {
   717  							data = root[:]
   718  						}
   719  						auxData = append(auxData, data)
   720  						auxBytes += len(data)
   721  					} else {
   722  						if auxTrie != nil {
   723  							auxTrie.Prove(request.Key, request.FromLevel, nodes)
   724  						}
   725  						if request.AuxReq != 0 {
   726  							data := h.getAuxiliaryHeaders(request)
   727  							auxData = append(auxData, data)
   728  							auxBytes += len(data)
   729  						}
   730  					}
   731  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   732  						break
   733  					}
   734  				}
   735  				reply := p.replyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   736  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   737  				if metrics.EnabledExpensive {
   738  					miscOutHelperTriePacketsMeter.Mark(1)
   739  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   740  					miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime))
   741  				}
   742  			}()
   743  		}
   744  
   745  	case SendTxV2Msg:
   746  		p.Log().Trace("Received new transactions")
   747  		if metrics.EnabledExpensive {
   748  			miscInTxsPacketsMeter.Mark(1)
   749  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   750  		}
   751  		var req struct {
   752  			ReqID uint64
   753  			Txs   []*types.Transaction
   754  		}
   755  		if err := msg.Decode(&req); err != nil {
   756  			clientErrorMeter.Mark(1)
   757  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   758  		}
   759  		reqCnt := len(req.Txs)
   760  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   761  			wg.Add(1)
   762  			go func() {
   763  				defer wg.Done()
   764  				stats := make([]light.TxStatus, len(req.Txs))
   765  				for i, tx := range req.Txs {
   766  					if i != 0 && !task.waitOrStop() {
   767  						return
   768  					}
   769  					hash := tx.Hash()
   770  					stats[i] = h.txStatus(hash)
   771  					if stats[i].Status == core.TxStatusUnknown {
   772  						addFn := h.txpool.AddRemotes
   773  						// Add txs synchronously for testing purpose
   774  						if h.addTxsSync {
   775  							addFn = h.txpool.AddRemotesSync
   776  						}
   777  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   778  							stats[i].Error = errs[0].Error()
   779  							continue
   780  						}
   781  						stats[i] = h.txStatus(hash)
   782  					}
   783  				}
   784  				reply := p.replyTxStatus(req.ReqID, stats)
   785  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   786  				if metrics.EnabledExpensive {
   787  					miscOutTxsPacketsMeter.Mark(1)
   788  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   789  					miscServingTimeTxTimer.Update(time.Duration(task.servingTime))
   790  				}
   791  			}()
   792  		}
   793  
   794  	case GetTxStatusMsg:
   795  		p.Log().Trace("Received transaction status query request")
   796  		if metrics.EnabledExpensive {
   797  			miscInTxStatusPacketsMeter.Mark(1)
   798  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   799  		}
   800  		var req struct {
   801  			ReqID  uint64
   802  			Hashes []common.Hash
   803  		}
   804  		if err := msg.Decode(&req); err != nil {
   805  			clientErrorMeter.Mark(1)
   806  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   807  		}
   808  		reqCnt := len(req.Hashes)
   809  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   810  			wg.Add(1)
   811  			go func() {
   812  				defer wg.Done()
   813  				stats := make([]light.TxStatus, len(req.Hashes))
   814  				for i, hash := range req.Hashes {
   815  					if i != 0 && !task.waitOrStop() {
   816  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   817  						return
   818  					}
   819  					stats[i] = h.txStatus(hash)
   820  				}
   821  				reply := p.replyTxStatus(req.ReqID, stats)
   822  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   823  				if metrics.EnabledExpensive {
   824  					miscOutTxStatusPacketsMeter.Mark(1)
   825  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   826  					miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime))
   827  				}
   828  			}()
   829  		}
   830  
   831  	default:
   832  		p.Log().Trace("Received invalid message", "code", msg.Code)
   833  		clientErrorMeter.Mark(1)
   834  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   835  	}
   836  	// If the client has made too much invalid request(e.g. request a non-exist data),
   837  	// reject them to prevent SPAM attack.
   838  	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
   839  		clientErrorMeter.Mark(1)
   840  		return errTooManyInvalidRequest
   841  	}
   842  	return nil
   843  }
   844  
   845  // getAccount retrieves an account from the state based on root.
   846  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   847  	trie, err := trie.New(root, triedb)
   848  	if err != nil {
   849  		return state.Account{}, err
   850  	}
   851  	blob, err := trie.TryGet(hash[:])
   852  	if err != nil {
   853  		return state.Account{}, err
   854  	}
   855  	var account state.Account
   856  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   857  		return state.Account{}, err
   858  	}
   859  	return account, nil
   860  }
   861  
   862  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   863  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   864  	switch typ {
   865  	case htCanonical:
   866  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   867  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   868  	case htBloomBits:
   869  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   870  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   871  	}
   872  	return common.Hash{}, ""
   873  }
   874  
   875  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   876  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   877  	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
   878  		blockNum := binary.BigEndian.Uint64(req.Key)
   879  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   880  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   881  	}
   882  	return nil
   883  }
   884  
   885  // txStatus returns the status of a specified transaction.
   886  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   887  	var stat light.TxStatus
   888  	// Looking the transaction in txpool first.
   889  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   890  
   891  	// If the transaction is unknown to the pool, try looking it up locally.
   892  	if stat.Status == core.TxStatusUnknown {
   893  		lookup := h.blockchain.GetTransactionLookup(hash)
   894  		if lookup != nil {
   895  			stat.Status = core.TxStatusIncluded
   896  			stat.Lookup = lookup
   897  		}
   898  	}
   899  	return stat
   900  }
   901  
   902  // broadcastHeaders broadcasts new block information to all connected light
   903  // clients. According to the agreement between client and server, server should
   904  // only broadcast new announcement if the total difficulty is higher than the
   905  // last one. Besides server will add the signature if client requires.
   906  func (h *serverHandler) broadcastHeaders() {
   907  	defer h.wg.Done()
   908  
   909  	headCh := make(chan core.ChainHeadEvent, 10)
   910  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   911  	defer headSub.Unsubscribe()
   912  
   913  	var (
   914  		lastHead *types.Header
   915  		lastTd   = common.Big0
   916  	)
   917  	for {
   918  		select {
   919  		case ev := <-headCh:
   920  			peers := h.server.peers.allPeers()
   921  			if len(peers) == 0 {
   922  				continue
   923  			}
   924  			header := ev.Block.Header()
   925  			hash, number := header.Hash(), header.Number.Uint64()
   926  			td := h.blockchain.GetTd(hash, number)
   927  			if td == nil || td.Cmp(lastTd) <= 0 {
   928  				continue
   929  			}
   930  			var reorg uint64
   931  			if lastHead != nil {
   932  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   933  			}
   934  			lastHead, lastTd = header, td
   935  
   936  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   937  			var (
   938  				signed         bool
   939  				signedAnnounce announceData
   940  			)
   941  			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
   942  			for _, p := range peers {
   943  				p := p
   944  				switch p.announceType {
   945  				case announceTypeSimple:
   946  					if !p.queueSend(func() { p.sendAnnounce(announce) }) {
   947  						log.Debug("Drop announcement because queue is full", "number", number, "hash", hash)
   948  					}
   949  				case announceTypeSigned:
   950  					if !signed {
   951  						signedAnnounce = announce
   952  						signedAnnounce.sign(h.server.privateKey)
   953  						signed = true
   954  					}
   955  					if !p.queueSend(func() { p.sendAnnounce(signedAnnounce) }) {
   956  						log.Debug("Drop announcement because queue is full", "number", number, "hash", hash)
   957  					}
   958  				}
   959  			}
   960  		case <-h.closeCh:
   961  			return
   962  		}
   963  	}
   964  }