github.com/ethereum-optimism/optimism/l2geth@v0.0.0-20230612200230-50b04ade19e3/les/server_handler.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/ethereum-optimism/optimism/l2geth/common"
    28  	"github.com/ethereum-optimism/optimism/l2geth/common/mclock"
    29  	"github.com/ethereum-optimism/optimism/l2geth/core"
    30  	"github.com/ethereum-optimism/optimism/l2geth/core/rawdb"
    31  	"github.com/ethereum-optimism/optimism/l2geth/core/state"
    32  	"github.com/ethereum-optimism/optimism/l2geth/core/types"
    33  	"github.com/ethereum-optimism/optimism/l2geth/ethdb"
    34  	"github.com/ethereum-optimism/optimism/l2geth/light"
    35  	"github.com/ethereum-optimism/optimism/l2geth/log"
    36  	"github.com/ethereum-optimism/optimism/l2geth/metrics"
    37  	"github.com/ethereum-optimism/optimism/l2geth/p2p"
    38  	"github.com/ethereum-optimism/optimism/l2geth/rlp"
    39  	"github.com/ethereum-optimism/optimism/l2geth/trie"
    40  )
    41  
    42  const (
    43  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    44  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    45  	ethVersion        = 63              // equivalent eth version for the downloader
    46  
    47  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    48  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    49  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    50  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    51  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    52  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    53  	MaxTxSend                = 64  // Amount of transactions to be send per request
    54  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    55  )
    56  
    57  var (
    58  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    59  	errFullClientPool        = errors.New("client pool is full")
    60  )
    61  
    62  // serverHandler is responsible for serving light client and process
    63  // all incoming light requests.
    64  type serverHandler struct {
    65  	blockchain *core.BlockChain
    66  	chainDb    ethdb.Database
    67  	txpool     *core.TxPool
    68  	server     *LesServer
    69  
    70  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    71  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    72  	synced  func() bool    // Callback function used to determine whether local node is synced.
    73  
    74  	// Testing fields
    75  	addTxsSync bool
    76  }
    77  
    78  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    79  	handler := &serverHandler{
    80  		server:     server,
    81  		blockchain: blockchain,
    82  		chainDb:    chainDb,
    83  		txpool:     txpool,
    84  		closeCh:    make(chan struct{}),
    85  		synced:     synced,
    86  	}
    87  	return handler
    88  }
    89  
    90  // start starts the server handler.
    91  func (h *serverHandler) start() {
    92  	h.wg.Add(1)
    93  	go h.broadcastHeaders()
    94  }
    95  
    96  // stop stops the server handler.
    97  func (h *serverHandler) stop() {
    98  	close(h.closeCh)
    99  	h.wg.Wait()
   100  }
   101  
   102  // runPeer is the p2p protocol run function for the given version.
   103  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   104  	peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version)))
   105  	h.wg.Add(1)
   106  	defer h.wg.Done()
   107  	return h.handle(peer)
   108  }
   109  
   110  func (h *serverHandler) handle(p *peer) error {
   111  	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
   112  
   113  	// Execute the LES handshake
   114  	var (
   115  		head   = h.blockchain.CurrentHeader()
   116  		hash   = head.Hash()
   117  		number = head.Number.Uint64()
   118  		td     = h.blockchain.GetTd(hash, number)
   119  	)
   120  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
   121  		p.Log().Debug("Light Ethereum handshake failed", "err", err)
   122  		return err
   123  	}
   124  	if p.server {
   125  		// connected to another server, no messages expected, just wait for disconnection
   126  		_, err := p.rw.ReadMsg()
   127  		return err
   128  	}
   129  	// Reject light clients if server is not synced.
   130  	if !h.synced() {
   131  		return p2p.DiscRequested
   132  	}
   133  	defer p.fcClient.Disconnect()
   134  
   135  	// Disconnect the inbound peer if it's rejected by clientPool
   136  	if !h.server.clientPool.connect(p, 0) {
   137  		p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
   138  		return errFullClientPool
   139  	}
   140  	// Register the peer locally
   141  	if err := h.server.peers.Register(p); err != nil {
   142  		h.server.clientPool.disconnect(p)
   143  		p.Log().Error("Light Ethereum peer registration failed", "err", err)
   144  		return err
   145  	}
   146  	clientConnectionGauge.Update(int64(h.server.peers.Len()))
   147  
   148  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   149  
   150  	connectedAt := mclock.Now()
   151  	defer func() {
   152  		wg.Wait() // Ensure all background task routines have exited.
   153  		h.server.peers.Unregister(p.id)
   154  		h.server.clientPool.disconnect(p)
   155  		clientConnectionGauge.Update(int64(h.server.peers.Len()))
   156  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   157  	}()
   158  
   159  	// Spawn a main loop to handle all incoming messages.
   160  	for {
   161  		select {
   162  		case err := <-p.errCh:
   163  			p.Log().Debug("Failed to send light ethereum response", "err", err)
   164  			return err
   165  		default:
   166  		}
   167  		if err := h.handleMsg(p, &wg); err != nil {
   168  			p.Log().Debug("Light Ethereum message handling failed", "err", err)
   169  			return err
   170  		}
   171  	}
   172  }
   173  
   174  // handleMsg is invoked whenever an inbound message is received from a remote
   175  // peer. The remote connection is torn down upon returning any error.
   176  func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error {
   177  	// Read the next message from the remote peer, and ensure it's fully consumed
   178  	msg, err := p.rw.ReadMsg()
   179  	if err != nil {
   180  		return err
   181  	}
   182  	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
   183  
   184  	// Discard large message which exceeds the limitation.
   185  	if msg.Size > ProtocolMaxMsgSize {
   186  		clientErrorMeter.Mark(1)
   187  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   188  	}
   189  	defer msg.Discard()
   190  
   191  	var (
   192  		maxCost uint64
   193  		task    *servingTask
   194  	)
   195  	p.responseCount++
   196  	responseCount := p.responseCount
   197  	// accept returns an indicator whether the request can be served.
   198  	// If so, deduct the max cost from the flow control buffer.
   199  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   200  		// Short circuit if the peer is already frozen or the request is invalid.
   201  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   202  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   203  			p.fcClient.OneTimeCost(inSizeCost)
   204  			return false
   205  		}
   206  		// Prepaid max cost units before request been serving.
   207  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   208  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   209  		if !accepted {
   210  			p.freezeClient()
   211  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   212  			p.fcClient.OneTimeCost(inSizeCost)
   213  			return false
   214  		}
   215  		// Create a multi-stage task, estimate the time it takes for the task to
   216  		// execute, and cache it in the request service queue.
   217  		factor := h.server.costTracker.globalFactor()
   218  		if factor < 0.001 {
   219  			factor = 1
   220  			p.Log().Error("Invalid global cost factor", "factor", factor)
   221  		}
   222  		maxTime := uint64(float64(maxCost) / factor)
   223  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   224  		if task.start() {
   225  			return true
   226  		}
   227  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   228  		return false
   229  	}
   230  	// sendResponse sends back the response and updates the flow control statistic.
   231  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   232  		p.responseLock.Lock()
   233  		defer p.responseLock.Unlock()
   234  
   235  		// Short circuit if the client is already frozen.
   236  		if p.isFrozen() {
   237  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   238  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   239  			return
   240  		}
   241  		// Positive correction buffer value with real cost.
   242  		var replySize uint32
   243  		if reply != nil {
   244  			replySize = reply.size()
   245  		}
   246  		var realCost uint64
   247  		if h.server.costTracker.testing {
   248  			realCost = maxCost // Assign a fake cost for testing purpose
   249  		} else {
   250  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   251  		}
   252  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   253  		if amount != 0 {
   254  			// Feed cost tracker request serving statistic.
   255  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   256  			// Reduce priority "balance" for the specific peer.
   257  			h.server.clientPool.requestCost(p, realCost)
   258  		}
   259  		if reply != nil {
   260  			p.queueSend(func() {
   261  				if err := reply.send(bv); err != nil {
   262  					select {
   263  					case p.errCh <- err:
   264  					default:
   265  					}
   266  				}
   267  			})
   268  		}
   269  	}
   270  	switch msg.Code {
   271  	case GetBlockHeadersMsg:
   272  		p.Log().Trace("Received block header request")
   273  		if metrics.EnabledExpensive {
   274  			miscInHeaderPacketsMeter.Mark(1)
   275  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   276  		}
   277  		var req struct {
   278  			ReqID uint64
   279  			Query getBlockHeadersData
   280  		}
   281  		if err := msg.Decode(&req); err != nil {
   282  			clientErrorMeter.Mark(1)
   283  			return errResp(ErrDecode, "%v: %v", msg, err)
   284  		}
   285  		query := req.Query
   286  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   287  			wg.Add(1)
   288  			go func() {
   289  				defer wg.Done()
   290  				hashMode := query.Origin.Hash != (common.Hash{})
   291  				first := true
   292  				maxNonCanonical := uint64(100)
   293  
   294  				// Gather headers until the fetch or network limits is reached
   295  				var (
   296  					bytes   common.StorageSize
   297  					headers []*types.Header
   298  					unknown bool
   299  				)
   300  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   301  					if !first && !task.waitOrStop() {
   302  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   303  						return
   304  					}
   305  					// Retrieve the next header satisfying the query
   306  					var origin *types.Header
   307  					if hashMode {
   308  						if first {
   309  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   310  							if origin != nil {
   311  								query.Origin.Number = origin.Number.Uint64()
   312  							}
   313  						} else {
   314  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   315  						}
   316  					} else {
   317  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   318  					}
   319  					if origin == nil {
   320  						atomic.AddUint32(&p.invalidCount, 1)
   321  						break
   322  					}
   323  					headers = append(headers, origin)
   324  					bytes += estHeaderRlpSize
   325  
   326  					// Advance to the next header of the query
   327  					switch {
   328  					case hashMode && query.Reverse:
   329  						// Hash based traversal towards the genesis block
   330  						ancestor := query.Skip + 1
   331  						if ancestor == 0 {
   332  							unknown = true
   333  						} else {
   334  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   335  							unknown = query.Origin.Hash == common.Hash{}
   336  						}
   337  					case hashMode && !query.Reverse:
   338  						// Hash based traversal towards the leaf block
   339  						var (
   340  							current = origin.Number.Uint64()
   341  							next    = current + query.Skip + 1
   342  						)
   343  						if next <= current {
   344  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   345  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   346  							unknown = true
   347  						} else {
   348  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   349  								nextHash := header.Hash()
   350  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   351  								if expOldHash == query.Origin.Hash {
   352  									query.Origin.Hash, query.Origin.Number = nextHash, next
   353  								} else {
   354  									unknown = true
   355  								}
   356  							} else {
   357  								unknown = true
   358  							}
   359  						}
   360  					case query.Reverse:
   361  						// Number based traversal towards the genesis block
   362  						if query.Origin.Number >= query.Skip+1 {
   363  							query.Origin.Number -= query.Skip + 1
   364  						} else {
   365  							unknown = true
   366  						}
   367  
   368  					case !query.Reverse:
   369  						// Number based traversal towards the leaf block
   370  						query.Origin.Number += query.Skip + 1
   371  					}
   372  					first = false
   373  				}
   374  				reply := p.ReplyBlockHeaders(req.ReqID, headers)
   375  				sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
   376  				if metrics.EnabledExpensive {
   377  					miscOutHeaderPacketsMeter.Mark(1)
   378  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   379  					miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime))
   380  				}
   381  			}()
   382  		}
   383  
   384  	case GetBlockBodiesMsg:
   385  		p.Log().Trace("Received block bodies request")
   386  		if metrics.EnabledExpensive {
   387  			miscInBodyPacketsMeter.Mark(1)
   388  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   389  		}
   390  		var req struct {
   391  			ReqID  uint64
   392  			Hashes []common.Hash
   393  		}
   394  		if err := msg.Decode(&req); err != nil {
   395  			clientErrorMeter.Mark(1)
   396  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   397  		}
   398  		var (
   399  			bytes  int
   400  			bodies []rlp.RawValue
   401  		)
   402  		reqCnt := len(req.Hashes)
   403  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   404  			wg.Add(1)
   405  			go func() {
   406  				defer wg.Done()
   407  				for i, hash := range req.Hashes {
   408  					if i != 0 && !task.waitOrStop() {
   409  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   410  						return
   411  					}
   412  					if bytes >= softResponseLimit {
   413  						break
   414  					}
   415  					body := h.blockchain.GetBodyRLP(hash)
   416  					if body == nil {
   417  						atomic.AddUint32(&p.invalidCount, 1)
   418  						continue
   419  					}
   420  					bodies = append(bodies, body)
   421  					bytes += len(body)
   422  				}
   423  				reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies)
   424  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   425  				if metrics.EnabledExpensive {
   426  					miscOutBodyPacketsMeter.Mark(1)
   427  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   428  					miscServingTimeBodyTimer.Update(time.Duration(task.servingTime))
   429  				}
   430  			}()
   431  		}
   432  
   433  	case GetCodeMsg:
   434  		p.Log().Trace("Received code request")
   435  		if metrics.EnabledExpensive {
   436  			miscInCodePacketsMeter.Mark(1)
   437  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   438  		}
   439  		var req struct {
   440  			ReqID uint64
   441  			Reqs  []CodeReq
   442  		}
   443  		if err := msg.Decode(&req); err != nil {
   444  			clientErrorMeter.Mark(1)
   445  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   446  		}
   447  		var (
   448  			bytes int
   449  			data  [][]byte
   450  		)
   451  		reqCnt := len(req.Reqs)
   452  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   453  			wg.Add(1)
   454  			go func() {
   455  				defer wg.Done()
   456  				for i, request := range req.Reqs {
   457  					if i != 0 && !task.waitOrStop() {
   458  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   459  						return
   460  					}
   461  					// Look up the root hash belonging to the request
   462  					header := h.blockchain.GetHeaderByHash(request.BHash)
   463  					if header == nil {
   464  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   465  						atomic.AddUint32(&p.invalidCount, 1)
   466  						continue
   467  					}
   468  					// Refuse to search stale state data in the database since looking for
   469  					// a non-exist key is kind of expensive.
   470  					local := h.blockchain.CurrentHeader().Number.Uint64()
   471  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   472  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   473  						atomic.AddUint32(&p.invalidCount, 1)
   474  						continue
   475  					}
   476  					triedb := h.blockchain.StateCache().TrieDB()
   477  
   478  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   479  					if err != nil {
   480  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   481  						atomic.AddUint32(&p.invalidCount, 1)
   482  						continue
   483  					}
   484  					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
   485  					if err != nil {
   486  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   487  						continue
   488  					}
   489  					// Accumulate the code and abort if enough data was retrieved
   490  					data = append(data, code)
   491  					if bytes += len(code); bytes >= softResponseLimit {
   492  						break
   493  					}
   494  				}
   495  				reply := p.ReplyCode(req.ReqID, data)
   496  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   497  				if metrics.EnabledExpensive {
   498  					miscOutCodePacketsMeter.Mark(1)
   499  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   500  					miscServingTimeCodeTimer.Update(time.Duration(task.servingTime))
   501  				}
   502  			}()
   503  		}
   504  
   505  	case GetReceiptsMsg:
   506  		p.Log().Trace("Received receipts request")
   507  		if metrics.EnabledExpensive {
   508  			miscInReceiptPacketsMeter.Mark(1)
   509  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   510  		}
   511  		var req struct {
   512  			ReqID  uint64
   513  			Hashes []common.Hash
   514  		}
   515  		if err := msg.Decode(&req); err != nil {
   516  			clientErrorMeter.Mark(1)
   517  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   518  		}
   519  		var (
   520  			bytes    int
   521  			receipts []rlp.RawValue
   522  		)
   523  		reqCnt := len(req.Hashes)
   524  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   525  			wg.Add(1)
   526  			go func() {
   527  				defer wg.Done()
   528  				for i, hash := range req.Hashes {
   529  					if i != 0 && !task.waitOrStop() {
   530  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   531  						return
   532  					}
   533  					if bytes >= softResponseLimit {
   534  						break
   535  					}
   536  					// Retrieve the requested block's receipts, skipping if unknown to us
   537  					results := h.blockchain.GetReceiptsByHash(hash)
   538  					if results == nil {
   539  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   540  							atomic.AddUint32(&p.invalidCount, 1)
   541  							continue
   542  						}
   543  					}
   544  					// If known, encode and queue for response packet
   545  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   546  						log.Error("Failed to encode receipt", "err", err)
   547  					} else {
   548  						receipts = append(receipts, encoded)
   549  						bytes += len(encoded)
   550  					}
   551  				}
   552  				reply := p.ReplyReceiptsRLP(req.ReqID, receipts)
   553  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   554  				if metrics.EnabledExpensive {
   555  					miscOutReceiptPacketsMeter.Mark(1)
   556  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   557  					miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime))
   558  				}
   559  			}()
   560  		}
   561  
   562  	case GetProofsV2Msg:
   563  		p.Log().Trace("Received les/2 proofs request")
   564  		if metrics.EnabledExpensive {
   565  			miscInTrieProofPacketsMeter.Mark(1)
   566  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   567  		}
   568  		var req struct {
   569  			ReqID uint64
   570  			Reqs  []ProofReq
   571  		}
   572  		if err := msg.Decode(&req); err != nil {
   573  			clientErrorMeter.Mark(1)
   574  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   575  		}
   576  		// Gather state data until the fetch or network limits is reached
   577  		var (
   578  			lastBHash common.Hash
   579  			root      common.Hash
   580  		)
   581  		reqCnt := len(req.Reqs)
   582  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   583  			wg.Add(1)
   584  			go func() {
   585  				defer wg.Done()
   586  				nodes := light.NewNodeSet()
   587  
   588  				for i, request := range req.Reqs {
   589  					if i != 0 && !task.waitOrStop() {
   590  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   591  						return
   592  					}
   593  					// Look up the root hash belonging to the request
   594  					var (
   595  						header *types.Header
   596  						trie   state.Trie
   597  					)
   598  					if request.BHash != lastBHash {
   599  						root, lastBHash = common.Hash{}, request.BHash
   600  
   601  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   602  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   603  							atomic.AddUint32(&p.invalidCount, 1)
   604  							continue
   605  						}
   606  						// Refuse to search stale state data in the database since looking for
   607  						// a non-exist key is kind of expensive.
   608  						local := h.blockchain.CurrentHeader().Number.Uint64()
   609  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   610  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   611  							atomic.AddUint32(&p.invalidCount, 1)
   612  							continue
   613  						}
   614  						root = header.Root
   615  					}
   616  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   617  					if root == (common.Hash{}) {
   618  						atomic.AddUint32(&p.invalidCount, 1)
   619  						continue
   620  					}
   621  					// Open the account or storage trie for the request
   622  					statedb := h.blockchain.StateCache()
   623  
   624  					switch len(request.AccKey) {
   625  					case 0:
   626  						// No account key specified, open an account trie
   627  						trie, err = statedb.OpenTrie(root)
   628  						if trie == nil || err != nil {
   629  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   630  							continue
   631  						}
   632  					default:
   633  						// Account key specified, open a storage trie
   634  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   635  						if err != nil {
   636  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   637  							atomic.AddUint32(&p.invalidCount, 1)
   638  							continue
   639  						}
   640  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   641  						if trie == nil || err != nil {
   642  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   643  							continue
   644  						}
   645  					}
   646  					// Prove the user's request from the account or stroage trie
   647  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   648  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   649  						continue
   650  					}
   651  					if nodes.DataSize() >= softResponseLimit {
   652  						break
   653  					}
   654  				}
   655  				reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList())
   656  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   657  				if metrics.EnabledExpensive {
   658  					miscOutTrieProofPacketsMeter.Mark(1)
   659  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   660  					miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime))
   661  				}
   662  			}()
   663  		}
   664  
   665  	case GetHelperTrieProofsMsg:
   666  		p.Log().Trace("Received helper trie proof request")
   667  		if metrics.EnabledExpensive {
   668  			miscInHelperTriePacketsMeter.Mark(1)
   669  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   670  		}
   671  		var req struct {
   672  			ReqID uint64
   673  			Reqs  []HelperTrieReq
   674  		}
   675  		if err := msg.Decode(&req); err != nil {
   676  			clientErrorMeter.Mark(1)
   677  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   678  		}
   679  		// Gather state data until the fetch or network limits is reached
   680  		var (
   681  			auxBytes int
   682  			auxData  [][]byte
   683  		)
   684  		reqCnt := len(req.Reqs)
   685  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   686  			wg.Add(1)
   687  			go func() {
   688  				defer wg.Done()
   689  				var (
   690  					lastIdx  uint64
   691  					lastType uint
   692  					root     common.Hash
   693  					auxTrie  *trie.Trie
   694  				)
   695  				nodes := light.NewNodeSet()
   696  				for i, request := range req.Reqs {
   697  					if i != 0 && !task.waitOrStop() {
   698  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   699  						return
   700  					}
   701  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   702  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   703  
   704  						var prefix string
   705  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   706  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   707  						}
   708  					}
   709  					if request.AuxReq == auxRoot {
   710  						var data []byte
   711  						if root != (common.Hash{}) {
   712  							data = root[:]
   713  						}
   714  						auxData = append(auxData, data)
   715  						auxBytes += len(data)
   716  					} else {
   717  						if auxTrie != nil {
   718  							auxTrie.Prove(request.Key, request.FromLevel, nodes)
   719  						}
   720  						if request.AuxReq != 0 {
   721  							data := h.getAuxiliaryHeaders(request)
   722  							auxData = append(auxData, data)
   723  							auxBytes += len(data)
   724  						}
   725  					}
   726  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   727  						break
   728  					}
   729  				}
   730  				reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   731  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   732  				if metrics.EnabledExpensive {
   733  					miscOutHelperTriePacketsMeter.Mark(1)
   734  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   735  					miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime))
   736  				}
   737  			}()
   738  		}
   739  
   740  	case SendTxV2Msg:
   741  		p.Log().Trace("Received new transactions")
   742  		if metrics.EnabledExpensive {
   743  			miscInTxsPacketsMeter.Mark(1)
   744  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   745  		}
   746  		var req struct {
   747  			ReqID uint64
   748  			Txs   []*types.Transaction
   749  		}
   750  		if err := msg.Decode(&req); err != nil {
   751  			clientErrorMeter.Mark(1)
   752  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   753  		}
   754  		reqCnt := len(req.Txs)
   755  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   756  			wg.Add(1)
   757  			go func() {
   758  				defer wg.Done()
   759  				stats := make([]light.TxStatus, len(req.Txs))
   760  				for i, tx := range req.Txs {
   761  					if i != 0 && !task.waitOrStop() {
   762  						return
   763  					}
   764  					hash := tx.Hash()
   765  					stats[i] = h.txStatus(hash)
   766  					if stats[i].Status == core.TxStatusUnknown {
   767  						addFn := h.txpool.AddRemotes
   768  						// Add txs synchronously for testing purpose
   769  						if h.addTxsSync {
   770  							addFn = h.txpool.AddRemotesSync
   771  						}
   772  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   773  							stats[i].Error = errs[0].Error()
   774  							continue
   775  						}
   776  						stats[i] = h.txStatus(hash)
   777  					}
   778  				}
   779  				reply := p.ReplyTxStatus(req.ReqID, stats)
   780  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   781  				if metrics.EnabledExpensive {
   782  					miscOutTxsPacketsMeter.Mark(1)
   783  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   784  					miscServingTimeTxTimer.Update(time.Duration(task.servingTime))
   785  				}
   786  			}()
   787  		}
   788  
   789  	case GetTxStatusMsg:
   790  		p.Log().Trace("Received transaction status query request")
   791  		if metrics.EnabledExpensive {
   792  			miscInTxStatusPacketsMeter.Mark(1)
   793  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   794  		}
   795  		var req struct {
   796  			ReqID  uint64
   797  			Hashes []common.Hash
   798  		}
   799  		if err := msg.Decode(&req); err != nil {
   800  			clientErrorMeter.Mark(1)
   801  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   802  		}
   803  		reqCnt := len(req.Hashes)
   804  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   805  			wg.Add(1)
   806  			go func() {
   807  				defer wg.Done()
   808  				stats := make([]light.TxStatus, len(req.Hashes))
   809  				for i, hash := range req.Hashes {
   810  					if i != 0 && !task.waitOrStop() {
   811  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   812  						return
   813  					}
   814  					stats[i] = h.txStatus(hash)
   815  				}
   816  				reply := p.ReplyTxStatus(req.ReqID, stats)
   817  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   818  				if metrics.EnabledExpensive {
   819  					miscOutTxStatusPacketsMeter.Mark(1)
   820  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   821  					miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime))
   822  				}
   823  			}()
   824  		}
   825  
   826  	default:
   827  		p.Log().Trace("Received invalid message", "code", msg.Code)
   828  		clientErrorMeter.Mark(1)
   829  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   830  	}
   831  	// If the client has made too much invalid request(e.g. request a non-exist data),
   832  	// reject them to prevent SPAM attack.
   833  	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
   834  		clientErrorMeter.Mark(1)
   835  		return errTooManyInvalidRequest
   836  	}
   837  	return nil
   838  }
   839  
   840  // getAccount retrieves an account from the state based on root.
   841  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   842  	trie, err := trie.New(root, triedb)
   843  	if err != nil {
   844  		return state.Account{}, err
   845  	}
   846  	blob, err := trie.TryGet(hash[:])
   847  	if err != nil {
   848  		return state.Account{}, err
   849  	}
   850  	var account state.Account
   851  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   852  		return state.Account{}, err
   853  	}
   854  	return account, nil
   855  }
   856  
   857  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   858  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   859  	switch typ {
   860  	case htCanonical:
   861  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   862  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   863  	case htBloomBits:
   864  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   865  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   866  	}
   867  	return common.Hash{}, ""
   868  }
   869  
   870  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   871  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   872  	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
   873  		blockNum := binary.BigEndian.Uint64(req.Key)
   874  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   875  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   876  	}
   877  	return nil
   878  }
   879  
   880  // txStatus returns the status of a specified transaction.
   881  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   882  	var stat light.TxStatus
   883  	// Looking the transaction in txpool first.
   884  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   885  
   886  	// If the transaction is unknown to the pool, try looking it up locally.
   887  	if stat.Status == core.TxStatusUnknown {
   888  		lookup := h.blockchain.GetTransactionLookup(hash)
   889  		if lookup != nil {
   890  			stat.Status = core.TxStatusIncluded
   891  			stat.Lookup = lookup
   892  		}
   893  	}
   894  	return stat
   895  }
   896  
   897  // broadcastHeaders broadcasts new block information to all connected light
   898  // clients. According to the agreement between client and server, server should
   899  // only broadcast new announcement if the total difficulty is higher than the
   900  // last one. Besides server will add the signature if client requires.
   901  func (h *serverHandler) broadcastHeaders() {
   902  	defer h.wg.Done()
   903  
   904  	headCh := make(chan core.ChainHeadEvent, 10)
   905  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   906  	defer headSub.Unsubscribe()
   907  
   908  	var (
   909  		lastHead *types.Header
   910  		lastTd   = common.Big0
   911  	)
   912  	for {
   913  		select {
   914  		case ev := <-headCh:
   915  			peers := h.server.peers.AllPeers()
   916  			if len(peers) == 0 {
   917  				continue
   918  			}
   919  			header := ev.Block.Header()
   920  			hash, number := header.Hash(), header.Number.Uint64()
   921  			td := h.blockchain.GetTd(hash, number)
   922  			if td == nil || td.Cmp(lastTd) <= 0 {
   923  				continue
   924  			}
   925  			var reorg uint64
   926  			if lastHead != nil {
   927  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   928  			}
   929  			lastHead, lastTd = header, td
   930  
   931  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   932  			var (
   933  				signed         bool
   934  				signedAnnounce announceData
   935  			)
   936  			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
   937  			for _, p := range peers {
   938  				p := p
   939  				switch p.announceType {
   940  				case announceTypeSimple:
   941  					p.queueSend(func() { p.SendAnnounce(announce) })
   942  				case announceTypeSigned:
   943  					if !signed {
   944  						signedAnnounce = announce
   945  						signedAnnounce.sign(h.server.privateKey)
   946  						signed = true
   947  					}
   948  					p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
   949  				}
   950  			}
   951  		case <-h.closeCh:
   952  			return
   953  		}
   954  	}
   955  }