github.com/cryptogateway/go-paymex@v0.0.0-20210204174735-96277fb1e602/les/server_handler.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"encoding/binary"
    22  	"encoding/json"
    23  	"errors"
    24  	"sync"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	"github.com/cryptogateway/go-paymex/common"
    29  	"github.com/cryptogateway/go-paymex/common/mclock"
    30  	"github.com/cryptogateway/go-paymex/core"
    31  	"github.com/cryptogateway/go-paymex/core/forkid"
    32  	"github.com/cryptogateway/go-paymex/core/rawdb"
    33  	"github.com/cryptogateway/go-paymex/core/state"
    34  	"github.com/cryptogateway/go-paymex/core/types"
    35  	"github.com/cryptogateway/go-paymex/ethdb"
    36  	lps "github.com/cryptogateway/go-paymex/les/lespay/server"
    37  	"github.com/cryptogateway/go-paymex/light"
    38  	"github.com/cryptogateway/go-paymex/log"
    39  	"github.com/cryptogateway/go-paymex/metrics"
    40  	"github.com/cryptogateway/go-paymex/p2p"
    41  	"github.com/cryptogateway/go-paymex/p2p/enode"
    42  	"github.com/cryptogateway/go-paymex/p2p/nodestate"
    43  	"github.com/cryptogateway/go-paymex/rlp"
    44  	"github.com/cryptogateway/go-paymex/trie"
    45  )
    46  
    47  const (
    48  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    49  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    50  	ethVersion        = 64              // equivalent eth version for the downloader
    51  
    52  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    53  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    54  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    55  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    56  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    57  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    58  	MaxTxSend                = 64  // Amount of transactions to be send per request
    59  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    60  )
    61  
    62  var (
    63  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    64  	errFullClientPool        = errors.New("client pool is full")
    65  )
    66  
    67  // serverHandler is responsible for serving light client and process
    68  // all incoming light requests.
    69  type serverHandler struct {
    70  	forkFilter forkid.Filter
    71  	blockchain *core.BlockChain
    72  	chainDb    ethdb.Database
    73  	txpool     *core.TxPool
    74  	server     *LesServer
    75  
    76  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    77  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    78  	synced  func() bool    // Callback function used to determine whether local node is synced.
    79  
    80  	// Testing fields
    81  	addTxsSync bool
    82  }
    83  
    84  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    85  	handler := &serverHandler{
    86  		forkFilter: forkid.NewFilter(blockchain),
    87  		server:     server,
    88  		blockchain: blockchain,
    89  		chainDb:    chainDb,
    90  		txpool:     txpool,
    91  		closeCh:    make(chan struct{}),
    92  		synced:     synced,
    93  	}
    94  	return handler
    95  }
    96  
    97  // start starts the server handler.
    98  func (h *serverHandler) start() {
    99  	h.wg.Add(1)
   100  	go h.broadcastLoop()
   101  }
   102  
   103  // stop stops the server handler.
   104  func (h *serverHandler) stop() {
   105  	close(h.closeCh)
   106  	h.wg.Wait()
   107  }
   108  
   109  // runPeer is the p2p protocol run function for the given version.
   110  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   111  	peer := newClientPeer(int(version), h.server.config.NetworkId, p, newMeteredMsgWriter(rw, int(version)))
   112  	defer peer.close()
   113  	h.wg.Add(1)
   114  	defer h.wg.Done()
   115  	return h.handle(peer)
   116  }
   117  
   118  func (h *serverHandler) handle(p *clientPeer) error {
   119  	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
   120  
   121  	// Execute the LES handshake
   122  	var (
   123  		head   = h.blockchain.CurrentHeader()
   124  		hash   = head.Hash()
   125  		number = head.Number.Uint64()
   126  		td     = h.blockchain.GetTd(hash, number)
   127  		forkID = forkid.NewID(h.blockchain.Config(), h.blockchain.Genesis().Hash(), h.blockchain.CurrentBlock().NumberU64())
   128  	)
   129  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), forkID, h.forkFilter, h.server); err != nil {
   130  		p.Log().Debug("Light Ethereum handshake failed", "err", err)
   131  		return err
   132  	}
   133  	// Reject the duplicated peer, otherwise register it to peerset.
   134  	var registered bool
   135  	if err := h.server.ns.Operation(func() {
   136  		if h.server.ns.GetField(p.Node(), clientPeerField) != nil {
   137  			registered = true
   138  		} else {
   139  			h.server.ns.SetFieldSub(p.Node(), clientPeerField, p)
   140  		}
   141  	}); err != nil {
   142  		return err
   143  	}
   144  	if registered {
   145  		return errAlreadyRegistered
   146  	}
   147  
   148  	defer func() {
   149  		h.server.ns.SetField(p.Node(), clientPeerField, nil)
   150  		if p.fcClient != nil { // is nil when connecting another server
   151  			p.fcClient.Disconnect()
   152  		}
   153  	}()
   154  	if p.server {
   155  		// connected to another server, no messages expected, just wait for disconnection
   156  		_, err := p.rw.ReadMsg()
   157  		return err
   158  	}
   159  	// Reject light clients if server is not synced.
   160  	//
   161  	// Put this checking here, so that "non-synced" les-server peers are still allowed
   162  	// to keep the connection.
   163  	if !h.synced() {
   164  		p.Log().Debug("Light server not synced, rejecting peer")
   165  		return p2p.DiscRequested
   166  	}
   167  	// Disconnect the inbound peer if it's rejected by clientPool
   168  	if cap, err := h.server.clientPool.connect(p); cap != p.fcParams.MinRecharge || err != nil {
   169  		p.Log().Debug("Light Ethereum peer rejected", "err", errFullClientPool)
   170  		return errFullClientPool
   171  	}
   172  	p.balance, _ = h.server.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance)
   173  	if p.balance == nil {
   174  		return p2p.DiscRequested
   175  	}
   176  	activeCount, _ := h.server.clientPool.pp.Active()
   177  	clientConnectionGauge.Update(int64(activeCount))
   178  
   179  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   180  
   181  	connectedAt := mclock.Now()
   182  	defer func() {
   183  		wg.Wait() // Ensure all background task routines have exited.
   184  		h.server.clientPool.disconnect(p)
   185  		p.balance = nil
   186  		activeCount, _ := h.server.clientPool.pp.Active()
   187  		clientConnectionGauge.Update(int64(activeCount))
   188  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   189  	}()
   190  	// Mark the peer starts to be served.
   191  	atomic.StoreUint32(&p.serving, 1)
   192  	defer atomic.StoreUint32(&p.serving, 0)
   193  
   194  	// Spawn a main loop to handle all incoming messages.
   195  	for {
   196  		select {
   197  		case err := <-p.errCh:
   198  			p.Log().Debug("Failed to send light ethereum response", "err", err)
   199  			return err
   200  		default:
   201  		}
   202  		if err := h.handleMsg(p, &wg); err != nil {
   203  			p.Log().Debug("Light Ethereum message handling failed", "err", err)
   204  			return err
   205  		}
   206  	}
   207  }
   208  
   209  // handleMsg is invoked whenever an inbound message is received from a remote
   210  // peer. The remote connection is torn down upon returning any error.
   211  func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
   212  	// Read the next message from the remote peer, and ensure it's fully consumed
   213  	msg, err := p.rw.ReadMsg()
   214  	if err != nil {
   215  		return err
   216  	}
   217  	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
   218  
   219  	// Discard large message which exceeds the limitation.
   220  	if msg.Size > ProtocolMaxMsgSize {
   221  		clientErrorMeter.Mark(1)
   222  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   223  	}
   224  	defer msg.Discard()
   225  
   226  	var (
   227  		maxCost uint64
   228  		task    *servingTask
   229  	)
   230  	p.responseCount++
   231  	responseCount := p.responseCount
   232  	// accept returns an indicator whether the request can be served.
   233  	// If so, deduct the max cost from the flow control buffer.
   234  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   235  		// Short circuit if the peer is already frozen or the request is invalid.
   236  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   237  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   238  			p.fcClient.OneTimeCost(inSizeCost)
   239  			return false
   240  		}
   241  		// Prepaid max cost units before request been serving.
   242  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   243  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   244  		if !accepted {
   245  			p.freeze()
   246  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   247  			p.fcClient.OneTimeCost(inSizeCost)
   248  			return false
   249  		}
   250  		// Create a multi-stage task, estimate the time it takes for the task to
   251  		// execute, and cache it in the request service queue.
   252  		factor := h.server.costTracker.globalFactor()
   253  		if factor < 0.001 {
   254  			factor = 1
   255  			p.Log().Error("Invalid global cost factor", "factor", factor)
   256  		}
   257  		maxTime := uint64(float64(maxCost) / factor)
   258  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   259  		if task.start() {
   260  			return true
   261  		}
   262  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   263  		return false
   264  	}
   265  	// sendResponse sends back the response and updates the flow control statistic.
   266  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   267  		p.responseLock.Lock()
   268  		defer p.responseLock.Unlock()
   269  
   270  		// Short circuit if the client is already frozen.
   271  		if p.isFrozen() {
   272  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   273  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   274  			return
   275  		}
   276  		// Positive correction buffer value with real cost.
   277  		var replySize uint32
   278  		if reply != nil {
   279  			replySize = reply.size()
   280  		}
   281  		var realCost uint64
   282  		if h.server.costTracker.testing {
   283  			realCost = maxCost // Assign a fake cost for testing purpose
   284  		} else {
   285  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   286  			if realCost > maxCost {
   287  				realCost = maxCost
   288  			}
   289  		}
   290  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   291  		if amount != 0 {
   292  			// Feed cost tracker request serving statistic.
   293  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   294  			// Reduce priority "balance" for the specific peer.
   295  			p.balance.RequestServed(realCost)
   296  		}
   297  		if reply != nil {
   298  			p.queueSend(func() {
   299  				if err := reply.send(bv); err != nil {
   300  					select {
   301  					case p.errCh <- err:
   302  					default:
   303  					}
   304  				}
   305  			})
   306  		}
   307  	}
   308  	switch msg.Code {
   309  	case GetBlockHeadersMsg:
   310  		p.Log().Trace("Received block header request")
   311  		if metrics.EnabledExpensive {
   312  			miscInHeaderPacketsMeter.Mark(1)
   313  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   314  		}
   315  		var req struct {
   316  			ReqID uint64
   317  			Query getBlockHeadersData
   318  		}
   319  		if err := msg.Decode(&req); err != nil {
   320  			clientErrorMeter.Mark(1)
   321  			return errResp(ErrDecode, "%v: %v", msg, err)
   322  		}
   323  		query := req.Query
   324  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   325  			wg.Add(1)
   326  			go func() {
   327  				defer wg.Done()
   328  				hashMode := query.Origin.Hash != (common.Hash{})
   329  				first := true
   330  				maxNonCanonical := uint64(100)
   331  
   332  				// Gather headers until the fetch or network limits is reached
   333  				var (
   334  					bytes   common.StorageSize
   335  					headers []*types.Header
   336  					unknown bool
   337  				)
   338  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   339  					if !first && !task.waitOrStop() {
   340  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   341  						return
   342  					}
   343  					// Retrieve the next header satisfying the query
   344  					var origin *types.Header
   345  					if hashMode {
   346  						if first {
   347  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   348  							if origin != nil {
   349  								query.Origin.Number = origin.Number.Uint64()
   350  							}
   351  						} else {
   352  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   353  						}
   354  					} else {
   355  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   356  					}
   357  					if origin == nil {
   358  						break
   359  					}
   360  					headers = append(headers, origin)
   361  					bytes += estHeaderRlpSize
   362  
   363  					// Advance to the next header of the query
   364  					switch {
   365  					case hashMode && query.Reverse:
   366  						// Hash based traversal towards the genesis block
   367  						ancestor := query.Skip + 1
   368  						if ancestor == 0 {
   369  							unknown = true
   370  						} else {
   371  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   372  							unknown = query.Origin.Hash == common.Hash{}
   373  						}
   374  					case hashMode && !query.Reverse:
   375  						// Hash based traversal towards the leaf block
   376  						var (
   377  							current = origin.Number.Uint64()
   378  							next    = current + query.Skip + 1
   379  						)
   380  						if next <= current {
   381  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   382  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   383  							unknown = true
   384  						} else {
   385  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   386  								nextHash := header.Hash()
   387  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   388  								if expOldHash == query.Origin.Hash {
   389  									query.Origin.Hash, query.Origin.Number = nextHash, next
   390  								} else {
   391  									unknown = true
   392  								}
   393  							} else {
   394  								unknown = true
   395  							}
   396  						}
   397  					case query.Reverse:
   398  						// Number based traversal towards the genesis block
   399  						if query.Origin.Number >= query.Skip+1 {
   400  							query.Origin.Number -= query.Skip + 1
   401  						} else {
   402  							unknown = true
   403  						}
   404  
   405  					case !query.Reverse:
   406  						// Number based traversal towards the leaf block
   407  						query.Origin.Number += query.Skip + 1
   408  					}
   409  					first = false
   410  				}
   411  				reply := p.replyBlockHeaders(req.ReqID, headers)
   412  				sendResponse(req.ReqID, query.Amount, reply, task.done())
   413  				if metrics.EnabledExpensive {
   414  					miscOutHeaderPacketsMeter.Mark(1)
   415  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   416  					miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime))
   417  				}
   418  			}()
   419  		}
   420  
   421  	case GetBlockBodiesMsg:
   422  		p.Log().Trace("Received block bodies request")
   423  		if metrics.EnabledExpensive {
   424  			miscInBodyPacketsMeter.Mark(1)
   425  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   426  		}
   427  		var req struct {
   428  			ReqID  uint64
   429  			Hashes []common.Hash
   430  		}
   431  		if err := msg.Decode(&req); err != nil {
   432  			clientErrorMeter.Mark(1)
   433  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   434  		}
   435  		var (
   436  			bytes  int
   437  			bodies []rlp.RawValue
   438  		)
   439  		reqCnt := len(req.Hashes)
   440  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   441  			wg.Add(1)
   442  			go func() {
   443  				defer wg.Done()
   444  				for i, hash := range req.Hashes {
   445  					if i != 0 && !task.waitOrStop() {
   446  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   447  						return
   448  					}
   449  					if bytes >= softResponseLimit {
   450  						break
   451  					}
   452  					body := h.blockchain.GetBodyRLP(hash)
   453  					if body == nil {
   454  						p.bumpInvalid()
   455  						continue
   456  					}
   457  					bodies = append(bodies, body)
   458  					bytes += len(body)
   459  				}
   460  				reply := p.replyBlockBodiesRLP(req.ReqID, bodies)
   461  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   462  				if metrics.EnabledExpensive {
   463  					miscOutBodyPacketsMeter.Mark(1)
   464  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   465  					miscServingTimeBodyTimer.Update(time.Duration(task.servingTime))
   466  				}
   467  			}()
   468  		}
   469  
   470  	case GetCodeMsg:
   471  		p.Log().Trace("Received code request")
   472  		if metrics.EnabledExpensive {
   473  			miscInCodePacketsMeter.Mark(1)
   474  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   475  		}
   476  		var req struct {
   477  			ReqID uint64
   478  			Reqs  []CodeReq
   479  		}
   480  		if err := msg.Decode(&req); err != nil {
   481  			clientErrorMeter.Mark(1)
   482  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   483  		}
   484  		var (
   485  			bytes int
   486  			data  [][]byte
   487  		)
   488  		reqCnt := len(req.Reqs)
   489  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   490  			wg.Add(1)
   491  			go func() {
   492  				defer wg.Done()
   493  				for i, request := range req.Reqs {
   494  					if i != 0 && !task.waitOrStop() {
   495  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   496  						return
   497  					}
   498  					// Look up the root hash belonging to the request
   499  					header := h.blockchain.GetHeaderByHash(request.BHash)
   500  					if header == nil {
   501  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   502  						p.bumpInvalid()
   503  						continue
   504  					}
   505  					// Refuse to search stale state data in the database since looking for
   506  					// a non-exist key is kind of expensive.
   507  					local := h.blockchain.CurrentHeader().Number.Uint64()
   508  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   509  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   510  						p.bumpInvalid()
   511  						continue
   512  					}
   513  					triedb := h.blockchain.StateCache().TrieDB()
   514  
   515  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   516  					if err != nil {
   517  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   518  						p.bumpInvalid()
   519  						continue
   520  					}
   521  					code, err := h.blockchain.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash))
   522  					if err != nil {
   523  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   524  						continue
   525  					}
   526  					// Accumulate the code and abort if enough data was retrieved
   527  					data = append(data, code)
   528  					if bytes += len(code); bytes >= softResponseLimit {
   529  						break
   530  					}
   531  				}
   532  				reply := p.replyCode(req.ReqID, data)
   533  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   534  				if metrics.EnabledExpensive {
   535  					miscOutCodePacketsMeter.Mark(1)
   536  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   537  					miscServingTimeCodeTimer.Update(time.Duration(task.servingTime))
   538  				}
   539  			}()
   540  		}
   541  
   542  	case GetReceiptsMsg:
   543  		p.Log().Trace("Received receipts request")
   544  		if metrics.EnabledExpensive {
   545  			miscInReceiptPacketsMeter.Mark(1)
   546  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   547  		}
   548  		var req struct {
   549  			ReqID  uint64
   550  			Hashes []common.Hash
   551  		}
   552  		if err := msg.Decode(&req); err != nil {
   553  			clientErrorMeter.Mark(1)
   554  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   555  		}
   556  		var (
   557  			bytes    int
   558  			receipts []rlp.RawValue
   559  		)
   560  		reqCnt := len(req.Hashes)
   561  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   562  			wg.Add(1)
   563  			go func() {
   564  				defer wg.Done()
   565  				for i, hash := range req.Hashes {
   566  					if i != 0 && !task.waitOrStop() {
   567  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   568  						return
   569  					}
   570  					if bytes >= softResponseLimit {
   571  						break
   572  					}
   573  					// Retrieve the requested block's receipts, skipping if unknown to us
   574  					results := h.blockchain.GetReceiptsByHash(hash)
   575  					if results == nil {
   576  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   577  							p.bumpInvalid()
   578  							continue
   579  						}
   580  					}
   581  					// If known, encode and queue for response packet
   582  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   583  						log.Error("Failed to encode receipt", "err", err)
   584  					} else {
   585  						receipts = append(receipts, encoded)
   586  						bytes += len(encoded)
   587  					}
   588  				}
   589  				reply := p.replyReceiptsRLP(req.ReqID, receipts)
   590  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   591  				if metrics.EnabledExpensive {
   592  					miscOutReceiptPacketsMeter.Mark(1)
   593  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   594  					miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime))
   595  				}
   596  			}()
   597  		}
   598  
   599  	case GetProofsV2Msg:
   600  		p.Log().Trace("Received les/2 proofs request")
   601  		if metrics.EnabledExpensive {
   602  			miscInTrieProofPacketsMeter.Mark(1)
   603  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   604  		}
   605  		var req struct {
   606  			ReqID uint64
   607  			Reqs  []ProofReq
   608  		}
   609  		if err := msg.Decode(&req); err != nil {
   610  			clientErrorMeter.Mark(1)
   611  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   612  		}
   613  		// Gather state data until the fetch or network limits is reached
   614  		var (
   615  			lastBHash common.Hash
   616  			root      common.Hash
   617  			header    *types.Header
   618  		)
   619  		reqCnt := len(req.Reqs)
   620  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   621  			wg.Add(1)
   622  			go func() {
   623  				defer wg.Done()
   624  				nodes := light.NewNodeSet()
   625  
   626  				for i, request := range req.Reqs {
   627  					if i != 0 && !task.waitOrStop() {
   628  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   629  						return
   630  					}
   631  					// Look up the root hash belonging to the request
   632  					if request.BHash != lastBHash {
   633  						root, lastBHash = common.Hash{}, request.BHash
   634  
   635  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   636  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   637  							p.bumpInvalid()
   638  							continue
   639  						}
   640  						// Refuse to search stale state data in the database since looking for
   641  						// a non-exist key is kind of expensive.
   642  						local := h.blockchain.CurrentHeader().Number.Uint64()
   643  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   644  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   645  							p.bumpInvalid()
   646  							continue
   647  						}
   648  						root = header.Root
   649  					}
   650  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   651  					if root == (common.Hash{}) {
   652  						p.bumpInvalid()
   653  						continue
   654  					}
   655  					// Open the account or storage trie for the request
   656  					statedb := h.blockchain.StateCache()
   657  
   658  					var trie state.Trie
   659  					switch len(request.AccKey) {
   660  					case 0:
   661  						// No account key specified, open an account trie
   662  						trie, err = statedb.OpenTrie(root)
   663  						if trie == nil || err != nil {
   664  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   665  							continue
   666  						}
   667  					default:
   668  						// Account key specified, open a storage trie
   669  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   670  						if err != nil {
   671  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   672  							p.bumpInvalid()
   673  							continue
   674  						}
   675  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   676  						if trie == nil || err != nil {
   677  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   678  							continue
   679  						}
   680  					}
   681  					// Prove the user's request from the account or stroage trie
   682  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   683  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   684  						continue
   685  					}
   686  					if nodes.DataSize() >= softResponseLimit {
   687  						break
   688  					}
   689  				}
   690  				reply := p.replyProofsV2(req.ReqID, nodes.NodeList())
   691  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   692  				if metrics.EnabledExpensive {
   693  					miscOutTrieProofPacketsMeter.Mark(1)
   694  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   695  					miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime))
   696  				}
   697  			}()
   698  		}
   699  
   700  	case GetHelperTrieProofsMsg:
   701  		p.Log().Trace("Received helper trie proof request")
   702  		if metrics.EnabledExpensive {
   703  			miscInHelperTriePacketsMeter.Mark(1)
   704  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   705  		}
   706  		var req struct {
   707  			ReqID uint64
   708  			Reqs  []HelperTrieReq
   709  		}
   710  		if err := msg.Decode(&req); err != nil {
   711  			clientErrorMeter.Mark(1)
   712  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   713  		}
   714  		// Gather state data until the fetch or network limits is reached
   715  		var (
   716  			auxBytes int
   717  			auxData  [][]byte
   718  		)
   719  		reqCnt := len(req.Reqs)
   720  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   721  			wg.Add(1)
   722  			go func() {
   723  				defer wg.Done()
   724  				var (
   725  					lastIdx  uint64
   726  					lastType uint
   727  					root     common.Hash
   728  					auxTrie  *trie.Trie
   729  				)
   730  				nodes := light.NewNodeSet()
   731  				for i, request := range req.Reqs {
   732  					if i != 0 && !task.waitOrStop() {
   733  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   734  						return
   735  					}
   736  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   737  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   738  
   739  						var prefix string
   740  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   741  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   742  						}
   743  					}
   744  					if auxTrie == nil {
   745  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   746  						return
   747  					}
   748  					// TODO(rjl493456442) short circuit if the proving is failed.
   749  					// The original client side code has a dirty hack to retrieve
   750  					// the headers with no valid proof. Keep the compatibility for
   751  					// legacy les protocol and drop this hack when the les2/3 are
   752  					// not supported.
   753  					err := auxTrie.Prove(request.Key, request.FromLevel, nodes)
   754  					if p.version >= lpv4 && err != nil {
   755  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   756  						return
   757  					}
   758  					if request.AuxReq == htAuxHeader {
   759  						data := h.getAuxiliaryHeaders(request)
   760  						auxData = append(auxData, data)
   761  						auxBytes += len(data)
   762  					}
   763  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   764  						break
   765  					}
   766  				}
   767  				reply := p.replyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   768  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   769  				if metrics.EnabledExpensive {
   770  					miscOutHelperTriePacketsMeter.Mark(1)
   771  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   772  					miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime))
   773  				}
   774  			}()
   775  		}
   776  
   777  	case SendTxV2Msg:
   778  		p.Log().Trace("Received new transactions")
   779  		if metrics.EnabledExpensive {
   780  			miscInTxsPacketsMeter.Mark(1)
   781  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   782  		}
   783  		var req struct {
   784  			ReqID uint64
   785  			Txs   []*types.Transaction
   786  		}
   787  		if err := msg.Decode(&req); err != nil {
   788  			clientErrorMeter.Mark(1)
   789  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   790  		}
   791  		reqCnt := len(req.Txs)
   792  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   793  			wg.Add(1)
   794  			go func() {
   795  				defer wg.Done()
   796  				stats := make([]light.TxStatus, len(req.Txs))
   797  				for i, tx := range req.Txs {
   798  					if i != 0 && !task.waitOrStop() {
   799  						return
   800  					}
   801  					hash := tx.Hash()
   802  					stats[i] = h.txStatus(hash)
   803  					if stats[i].Status == core.TxStatusUnknown {
   804  						addFn := h.txpool.AddRemotes
   805  						// Add txs synchronously for testing purpose
   806  						if h.addTxsSync {
   807  							addFn = h.txpool.AddRemotesSync
   808  						}
   809  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   810  							stats[i].Error = errs[0].Error()
   811  							continue
   812  						}
   813  						stats[i] = h.txStatus(hash)
   814  					}
   815  				}
   816  				reply := p.replyTxStatus(req.ReqID, stats)
   817  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   818  				if metrics.EnabledExpensive {
   819  					miscOutTxsPacketsMeter.Mark(1)
   820  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   821  					miscServingTimeTxTimer.Update(time.Duration(task.servingTime))
   822  				}
   823  			}()
   824  		}
   825  
   826  	case GetTxStatusMsg:
   827  		p.Log().Trace("Received transaction status query request")
   828  		if metrics.EnabledExpensive {
   829  			miscInTxStatusPacketsMeter.Mark(1)
   830  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   831  		}
   832  		var req struct {
   833  			ReqID  uint64
   834  			Hashes []common.Hash
   835  		}
   836  		if err := msg.Decode(&req); err != nil {
   837  			clientErrorMeter.Mark(1)
   838  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   839  		}
   840  		reqCnt := len(req.Hashes)
   841  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   842  			wg.Add(1)
   843  			go func() {
   844  				defer wg.Done()
   845  				stats := make([]light.TxStatus, len(req.Hashes))
   846  				for i, hash := range req.Hashes {
   847  					if i != 0 && !task.waitOrStop() {
   848  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   849  						return
   850  					}
   851  					stats[i] = h.txStatus(hash)
   852  				}
   853  				reply := p.replyTxStatus(req.ReqID, stats)
   854  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   855  				if metrics.EnabledExpensive {
   856  					miscOutTxStatusPacketsMeter.Mark(1)
   857  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   858  					miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime))
   859  				}
   860  			}()
   861  		}
   862  
   863  	default:
   864  		p.Log().Trace("Received invalid message", "code", msg.Code)
   865  		clientErrorMeter.Mark(1)
   866  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   867  	}
   868  	// If the client has made too much invalid request(e.g. request a non-existent data),
   869  	// reject them to prevent SPAM attack.
   870  	if p.getInvalid() > maxRequestErrors {
   871  		clientErrorMeter.Mark(1)
   872  		return errTooManyInvalidRequest
   873  	}
   874  	return nil
   875  }
   876  
   877  // getAccount retrieves an account from the state based on root.
   878  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   879  	trie, err := trie.New(root, triedb)
   880  	if err != nil {
   881  		return state.Account{}, err
   882  	}
   883  	blob, err := trie.TryGet(hash[:])
   884  	if err != nil {
   885  		return state.Account{}, err
   886  	}
   887  	var account state.Account
   888  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   889  		return state.Account{}, err
   890  	}
   891  	return account, nil
   892  }
   893  
   894  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   895  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   896  	switch typ {
   897  	case htCanonical:
   898  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   899  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   900  	case htBloomBits:
   901  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   902  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   903  	}
   904  	return common.Hash{}, ""
   905  }
   906  
   907  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   908  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   909  	if req.Type == htCanonical && req.AuxReq == htAuxHeader && len(req.Key) == 8 {
   910  		blockNum := binary.BigEndian.Uint64(req.Key)
   911  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   912  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   913  	}
   914  	return nil
   915  }
   916  
   917  // txStatus returns the status of a specified transaction.
   918  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   919  	var stat light.TxStatus
   920  	// Looking the transaction in txpool first.
   921  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   922  
   923  	// If the transaction is unknown to the pool, try looking it up locally.
   924  	if stat.Status == core.TxStatusUnknown {
   925  		lookup := h.blockchain.GetTransactionLookup(hash)
   926  		if lookup != nil {
   927  			stat.Status = core.TxStatusIncluded
   928  			stat.Lookup = lookup
   929  		}
   930  	}
   931  	return stat
   932  }
   933  
   934  // broadcastLoop broadcasts new block information to all connected light
   935  // clients. According to the agreement between client and server, server should
   936  // only broadcast new announcement if the total difficulty is higher than the
   937  // last one. Besides server will add the signature if client requires.
   938  func (h *serverHandler) broadcastLoop() {
   939  	defer h.wg.Done()
   940  
   941  	headCh := make(chan core.ChainHeadEvent, 10)
   942  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   943  	defer headSub.Unsubscribe()
   944  
   945  	var (
   946  		lastHead *types.Header
   947  		lastTd   = common.Big0
   948  	)
   949  	for {
   950  		select {
   951  		case ev := <-headCh:
   952  			header := ev.Block.Header()
   953  			hash, number := header.Hash(), header.Number.Uint64()
   954  			td := h.blockchain.GetTd(hash, number)
   955  			if td == nil || td.Cmp(lastTd) <= 0 {
   956  				continue
   957  			}
   958  			var reorg uint64
   959  			if lastHead != nil {
   960  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   961  			}
   962  			lastHead, lastTd = header, td
   963  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   964  			h.server.broadcaster.broadcast(announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg})
   965  		case <-h.closeCh:
   966  			return
   967  		}
   968  	}
   969  }
   970  
   971  // broadcaster sends new header announcements to active client peers
   972  type broadcaster struct {
   973  	ns                           *nodestate.NodeStateMachine
   974  	privateKey                   *ecdsa.PrivateKey
   975  	lastAnnounce, signedAnnounce announceData
   976  }
   977  
   978  // newBroadcaster creates a new broadcaster
   979  func newBroadcaster(ns *nodestate.NodeStateMachine) *broadcaster {
   980  	b := &broadcaster{ns: ns}
   981  	ns.SubscribeState(priorityPoolSetup.ActiveFlag, func(node *enode.Node, oldState, newState nodestate.Flags) {
   982  		if newState.Equals(priorityPoolSetup.ActiveFlag) {
   983  			// send last announcement to activated peers
   984  			b.sendTo(node)
   985  		}
   986  	})
   987  	return b
   988  }
   989  
   990  // setSignerKey sets the signer key for signed announcements. Should be called before
   991  // starting the protocol handler.
   992  func (b *broadcaster) setSignerKey(privateKey *ecdsa.PrivateKey) {
   993  	b.privateKey = privateKey
   994  }
   995  
   996  // broadcast sends the given announcements to all active peers
   997  func (b *broadcaster) broadcast(announce announceData) {
   998  	b.ns.Operation(func() {
   999  		// iterate in an Operation to ensure that the active set does not change while iterating
  1000  		b.lastAnnounce = announce
  1001  		b.ns.ForEach(priorityPoolSetup.ActiveFlag, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
  1002  			b.sendTo(node)
  1003  		})
  1004  	})
  1005  }
  1006  
  1007  // sendTo sends the most recent announcement to the given node unless the same or higher Td
  1008  // announcement has already been sent.
  1009  func (b *broadcaster) sendTo(node *enode.Node) {
  1010  	if b.lastAnnounce.Td == nil {
  1011  		return
  1012  	}
  1013  	if p, _ := b.ns.GetField(node, clientPeerField).(*clientPeer); p != nil {
  1014  		if p.headInfo.Td == nil || b.lastAnnounce.Td.Cmp(p.headInfo.Td) > 0 {
  1015  			announce := b.lastAnnounce
  1016  			switch p.announceType {
  1017  			case announceTypeSimple:
  1018  				if !p.queueSend(func() { p.sendAnnounce(announce) }) {
  1019  					log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash)
  1020  				} else {
  1021  					log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash)
  1022  				}
  1023  			case announceTypeSigned:
  1024  				if b.signedAnnounce.Hash != b.lastAnnounce.Hash {
  1025  					b.signedAnnounce = b.lastAnnounce
  1026  					b.signedAnnounce.sign(b.privateKey)
  1027  				}
  1028  				announce := b.signedAnnounce
  1029  				if !p.queueSend(func() { p.sendAnnounce(announce) }) {
  1030  					log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash)
  1031  				} else {
  1032  					log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash)
  1033  				}
  1034  			}
  1035  			p.headInfo = blockInfo{b.lastAnnounce.Hash, b.lastAnnounce.Number, b.lastAnnounce.Td}
  1036  		}
  1037  	}
  1038  }