github.com/coltonfike/e2c@v21.1.0+incompatible/les/server_handler.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/common"
    28  	"github.com/ethereum/go-ethereum/common/mclock"
    29  	"github.com/ethereum/go-ethereum/core"
    30  	"github.com/ethereum/go-ethereum/core/rawdb"
    31  	"github.com/ethereum/go-ethereum/core/state"
    32  	"github.com/ethereum/go-ethereum/core/types"
    33  	"github.com/ethereum/go-ethereum/ethdb"
    34  	"github.com/ethereum/go-ethereum/light"
    35  	"github.com/ethereum/go-ethereum/log"
    36  	"github.com/ethereum/go-ethereum/metrics"
    37  	"github.com/ethereum/go-ethereum/p2p"
    38  	"github.com/ethereum/go-ethereum/rlp"
    39  	"github.com/ethereum/go-ethereum/trie"
    40  )
    41  
    42  const (
    43  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    44  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    45  	ethVersion        = 63              // equivalent eth version for the downloader
    46  
    47  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    48  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    49  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    50  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    51  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    52  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    53  	MaxTxSend                = 64  // Amount of transactions to be send per request
    54  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    55  )
    56  
    57  var (
    58  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    59  	errFullClientPool        = errors.New("client pool is full")
    60  )
    61  
    62  // serverHandler is responsible for serving light client and process
    63  // all incoming light requests.
    64  type serverHandler struct {
    65  	blockchain *core.BlockChain
    66  	chainDb    ethdb.Database
    67  	txpool     *core.TxPool
    68  	server     *LesServer
    69  
    70  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    71  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    72  	synced  func() bool    // Callback function used to determine whether local node is synced.
    73  
    74  	// Testing fields
    75  	addTxsSync bool
    76  }
    77  
    78  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    79  	handler := &serverHandler{
    80  		server:     server,
    81  		blockchain: blockchain,
    82  		chainDb:    chainDb,
    83  		txpool:     txpool,
    84  		closeCh:    make(chan struct{}),
    85  		synced:     synced,
    86  	}
    87  	return handler
    88  }
    89  
    90  // start starts the server handler.
    91  func (h *serverHandler) start() {
    92  	h.wg.Add(1)
    93  	go h.broadcastHeaders()
    94  }
    95  
    96  // stop stops the server handler.
    97  func (h *serverHandler) stop() {
    98  	close(h.closeCh)
    99  	h.wg.Wait()
   100  }
   101  
   102  // runPeer is the p2p protocol run function for the given version.
   103  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   104  	peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version)))
   105  	h.wg.Add(1)
   106  	defer h.wg.Done()
   107  	return h.handle(peer)
   108  }
   109  
   110  func (h *serverHandler) handle(p *peer) error {
   111  	// Reject light clients if server is not synced.
   112  	if !h.synced() {
   113  		return p2p.DiscRequested
   114  	}
   115  	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
   116  
   117  	// Execute the LES handshake
   118  	var (
   119  		head   = h.blockchain.CurrentHeader()
   120  		hash   = head.Hash()
   121  		number = head.Number.Uint64()
   122  		td     = h.blockchain.GetTd(hash, number)
   123  	)
   124  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
   125  		p.Log().Debug("Light Ethereum handshake failed", "err", err)
   126  		return err
   127  	}
   128  	defer p.fcClient.Disconnect()
   129  
   130  	// Disconnect the inbound peer if it's rejected by clientPool
   131  	if !h.server.clientPool.connect(p, 0) {
   132  		p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
   133  		return errFullClientPool
   134  	}
   135  	// Register the peer locally
   136  	if err := h.server.peers.Register(p); err != nil {
   137  		h.server.clientPool.disconnect(p)
   138  		p.Log().Error("Light Ethereum peer registration failed", "err", err)
   139  		return err
   140  	}
   141  	clientConnectionGauge.Update(int64(h.server.peers.Len()))
   142  
   143  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   144  
   145  	connectedAt := mclock.Now()
   146  	defer func() {
   147  		wg.Wait() // Ensure all background task routines have exited.
   148  		h.server.peers.Unregister(p.id)
   149  		h.server.clientPool.disconnect(p)
   150  		clientConnectionGauge.Update(int64(h.server.peers.Len()))
   151  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   152  	}()
   153  
   154  	// Spawn a main loop to handle all incoming messages.
   155  	for {
   156  		select {
   157  		case err := <-p.errCh:
   158  			p.Log().Debug("Failed to send light ethereum response", "err", err)
   159  			return err
   160  		default:
   161  		}
   162  		if err := h.handleMsg(p, &wg); err != nil {
   163  			p.Log().Debug("Light Ethereum message handling failed", "err", err)
   164  			return err
   165  		}
   166  	}
   167  }
   168  
   169  // handleMsg is invoked whenever an inbound message is received from a remote
   170  // peer. The remote connection is torn down upon returning any error.
   171  func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error {
   172  	// Read the next message from the remote peer, and ensure it's fully consumed
   173  	msg, err := p.rw.ReadMsg()
   174  	if err != nil {
   175  		return err
   176  	}
   177  	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
   178  
   179  	// Discard large message which exceeds the limitation.
   180  	if msg.Size > ProtocolMaxMsgSize {
   181  		clientErrorMeter.Mark(1)
   182  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   183  	}
   184  	defer msg.Discard()
   185  
   186  	var (
   187  		maxCost uint64
   188  		task    *servingTask
   189  	)
   190  	p.responseCount++
   191  	responseCount := p.responseCount
   192  	// accept returns an indicator whether the request can be served.
   193  	// If so, deduct the max cost from the flow control buffer.
   194  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   195  		// Short circuit if the peer is already frozen or the request is invalid.
   196  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   197  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   198  			p.fcClient.OneTimeCost(inSizeCost)
   199  			return false
   200  		}
   201  		// Prepaid max cost units before request been serving.
   202  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   203  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   204  		if !accepted {
   205  			p.freezeClient()
   206  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   207  			p.fcClient.OneTimeCost(inSizeCost)
   208  			return false
   209  		}
   210  		// Create a multi-stage task, estimate the time it takes for the task to
   211  		// execute, and cache it in the request service queue.
   212  		factor := h.server.costTracker.globalFactor()
   213  		if factor < 0.001 {
   214  			factor = 1
   215  			p.Log().Error("Invalid global cost factor", "factor", factor)
   216  		}
   217  		maxTime := uint64(float64(maxCost) / factor)
   218  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   219  		if task.start() {
   220  			return true
   221  		}
   222  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   223  		return false
   224  	}
   225  	// sendResponse sends back the response and updates the flow control statistic.
   226  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   227  		p.responseLock.Lock()
   228  		defer p.responseLock.Unlock()
   229  
   230  		// Short circuit if the client is already frozen.
   231  		if p.isFrozen() {
   232  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   233  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   234  			return
   235  		}
   236  		// Positive correction buffer value with real cost.
   237  		var replySize uint32
   238  		if reply != nil {
   239  			replySize = reply.size()
   240  		}
   241  		var realCost uint64
   242  		if h.server.costTracker.testing {
   243  			realCost = maxCost // Assign a fake cost for testing purpose
   244  		} else {
   245  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   246  		}
   247  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   248  		if amount != 0 {
   249  			// Feed cost tracker request serving statistic.
   250  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   251  			// Reduce priority "balance" for the specific peer.
   252  			h.server.clientPool.requestCost(p, realCost)
   253  		}
   254  		if reply != nil {
   255  			p.queueSend(func() {
   256  				if err := reply.send(bv); err != nil {
   257  					select {
   258  					case p.errCh <- err:
   259  					default:
   260  					}
   261  				}
   262  			})
   263  		}
   264  	}
   265  	switch msg.Code {
   266  	case GetBlockHeadersMsg:
   267  		p.Log().Trace("Received block header request")
   268  		if metrics.EnabledExpensive {
   269  			miscInHeaderPacketsMeter.Mark(1)
   270  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   271  			defer func(start time.Time) { miscServingTimeHeaderTimer.UpdateSince(start) }(time.Now())
   272  		}
   273  		var req struct {
   274  			ReqID uint64
   275  			Query getBlockHeadersData
   276  		}
   277  		if err := msg.Decode(&req); err != nil {
   278  			clientErrorMeter.Mark(1)
   279  			return errResp(ErrDecode, "%v: %v", msg, err)
   280  		}
   281  		query := req.Query
   282  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   283  			wg.Add(1)
   284  			go func() {
   285  				defer wg.Done()
   286  				hashMode := query.Origin.Hash != (common.Hash{})
   287  				first := true
   288  				maxNonCanonical := uint64(100)
   289  
   290  				// Gather headers until the fetch or network limits is reached
   291  				var (
   292  					bytes   common.StorageSize
   293  					headers []*types.Header
   294  					unknown bool
   295  				)
   296  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   297  					if !first && !task.waitOrStop() {
   298  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   299  						return
   300  					}
   301  					// Retrieve the next header satisfying the query
   302  					var origin *types.Header
   303  					if hashMode {
   304  						if first {
   305  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   306  							if origin != nil {
   307  								query.Origin.Number = origin.Number.Uint64()
   308  							}
   309  						} else {
   310  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   311  						}
   312  					} else {
   313  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   314  					}
   315  					if origin == nil {
   316  						atomic.AddUint32(&p.invalidCount, 1)
   317  						break
   318  					}
   319  					headers = append(headers, origin)
   320  					bytes += estHeaderRlpSize
   321  
   322  					// Advance to the next header of the query
   323  					switch {
   324  					case hashMode && query.Reverse:
   325  						// Hash based traversal towards the genesis block
   326  						ancestor := query.Skip + 1
   327  						if ancestor == 0 {
   328  							unknown = true
   329  						} else {
   330  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   331  							unknown = query.Origin.Hash == common.Hash{}
   332  						}
   333  					case hashMode && !query.Reverse:
   334  						// Hash based traversal towards the leaf block
   335  						var (
   336  							current = origin.Number.Uint64()
   337  							next    = current + query.Skip + 1
   338  						)
   339  						if next <= current {
   340  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   341  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   342  							unknown = true
   343  						} else {
   344  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   345  								nextHash := header.Hash()
   346  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   347  								if expOldHash == query.Origin.Hash {
   348  									query.Origin.Hash, query.Origin.Number = nextHash, next
   349  								} else {
   350  									unknown = true
   351  								}
   352  							} else {
   353  								unknown = true
   354  							}
   355  						}
   356  					case query.Reverse:
   357  						// Number based traversal towards the genesis block
   358  						if query.Origin.Number >= query.Skip+1 {
   359  							query.Origin.Number -= query.Skip + 1
   360  						} else {
   361  							unknown = true
   362  						}
   363  
   364  					case !query.Reverse:
   365  						// Number based traversal towards the leaf block
   366  						query.Origin.Number += query.Skip + 1
   367  					}
   368  					first = false
   369  				}
   370  				reply := p.ReplyBlockHeaders(req.ReqID, headers)
   371  				sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
   372  				if metrics.EnabledExpensive {
   373  					miscOutHeaderPacketsMeter.Mark(1)
   374  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   375  				}
   376  			}()
   377  		}
   378  
   379  	case GetBlockBodiesMsg:
   380  		p.Log().Trace("Received block bodies request")
   381  		if metrics.EnabledExpensive {
   382  			miscInBodyPacketsMeter.Mark(1)
   383  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   384  			defer func(start time.Time) { miscServingTimeBodyTimer.UpdateSince(start) }(time.Now())
   385  		}
   386  		var req struct {
   387  			ReqID  uint64
   388  			Hashes []common.Hash
   389  		}
   390  		if err := msg.Decode(&req); err != nil {
   391  			clientErrorMeter.Mark(1)
   392  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   393  		}
   394  		var (
   395  			bytes  int
   396  			bodies []rlp.RawValue
   397  		)
   398  		reqCnt := len(req.Hashes)
   399  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   400  			wg.Add(1)
   401  			go func() {
   402  				defer wg.Done()
   403  				for i, hash := range req.Hashes {
   404  					if i != 0 && !task.waitOrStop() {
   405  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   406  						return
   407  					}
   408  					if bytes >= softResponseLimit {
   409  						break
   410  					}
   411  					body := h.blockchain.GetBodyRLP(hash)
   412  					if body == nil {
   413  						atomic.AddUint32(&p.invalidCount, 1)
   414  						continue
   415  					}
   416  					bodies = append(bodies, body)
   417  					bytes += len(body)
   418  				}
   419  				reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies)
   420  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   421  				if metrics.EnabledExpensive {
   422  					miscOutBodyPacketsMeter.Mark(1)
   423  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   424  				}
   425  			}()
   426  		}
   427  
   428  	case GetCodeMsg:
   429  		p.Log().Trace("Received code request")
   430  		if metrics.EnabledExpensive {
   431  			miscInCodePacketsMeter.Mark(1)
   432  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   433  			defer func(start time.Time) { miscServingTimeCodeTimer.UpdateSince(start) }(time.Now())
   434  		}
   435  		var req struct {
   436  			ReqID uint64
   437  			Reqs  []CodeReq
   438  		}
   439  		if err := msg.Decode(&req); err != nil {
   440  			clientErrorMeter.Mark(1)
   441  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   442  		}
   443  		var (
   444  			bytes int
   445  			data  [][]byte
   446  		)
   447  		reqCnt := len(req.Reqs)
   448  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   449  			wg.Add(1)
   450  			go func() {
   451  				defer wg.Done()
   452  				for i, request := range req.Reqs {
   453  					if i != 0 && !task.waitOrStop() {
   454  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   455  						return
   456  					}
   457  					// Look up the root hash belonging to the request
   458  					header := h.blockchain.GetHeaderByHash(request.BHash)
   459  					if header == nil {
   460  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   461  						atomic.AddUint32(&p.invalidCount, 1)
   462  						continue
   463  					}
   464  					// Refuse to search stale state data in the database since looking for
   465  					// a non-exist key is kind of expensive.
   466  					local := h.blockchain.CurrentHeader().Number.Uint64()
   467  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   468  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   469  						atomic.AddUint32(&p.invalidCount, 1)
   470  						continue
   471  					}
   472  					statedb, _ := h.blockchain.StateCache()
   473  					triedb := statedb.TrieDB()
   474  
   475  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   476  					if err != nil {
   477  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   478  						atomic.AddUint32(&p.invalidCount, 1)
   479  						continue
   480  					}
   481  					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
   482  					if err != nil {
   483  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   484  						continue
   485  					}
   486  					// Accumulate the code and abort if enough data was retrieved
   487  					data = append(data, code)
   488  					if bytes += len(code); bytes >= softResponseLimit {
   489  						break
   490  					}
   491  				}
   492  				reply := p.ReplyCode(req.ReqID, data)
   493  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   494  				if metrics.EnabledExpensive {
   495  					miscOutCodePacketsMeter.Mark(1)
   496  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   497  				}
   498  			}()
   499  		}
   500  
   501  	case GetReceiptsMsg:
   502  		p.Log().Trace("Received receipts request")
   503  		if metrics.EnabledExpensive {
   504  			miscInReceiptPacketsMeter.Mark(1)
   505  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   506  			defer func(start time.Time) { miscServingTimeReceiptTimer.UpdateSince(start) }(time.Now())
   507  		}
   508  		var req struct {
   509  			ReqID  uint64
   510  			Hashes []common.Hash
   511  		}
   512  		if err := msg.Decode(&req); err != nil {
   513  			clientErrorMeter.Mark(1)
   514  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   515  		}
   516  		var (
   517  			bytes    int
   518  			receipts []rlp.RawValue
   519  		)
   520  		reqCnt := len(req.Hashes)
   521  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   522  			wg.Add(1)
   523  			go func() {
   524  				defer wg.Done()
   525  				for i, hash := range req.Hashes {
   526  					if i != 0 && !task.waitOrStop() {
   527  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   528  						return
   529  					}
   530  					if bytes >= softResponseLimit {
   531  						break
   532  					}
   533  					// Retrieve the requested block's receipts, skipping if unknown to us
   534  					results := h.blockchain.GetReceiptsByHash(hash)
   535  					if results == nil {
   536  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   537  							atomic.AddUint32(&p.invalidCount, 1)
   538  							continue
   539  						}
   540  					}
   541  					// If known, encode and queue for response packet
   542  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   543  						log.Error("Failed to encode receipt", "err", err)
   544  					} else {
   545  						receipts = append(receipts, encoded)
   546  						bytes += len(encoded)
   547  					}
   548  				}
   549  				reply := p.ReplyReceiptsRLP(req.ReqID, receipts)
   550  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   551  				if metrics.EnabledExpensive {
   552  					miscOutReceiptPacketsMeter.Mark(1)
   553  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   554  				}
   555  			}()
   556  		}
   557  
   558  	case GetProofsV2Msg:
   559  		p.Log().Trace("Received les/2 proofs request")
   560  		if metrics.EnabledExpensive {
   561  			miscInTrieProofPacketsMeter.Mark(1)
   562  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   563  			defer func(start time.Time) { miscServingTimeTrieProofTimer.UpdateSince(start) }(time.Now())
   564  		}
   565  		var req struct {
   566  			ReqID uint64
   567  			Reqs  []ProofReq
   568  		}
   569  		if err := msg.Decode(&req); err != nil {
   570  			clientErrorMeter.Mark(1)
   571  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   572  		}
   573  		// Gather state data until the fetch or network limits is reached
   574  		var (
   575  			lastBHash common.Hash
   576  			root      common.Hash
   577  		)
   578  		reqCnt := len(req.Reqs)
   579  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   580  			wg.Add(1)
   581  			go func() {
   582  				defer wg.Done()
   583  				nodes := light.NewNodeSet()
   584  
   585  				for i, request := range req.Reqs {
   586  					if i != 0 && !task.waitOrStop() {
   587  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   588  						return
   589  					}
   590  					// Look up the root hash belonging to the request
   591  					var (
   592  						header *types.Header
   593  						trie   state.Trie
   594  					)
   595  					if request.BHash != lastBHash {
   596  						root, lastBHash = common.Hash{}, request.BHash
   597  
   598  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   599  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   600  							atomic.AddUint32(&p.invalidCount, 1)
   601  							continue
   602  						}
   603  						// Refuse to search stale state data in the database since looking for
   604  						// a non-exist key is kind of expensive.
   605  						local := h.blockchain.CurrentHeader().Number.Uint64()
   606  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   607  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   608  							atomic.AddUint32(&p.invalidCount, 1)
   609  							continue
   610  						}
   611  						root = header.Root
   612  					}
   613  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   614  					if root == (common.Hash{}) {
   615  						atomic.AddUint32(&p.invalidCount, 1)
   616  						continue
   617  					}
   618  					// Open the account or storage trie for the request
   619  					statedb, _ := h.blockchain.StateCache()
   620  
   621  					switch len(request.AccKey) {
   622  					case 0:
   623  						// No account key specified, open an account trie
   624  						trie, err = statedb.OpenTrie(root)
   625  						if trie == nil || err != nil {
   626  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   627  							continue
   628  						}
   629  					default:
   630  						// Account key specified, open a storage trie
   631  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   632  						if err != nil {
   633  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   634  							atomic.AddUint32(&p.invalidCount, 1)
   635  							continue
   636  						}
   637  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   638  						if trie == nil || err != nil {
   639  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   640  							continue
   641  						}
   642  					}
   643  					// Prove the user's request from the account or stroage trie
   644  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   645  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   646  						continue
   647  					}
   648  					if nodes.DataSize() >= softResponseLimit {
   649  						break
   650  					}
   651  				}
   652  				reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList())
   653  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   654  				if metrics.EnabledExpensive {
   655  					miscOutTrieProofPacketsMeter.Mark(1)
   656  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   657  				}
   658  			}()
   659  		}
   660  
   661  	case GetHelperTrieProofsMsg:
   662  		p.Log().Trace("Received helper trie proof request")
   663  		if metrics.EnabledExpensive {
   664  			miscInHelperTriePacketsMeter.Mark(1)
   665  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   666  			defer func(start time.Time) { miscServingTimeHelperTrieTimer.UpdateSince(start) }(time.Now())
   667  		}
   668  		var req struct {
   669  			ReqID uint64
   670  			Reqs  []HelperTrieReq
   671  		}
   672  		if err := msg.Decode(&req); err != nil {
   673  			clientErrorMeter.Mark(1)
   674  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   675  		}
   676  		// Gather state data until the fetch or network limits is reached
   677  		var (
   678  			auxBytes int
   679  			auxData  [][]byte
   680  		)
   681  		reqCnt := len(req.Reqs)
   682  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   683  			wg.Add(1)
   684  			go func() {
   685  				defer wg.Done()
   686  				var (
   687  					lastIdx  uint64
   688  					lastType uint
   689  					root     common.Hash
   690  					auxTrie  *trie.Trie
   691  				)
   692  				nodes := light.NewNodeSet()
   693  				for i, request := range req.Reqs {
   694  					if i != 0 && !task.waitOrStop() {
   695  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   696  						return
   697  					}
   698  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   699  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   700  
   701  						var prefix string
   702  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   703  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   704  						}
   705  					}
   706  					if request.AuxReq == auxRoot {
   707  						var data []byte
   708  						if root != (common.Hash{}) {
   709  							data = root[:]
   710  						}
   711  						auxData = append(auxData, data)
   712  						auxBytes += len(data)
   713  					} else {
   714  						if auxTrie != nil {
   715  							auxTrie.Prove(request.Key, request.FromLevel, nodes)
   716  						}
   717  						if request.AuxReq != 0 {
   718  							data := h.getAuxiliaryHeaders(request)
   719  							auxData = append(auxData, data)
   720  							auxBytes += len(data)
   721  						}
   722  					}
   723  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   724  						break
   725  					}
   726  				}
   727  				reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   728  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   729  				if metrics.EnabledExpensive {
   730  					miscOutHelperTriePacketsMeter.Mark(1)
   731  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   732  				}
   733  			}()
   734  		}
   735  
   736  	case SendTxV2Msg:
   737  		p.Log().Trace("Received new transactions")
   738  		if metrics.EnabledExpensive {
   739  			miscInTxsPacketsMeter.Mark(1)
   740  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   741  			defer func(start time.Time) { miscServingTimeTxTimer.UpdateSince(start) }(time.Now())
   742  		}
   743  		var req struct {
   744  			ReqID uint64
   745  			Txs   []*types.Transaction
   746  		}
   747  		if err := msg.Decode(&req); err != nil {
   748  			clientErrorMeter.Mark(1)
   749  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   750  		}
   751  		reqCnt := len(req.Txs)
   752  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   753  			wg.Add(1)
   754  			go func() {
   755  				defer wg.Done()
   756  				stats := make([]light.TxStatus, len(req.Txs))
   757  				for i, tx := range req.Txs {
   758  					if i != 0 && !task.waitOrStop() {
   759  						return
   760  					}
   761  					hash := tx.Hash()
   762  					stats[i] = h.txStatus(hash)
   763  					if stats[i].Status == core.TxStatusUnknown {
   764  						addFn := h.txpool.AddRemotes
   765  						// Add txs synchronously for testing purpose
   766  						if h.addTxsSync {
   767  							addFn = h.txpool.AddRemotesSync
   768  						}
   769  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   770  							stats[i].Error = errs[0].Error()
   771  							continue
   772  						}
   773  						stats[i] = h.txStatus(hash)
   774  					}
   775  				}
   776  				reply := p.ReplyTxStatus(req.ReqID, stats)
   777  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   778  				if metrics.EnabledExpensive {
   779  					miscOutTxsPacketsMeter.Mark(1)
   780  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   781  				}
   782  			}()
   783  		}
   784  
   785  	case GetTxStatusMsg:
   786  		p.Log().Trace("Received transaction status query request")
   787  		if metrics.EnabledExpensive {
   788  			miscInTxStatusPacketsMeter.Mark(1)
   789  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   790  			defer func(start time.Time) { miscServingTimeTxStatusTimer.UpdateSince(start) }(time.Now())
   791  		}
   792  		var req struct {
   793  			ReqID  uint64
   794  			Hashes []common.Hash
   795  		}
   796  		if err := msg.Decode(&req); err != nil {
   797  			clientErrorMeter.Mark(1)
   798  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   799  		}
   800  		reqCnt := len(req.Hashes)
   801  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   802  			wg.Add(1)
   803  			go func() {
   804  				defer wg.Done()
   805  				stats := make([]light.TxStatus, len(req.Hashes))
   806  				for i, hash := range req.Hashes {
   807  					if i != 0 && !task.waitOrStop() {
   808  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   809  						return
   810  					}
   811  					stats[i] = h.txStatus(hash)
   812  				}
   813  				reply := p.ReplyTxStatus(req.ReqID, stats)
   814  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   815  				if metrics.EnabledExpensive {
   816  					miscOutTxStatusPacketsMeter.Mark(1)
   817  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   818  				}
   819  			}()
   820  		}
   821  
   822  	default:
   823  		p.Log().Trace("Received invalid message", "code", msg.Code)
   824  		clientErrorMeter.Mark(1)
   825  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   826  	}
   827  	// If the client has made too much invalid request(e.g. request a non-exist data),
   828  	// reject them to prevent SPAM attack.
   829  	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
   830  		clientErrorMeter.Mark(1)
   831  		return errTooManyInvalidRequest
   832  	}
   833  	return nil
   834  }
   835  
   836  // getAccount retrieves an account from the state based on root.
   837  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   838  	trie, err := trie.New(root, triedb)
   839  	if err != nil {
   840  		return state.Account{}, err
   841  	}
   842  	blob, err := trie.TryGet(hash[:])
   843  	if err != nil {
   844  		return state.Account{}, err
   845  	}
   846  	var account state.Account
   847  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   848  		return state.Account{}, err
   849  	}
   850  	return account, nil
   851  }
   852  
   853  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   854  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   855  	switch typ {
   856  	case htCanonical:
   857  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   858  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   859  	case htBloomBits:
   860  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   861  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   862  	}
   863  	return common.Hash{}, ""
   864  }
   865  
   866  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   867  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   868  	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
   869  		blockNum := binary.BigEndian.Uint64(req.Key)
   870  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   871  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   872  	}
   873  	return nil
   874  }
   875  
   876  // txStatus returns the status of a specified transaction.
   877  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   878  	var stat light.TxStatus
   879  	// Looking the transaction in txpool first.
   880  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   881  
   882  	// If the transaction is unknown to the pool, try looking it up locally.
   883  	if stat.Status == core.TxStatusUnknown {
   884  		lookup := h.blockchain.GetTransactionLookup(hash)
   885  		if lookup != nil {
   886  			stat.Status = core.TxStatusIncluded
   887  			stat.Lookup = lookup
   888  		}
   889  	}
   890  	return stat
   891  }
   892  
   893  // broadcastHeaders broadcasts new block information to all connected light
   894  // clients. According to the agreement between client and server, server should
   895  // only broadcast new announcement if the total difficulty is higher than the
   896  // last one. Besides server will add the signature if client requires.
   897  func (h *serverHandler) broadcastHeaders() {
   898  	defer h.wg.Done()
   899  
   900  	headCh := make(chan core.ChainHeadEvent, 10)
   901  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   902  	defer headSub.Unsubscribe()
   903  
   904  	var (
   905  		lastHead *types.Header
   906  		lastTd   = common.Big0
   907  	)
   908  	for {
   909  		select {
   910  		case ev := <-headCh:
   911  			peers := h.server.peers.AllPeers()
   912  			if len(peers) == 0 {
   913  				continue
   914  			}
   915  			header := ev.Block.Header()
   916  			hash, number := header.Hash(), header.Number.Uint64()
   917  			td := h.blockchain.GetTd(hash, number)
   918  			if td == nil || td.Cmp(lastTd) <= 0 {
   919  				continue
   920  			}
   921  			var reorg uint64
   922  			if lastHead != nil {
   923  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   924  			}
   925  			lastHead, lastTd = header, td
   926  
   927  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   928  			var (
   929  				signed         bool
   930  				signedAnnounce announceData
   931  			)
   932  			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
   933  			for _, p := range peers {
   934  				p := p
   935  				switch p.announceType {
   936  				case announceTypeSimple:
   937  					p.queueSend(func() { p.SendAnnounce(announce) })
   938  				case announceTypeSigned:
   939  					if !signed {
   940  						signedAnnounce = announce
   941  						signedAnnounce.sign(h.server.privateKey)
   942  						signed = true
   943  					}
   944  					p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
   945  				}
   946  			}
   947  		case <-h.closeCh:
   948  			return
   949  		}
   950  	}
   951  }