github.com/core-coin/go-core/v2@v2.1.9/les/server_handler.go (about)

     1  // Copyright 2019 by the Authors
     2  // This file is part of the go-core library.
     3  //
     4  // The go-core library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-core library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-core library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/core-coin/go-core/v2/crypto"
    28  	"github.com/core-coin/go-core/v2/xcbdb"
    29  
    30  	"github.com/core-coin/go-core/v2/common"
    31  	"github.com/core-coin/go-core/v2/common/mclock"
    32  	"github.com/core-coin/go-core/v2/core"
    33  	"github.com/core-coin/go-core/v2/core/forkid"
    34  	"github.com/core-coin/go-core/v2/core/rawdb"
    35  	"github.com/core-coin/go-core/v2/core/state"
    36  	"github.com/core-coin/go-core/v2/core/types"
    37  	lps "github.com/core-coin/go-core/v2/les/lespay/server"
    38  	"github.com/core-coin/go-core/v2/light"
    39  	"github.com/core-coin/go-core/v2/log"
    40  	"github.com/core-coin/go-core/v2/metrics"
    41  	"github.com/core-coin/go-core/v2/p2p"
    42  	"github.com/core-coin/go-core/v2/p2p/enode"
    43  	"github.com/core-coin/go-core/v2/p2p/nodestate"
    44  	"github.com/core-coin/go-core/v2/rlp"
    45  	"github.com/core-coin/go-core/v2/trie"
    46  )
    47  
    48  const (
    49  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    50  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    51  	xcbVersion        = 63              // equivalent xcb version for the downloader
    52  
    53  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    54  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    55  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    56  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    57  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    58  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    59  	MaxTxSend                = 64  // Amount of transactions to be send per request
    60  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    61  )
    62  
    63  var (
    64  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    65  	errFullClientPool        = errors.New("client pool is full")
    66  )
    67  
    68  // serverHandler is responsible for serving light client and process
    69  // all incoming light requests.
    70  type serverHandler struct {
    71  	forkFilter forkid.Filter
    72  	blockchain *core.BlockChain
    73  	chainDb    xcbdb.Database
    74  	txpool     *core.TxPool
    75  	server     *LesServer
    76  
    77  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    78  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    79  	synced  func() bool    // Callback function used to determine whether local node is synced.
    80  
    81  	// Testing fields
    82  	addTxsSync bool
    83  }
    84  
    85  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb xcbdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    86  	handler := &serverHandler{
    87  		forkFilter: forkid.NewFilter(blockchain),
    88  		server:     server,
    89  		blockchain: blockchain,
    90  		chainDb:    chainDb,
    91  		txpool:     txpool,
    92  		closeCh:    make(chan struct{}),
    93  		synced:     synced,
    94  	}
    95  	return handler
    96  }
    97  
    98  // start starts the server handler.
    99  func (h *serverHandler) start() {
   100  	h.wg.Add(1)
   101  	go h.broadcastLoop()
   102  }
   103  
   104  // stop stops the server handler.
   105  func (h *serverHandler) stop() {
   106  	close(h.closeCh)
   107  	h.wg.Wait()
   108  }
   109  
   110  // runPeer is the p2p protocol run function for the given version.
   111  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   112  	peer := newClientPeer(int(version), h.server.config.NetworkId, p, newMeteredMsgWriter(rw, int(version)))
   113  	defer peer.close()
   114  	h.wg.Add(1)
   115  	defer h.wg.Done()
   116  	return h.handle(peer)
   117  }
   118  
   119  func (h *serverHandler) handle(p *clientPeer) error {
   120  	p.Log().Debug("Light Core peer connected", "name", p.Name())
   121  
   122  	// Execute the LES handshake
   123  	var (
   124  		head   = h.blockchain.CurrentHeader()
   125  		hash   = head.Hash()
   126  		number = head.Number.Uint64()
   127  		td     = h.blockchain.GetTd(hash, number)
   128  		forkID = forkid.NewID(h.blockchain.Config(), h.blockchain.Genesis().Hash(), h.blockchain.CurrentBlock().NumberU64())
   129  	)
   130  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), forkID, h.forkFilter, h.server); err != nil {
   131  		p.Log().Debug("Light Core handshake failed", "err", err)
   132  		return err
   133  	}
   134  	// Reject the duplicated peer, otherwise register it to peerset.
   135  	var registered bool
   136  	if err := h.server.ns.Operation(func() {
   137  		if h.server.ns.GetField(p.Node(), clientPeerField) != nil {
   138  			registered = true
   139  		} else {
   140  			h.server.ns.SetFieldSub(p.Node(), clientPeerField, p)
   141  		}
   142  	}); err != nil {
   143  		return err
   144  	}
   145  	if registered {
   146  		return errAlreadyRegistered
   147  	}
   148  
   149  	defer func() {
   150  		h.server.ns.SetField(p.Node(), clientPeerField, nil)
   151  		if p.fcClient != nil { // is nil when connecting another server
   152  			p.fcClient.Disconnect()
   153  		}
   154  	}()
   155  	if p.server {
   156  		// connected to another server, no messages expected, just wait for disconnection
   157  		_, err := p.rw.ReadMsg()
   158  		return err
   159  	}
   160  	// Reject light clients if server is not synced.
   161  	//
   162  	// Put this checking here, so that "non-synced" les-server peers are still allowed
   163  	// to keep the connection.
   164  	if !h.synced() {
   165  		p.Log().Debug("Light server not synced, rejecting peer")
   166  		return p2p.DiscRequested
   167  	}
   168  	// Disconnect the inbound peer if it's rejected by clientPool
   169  	if cap, err := h.server.clientPool.connect(p); cap != p.fcParams.MinRecharge || err != nil {
   170  		p.Log().Debug("Light Core peer rejected", "err", errFullClientPool)
   171  		return errFullClientPool
   172  	}
   173  	p.balance, _ = h.server.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance)
   174  	if p.balance == nil {
   175  		return p2p.DiscRequested
   176  	}
   177  	activeCount, _ := h.server.clientPool.pp.Active()
   178  	clientConnectionGauge.Update(int64(activeCount))
   179  
   180  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   181  
   182  	connectedAt := mclock.Now()
   183  	defer func() {
   184  		wg.Wait() // Ensure all background task routines have exited.
   185  		h.server.clientPool.disconnect(p)
   186  		p.balance = nil
   187  		activeCount, _ := h.server.clientPool.pp.Active()
   188  		clientConnectionGauge.Update(int64(activeCount))
   189  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   190  	}()
   191  	// Mark the peer starts to be served.
   192  	atomic.StoreUint32(&p.serving, 1)
   193  	defer atomic.StoreUint32(&p.serving, 0)
   194  
   195  	// Spawn a main loop to handle all incoming messages.
   196  	for {
   197  		select {
   198  		case err := <-p.errCh:
   199  			p.Log().Debug("Failed to send light core response", "err", err)
   200  			return err
   201  		default:
   202  		}
   203  		if err := h.handleMsg(p, &wg); err != nil {
   204  			p.Log().Debug("Light Core message handling failed", "err", err)
   205  			return err
   206  		}
   207  	}
   208  }
   209  
   210  // handleMsg is invoked whenever an inbound message is received from a remote
   211  // peer. The remote connection is torn down upon returning any error.
   212  func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
   213  	// Read the next message from the remote peer, and ensure it's fully consumed
   214  	msg, err := p.rw.ReadMsg()
   215  	if err != nil {
   216  		return err
   217  	}
   218  	p.Log().Trace("Light Core message arrived", "code", msg.Code, "bytes", msg.Size)
   219  
   220  	// Discard large message which exceeds the limitation.
   221  	if msg.Size > ProtocolMaxMsgSize {
   222  		clientErrorMeter.Mark(1)
   223  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   224  	}
   225  	defer msg.Discard()
   226  
   227  	var (
   228  		maxCost uint64
   229  		task    *servingTask
   230  	)
   231  	p.responseCount++
   232  	responseCount := p.responseCount
   233  	// accept returns an indicator whether the request can be served.
   234  	// If so, deduct the max cost from the flow control buffer.
   235  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   236  		// Short circuit if the peer is already frozen or the request is invalid.
   237  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   238  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   239  			p.fcClient.OneTimeCost(inSizeCost)
   240  			return false
   241  		}
   242  		// Prepaid max cost units before request been serving.
   243  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   244  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   245  		if !accepted {
   246  			p.freeze()
   247  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   248  			p.fcClient.OneTimeCost(inSizeCost)
   249  			return false
   250  		}
   251  		// Create a multi-stage task, estimate the time it takes for the task to
   252  		// execute, and cache it in the request service queue.
   253  		factor := h.server.costTracker.globalFactor()
   254  		if factor < 0.001 {
   255  			factor = 1
   256  			p.Log().Error("Invalid global cost factor", "factor", factor)
   257  		}
   258  		maxTime := uint64(float64(maxCost) / factor)
   259  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   260  		if task.start() {
   261  			return true
   262  		}
   263  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   264  		return false
   265  	}
   266  	// sendResponse sends back the response and updates the flow control statistic.
   267  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   268  		p.responseLock.Lock()
   269  		defer p.responseLock.Unlock()
   270  
   271  		// Short circuit if the client is already frozen.
   272  		if p.isFrozen() {
   273  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   274  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   275  			return
   276  		}
   277  		// Positive correction buffer value with real cost.
   278  		var replySize uint32
   279  		if reply != nil {
   280  			replySize = reply.size()
   281  		}
   282  		var realCost uint64
   283  		if h.server.costTracker.testing {
   284  			realCost = maxCost // Assign a fake cost for testing purpose
   285  		} else {
   286  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   287  			if realCost > maxCost {
   288  				realCost = maxCost
   289  			}
   290  		}
   291  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   292  		if amount != 0 {
   293  			// Feed cost tracker request serving statistic.
   294  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   295  			// Reduce priority "balance" for the specific peer.
   296  			p.balance.RequestServed(realCost)
   297  		}
   298  		if reply != nil {
   299  			p.queueSend(func() {
   300  				if err := reply.send(bv); err != nil {
   301  					select {
   302  					case p.errCh <- err:
   303  					default:
   304  					}
   305  				}
   306  			})
   307  		}
   308  	}
   309  	switch msg.Code {
   310  	case GetBlockHeadersMsg:
   311  		p.Log().Trace("Received block header request")
   312  		if metrics.EnabledExpensive {
   313  			miscInHeaderPacketsMeter.Mark(1)
   314  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   315  		}
   316  		var req struct {
   317  			ReqID uint64
   318  			Query getBlockHeadersData
   319  		}
   320  		if err := msg.Decode(&req); err != nil {
   321  			clientErrorMeter.Mark(1)
   322  			return errResp(ErrDecode, "%v: %v", msg, err)
   323  		}
   324  		query := req.Query
   325  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   326  			wg.Add(1)
   327  			go func() {
   328  				defer wg.Done()
   329  				hashMode := query.Origin.Hash != (common.Hash{})
   330  				first := true
   331  				maxNonCanonical := uint64(100)
   332  
   333  				// Gather headers until the fetch or network limits is reached
   334  				var (
   335  					bytes   common.StorageSize
   336  					headers []*types.Header
   337  					unknown bool
   338  				)
   339  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   340  					if !first && !task.waitOrStop() {
   341  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   342  						return
   343  					}
   344  					// Retrieve the next header satisfying the query
   345  					var origin *types.Header
   346  					if hashMode {
   347  						if first {
   348  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   349  							if origin != nil {
   350  								query.Origin.Number = origin.Number.Uint64()
   351  							}
   352  						} else {
   353  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   354  						}
   355  					} else {
   356  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   357  					}
   358  					if origin == nil {
   359  						break
   360  					}
   361  					headers = append(headers, origin)
   362  					bytes += estHeaderRlpSize
   363  
   364  					// Advance to the next header of the query
   365  					switch {
   366  					case hashMode && query.Reverse:
   367  						// Hash based traversal towards the genesis block
   368  						ancestor := query.Skip + 1
   369  						if ancestor == 0 {
   370  							unknown = true
   371  						} else {
   372  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   373  							unknown = query.Origin.Hash == common.Hash{}
   374  						}
   375  					case hashMode && !query.Reverse:
   376  						// Hash based traversal towards the leaf block
   377  						var (
   378  							current = origin.Number.Uint64()
   379  							next    = current + query.Skip + 1
   380  						)
   381  						if next <= current {
   382  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   383  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   384  							unknown = true
   385  						} else {
   386  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   387  								nextHash := header.Hash()
   388  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   389  								if expOldHash == query.Origin.Hash {
   390  									query.Origin.Hash, query.Origin.Number = nextHash, next
   391  								} else {
   392  									unknown = true
   393  								}
   394  							} else {
   395  								unknown = true
   396  							}
   397  						}
   398  					case query.Reverse:
   399  						// Number based traversal towards the genesis block
   400  						if query.Origin.Number >= query.Skip+1 {
   401  							query.Origin.Number -= query.Skip + 1
   402  						} else {
   403  							unknown = true
   404  						}
   405  
   406  					case !query.Reverse:
   407  						// Number based traversal towards the leaf block
   408  						query.Origin.Number += query.Skip + 1
   409  					}
   410  					first = false
   411  				}
   412  				reply := p.replyBlockHeaders(req.ReqID, headers)
   413  				sendResponse(req.ReqID, query.Amount, reply, task.done())
   414  				if metrics.EnabledExpensive {
   415  					miscOutHeaderPacketsMeter.Mark(1)
   416  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   417  					miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime))
   418  				}
   419  			}()
   420  		}
   421  
   422  	case GetBlockBodiesMsg:
   423  		p.Log().Trace("Received block bodies request")
   424  		if metrics.EnabledExpensive {
   425  			miscInBodyPacketsMeter.Mark(1)
   426  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   427  		}
   428  		var req struct {
   429  			ReqID  uint64
   430  			Hashes []common.Hash
   431  		}
   432  		if err := msg.Decode(&req); err != nil {
   433  			clientErrorMeter.Mark(1)
   434  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   435  		}
   436  		var (
   437  			bytes  int
   438  			bodies []rlp.RawValue
   439  		)
   440  		reqCnt := len(req.Hashes)
   441  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   442  			wg.Add(1)
   443  			go func() {
   444  				defer wg.Done()
   445  				for i, hash := range req.Hashes {
   446  					if i != 0 && !task.waitOrStop() {
   447  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   448  						return
   449  					}
   450  					if bytes >= softResponseLimit {
   451  						break
   452  					}
   453  					body := h.blockchain.GetBodyRLP(hash)
   454  					if body == nil {
   455  						p.bumpInvalid()
   456  						continue
   457  					}
   458  					bodies = append(bodies, body)
   459  					bytes += len(body)
   460  				}
   461  				reply := p.replyBlockBodiesRLP(req.ReqID, bodies)
   462  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   463  				if metrics.EnabledExpensive {
   464  					miscOutBodyPacketsMeter.Mark(1)
   465  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   466  					miscServingTimeBodyTimer.Update(time.Duration(task.servingTime))
   467  				}
   468  			}()
   469  		}
   470  
   471  	case GetCodeMsg:
   472  		p.Log().Trace("Received code request")
   473  		if metrics.EnabledExpensive {
   474  			miscInCodePacketsMeter.Mark(1)
   475  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   476  		}
   477  		var req struct {
   478  			ReqID uint64
   479  			Reqs  []CodeReq
   480  		}
   481  		if err := msg.Decode(&req); err != nil {
   482  			clientErrorMeter.Mark(1)
   483  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   484  		}
   485  		var (
   486  			bytes int
   487  			data  [][]byte
   488  		)
   489  		reqCnt := len(req.Reqs)
   490  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   491  			wg.Add(1)
   492  			go func() {
   493  				defer wg.Done()
   494  				for i, request := range req.Reqs {
   495  					if i != 0 && !task.waitOrStop() {
   496  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   497  						return
   498  					}
   499  					// Look up the root hash belonging to the request
   500  					header := h.blockchain.GetHeaderByHash(request.BHash)
   501  					if header == nil {
   502  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   503  						p.bumpInvalid()
   504  						continue
   505  					}
   506  					// Refuse to search stale state data in the database since looking for
   507  					// a non-exist key is kind of expensive.
   508  					local := h.blockchain.CurrentHeader().Number.Uint64()
   509  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   510  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   511  						p.bumpInvalid()
   512  						continue
   513  					}
   514  					triedb := h.blockchain.StateCache().TrieDB()
   515  
   516  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   517  					if err != nil {
   518  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   519  						p.bumpInvalid()
   520  						continue
   521  					}
   522  					code, err := h.blockchain.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash))
   523  					if err != nil {
   524  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   525  						continue
   526  					}
   527  					// Accumulate the code and abort if enough data was retrieved
   528  					data = append(data, code)
   529  					if bytes += len(code); bytes >= softResponseLimit {
   530  						break
   531  					}
   532  				}
   533  				reply := p.replyCode(req.ReqID, data)
   534  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   535  				if metrics.EnabledExpensive {
   536  					miscOutCodePacketsMeter.Mark(1)
   537  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   538  					miscServingTimeCodeTimer.Update(time.Duration(task.servingTime))
   539  				}
   540  			}()
   541  		}
   542  
   543  	case GetReceiptsMsg:
   544  		p.Log().Trace("Received receipts request")
   545  		if metrics.EnabledExpensive {
   546  			miscInReceiptPacketsMeter.Mark(1)
   547  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   548  		}
   549  		var req struct {
   550  			ReqID  uint64
   551  			Hashes []common.Hash
   552  		}
   553  		if err := msg.Decode(&req); err != nil {
   554  			clientErrorMeter.Mark(1)
   555  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   556  		}
   557  		var (
   558  			bytes    int
   559  			receipts []rlp.RawValue
   560  		)
   561  		reqCnt := len(req.Hashes)
   562  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   563  			wg.Add(1)
   564  			go func() {
   565  				defer wg.Done()
   566  				for i, hash := range req.Hashes {
   567  					if i != 0 && !task.waitOrStop() {
   568  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   569  						return
   570  					}
   571  					if bytes >= softResponseLimit {
   572  						break
   573  					}
   574  					// Retrieve the requested block's receipts, skipping if unknown to us
   575  					results := h.blockchain.GetReceiptsByHash(hash)
   576  					if results == nil {
   577  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   578  							p.bumpInvalid()
   579  							continue
   580  						}
   581  					}
   582  					// If known, encode and queue for response packet
   583  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   584  						log.Error("Failed to encode receipt", "err", err)
   585  					} else {
   586  						receipts = append(receipts, encoded)
   587  						bytes += len(encoded)
   588  					}
   589  				}
   590  				reply := p.replyReceiptsRLP(req.ReqID, receipts)
   591  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   592  				if metrics.EnabledExpensive {
   593  					miscOutReceiptPacketsMeter.Mark(1)
   594  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   595  					miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime))
   596  				}
   597  			}()
   598  		}
   599  
   600  	case GetProofsV2Msg:
   601  		p.Log().Trace("Received les/2 proofs request")
   602  		if metrics.EnabledExpensive {
   603  			miscInTrieProofPacketsMeter.Mark(1)
   604  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   605  		}
   606  		var req struct {
   607  			ReqID uint64
   608  			Reqs  []ProofReq
   609  		}
   610  		if err := msg.Decode(&req); err != nil {
   611  			clientErrorMeter.Mark(1)
   612  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   613  		}
   614  		// Gather state data until the fetch or network limits is reached
   615  		var (
   616  			lastBHash common.Hash
   617  			root      common.Hash
   618  			header    *types.Header
   619  		)
   620  		reqCnt := len(req.Reqs)
   621  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   622  			wg.Add(1)
   623  			go func() {
   624  				defer wg.Done()
   625  				nodes := light.NewNodeSet()
   626  
   627  				for i, request := range req.Reqs {
   628  					if i != 0 && !task.waitOrStop() {
   629  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   630  						return
   631  					}
   632  					// Look up the root hash belonging to the request
   633  					if request.BHash != lastBHash {
   634  						root, lastBHash = common.Hash{}, request.BHash
   635  
   636  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   637  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   638  							p.bumpInvalid()
   639  							continue
   640  						}
   641  						// Refuse to search stale state data in the database since looking for
   642  						// a non-exist key is kind of expensive.
   643  						local := h.blockchain.CurrentHeader().Number.Uint64()
   644  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   645  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   646  							p.bumpInvalid()
   647  							continue
   648  						}
   649  						root = header.Root
   650  					}
   651  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   652  					if root == (common.Hash{}) {
   653  						p.bumpInvalid()
   654  						continue
   655  					}
   656  					// Open the account or storage trie for the request
   657  					statedb := h.blockchain.StateCache()
   658  
   659  					var trie state.Trie
   660  					switch len(request.AccKey) {
   661  					case 0:
   662  						// No account key specified, open an account trie
   663  						trie, err = statedb.OpenTrie(root)
   664  						if trie == nil || err != nil {
   665  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   666  							continue
   667  						}
   668  					default:
   669  						// Account key specified, open a storage trie
   670  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   671  						if err != nil {
   672  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   673  							p.bumpInvalid()
   674  							continue
   675  						}
   676  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   677  						if trie == nil || err != nil {
   678  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   679  							continue
   680  						}
   681  					}
   682  					// Prove the user's request from the account or stroage trie
   683  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   684  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   685  						continue
   686  					}
   687  					if nodes.DataSize() >= softResponseLimit {
   688  						break
   689  					}
   690  				}
   691  				reply := p.replyProofsV2(req.ReqID, nodes.NodeList())
   692  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   693  				if metrics.EnabledExpensive {
   694  					miscOutTrieProofPacketsMeter.Mark(1)
   695  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   696  					miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime))
   697  				}
   698  			}()
   699  		}
   700  
   701  	case GetHelperTrieProofsMsg:
   702  		p.Log().Trace("Received helper trie proof request")
   703  		if metrics.EnabledExpensive {
   704  			miscInHelperTriePacketsMeter.Mark(1)
   705  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   706  		}
   707  		var req struct {
   708  			ReqID uint64
   709  			Reqs  []HelperTrieReq
   710  		}
   711  		if err := msg.Decode(&req); err != nil {
   712  			clientErrorMeter.Mark(1)
   713  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   714  		}
   715  		// Gather state data until the fetch or network limits is reached
   716  		var (
   717  			auxBytes int
   718  			auxData  [][]byte
   719  		)
   720  		reqCnt := len(req.Reqs)
   721  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   722  			wg.Add(1)
   723  			go func() {
   724  				defer wg.Done()
   725  				var (
   726  					lastIdx  uint64
   727  					lastType uint
   728  					root     common.Hash
   729  					auxTrie  *trie.Trie
   730  				)
   731  				nodes := light.NewNodeSet()
   732  				for i, request := range req.Reqs {
   733  					if i != 0 && !task.waitOrStop() {
   734  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   735  						return
   736  					}
   737  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   738  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   739  
   740  						var prefix string
   741  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   742  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   743  						}
   744  					}
   745  					if request.AuxReq == auxRoot {
   746  						var data []byte
   747  						if root != (common.Hash{}) {
   748  							data = root[:]
   749  						}
   750  						auxData = append(auxData, data)
   751  						auxBytes += len(data)
   752  					} else {
   753  						if auxTrie != nil {
   754  							auxTrie.Prove(request.Key, request.FromLevel, nodes)
   755  						}
   756  						if request.AuxReq != 0 {
   757  							data := h.getAuxiliaryHeaders(request)
   758  							auxData = append(auxData, data)
   759  							auxBytes += len(data)
   760  						}
   761  					}
   762  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   763  						break
   764  					}
   765  				}
   766  				reply := p.replyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   767  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   768  				if metrics.EnabledExpensive {
   769  					miscOutHelperTriePacketsMeter.Mark(1)
   770  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   771  					miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime))
   772  				}
   773  			}()
   774  		}
   775  
   776  	case SendTxV2Msg:
   777  		p.Log().Trace("Received new transactions")
   778  		if metrics.EnabledExpensive {
   779  			miscInTxsPacketsMeter.Mark(1)
   780  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   781  		}
   782  		var req struct {
   783  			ReqID uint64
   784  			Txs   []*types.Transaction
   785  		}
   786  		if err := msg.Decode(&req); err != nil {
   787  			clientErrorMeter.Mark(1)
   788  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   789  		}
   790  		reqCnt := len(req.Txs)
   791  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   792  			wg.Add(1)
   793  			go func() {
   794  				defer wg.Done()
   795  				stats := make([]light.TxStatus, len(req.Txs))
   796  				for i, tx := range req.Txs {
   797  					if i != 0 && !task.waitOrStop() {
   798  						return
   799  					}
   800  					hash := tx.Hash()
   801  					stats[i] = h.txStatus(hash)
   802  					if stats[i].Status == core.TxStatusUnknown {
   803  						addFn := h.txpool.AddRemotes
   804  						// Add txs synchronously for testing purpose
   805  						if h.addTxsSync {
   806  							addFn = h.txpool.AddRemotesSync
   807  						}
   808  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   809  							stats[i].Error = errs[0].Error()
   810  							continue
   811  						}
   812  						stats[i] = h.txStatus(hash)
   813  					}
   814  				}
   815  				reply := p.replyTxStatus(req.ReqID, stats)
   816  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   817  				if metrics.EnabledExpensive {
   818  					miscOutTxsPacketsMeter.Mark(1)
   819  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   820  					miscServingTimeTxTimer.Update(time.Duration(task.servingTime))
   821  				}
   822  			}()
   823  		}
   824  
   825  	case GetTxStatusMsg:
   826  		p.Log().Trace("Received transaction status query request")
   827  		if metrics.EnabledExpensive {
   828  			miscInTxStatusPacketsMeter.Mark(1)
   829  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   830  		}
   831  		var req struct {
   832  			ReqID  uint64
   833  			Hashes []common.Hash
   834  		}
   835  		if err := msg.Decode(&req); err != nil {
   836  			clientErrorMeter.Mark(1)
   837  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   838  		}
   839  		reqCnt := len(req.Hashes)
   840  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   841  			wg.Add(1)
   842  			go func() {
   843  				defer wg.Done()
   844  				stats := make([]light.TxStatus, len(req.Hashes))
   845  				for i, hash := range req.Hashes {
   846  					if i != 0 && !task.waitOrStop() {
   847  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   848  						return
   849  					}
   850  					stats[i] = h.txStatus(hash)
   851  				}
   852  				reply := p.replyTxStatus(req.ReqID, stats)
   853  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   854  				if metrics.EnabledExpensive {
   855  					miscOutTxStatusPacketsMeter.Mark(1)
   856  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   857  					miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime))
   858  				}
   859  			}()
   860  		}
   861  
   862  	default:
   863  		p.Log().Trace("Received invalid message", "code", msg.Code)
   864  		clientErrorMeter.Mark(1)
   865  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   866  	}
   867  	// If the client has made too much invalid request(e.g. request a non-existent data),
   868  	// reject them to prevent SPAM attack.
   869  	if p.getInvalid() > maxRequestErrors {
   870  		clientErrorMeter.Mark(1)
   871  		return errTooManyInvalidRequest
   872  	}
   873  	return nil
   874  }
   875  
   876  // getAccount retrieves an account from the state based on root.
   877  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   878  	trie, err := trie.New(root, triedb)
   879  	if err != nil {
   880  		return state.Account{}, err
   881  	}
   882  	blob, err := trie.TryGet(hash[:])
   883  	if err != nil {
   884  		return state.Account{}, err
   885  	}
   886  	var account state.Account
   887  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   888  		return state.Account{}, err
   889  	}
   890  	return account, nil
   891  }
   892  
   893  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   894  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   895  	switch typ {
   896  	case htCanonical:
   897  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   898  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   899  	case htBloomBits:
   900  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   901  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   902  	}
   903  	return common.Hash{}, ""
   904  }
   905  
   906  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   907  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   908  	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
   909  		blockNum := binary.BigEndian.Uint64(req.Key)
   910  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   911  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   912  	}
   913  	return nil
   914  }
   915  
   916  // txStatus returns the status of a specified transaction.
   917  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   918  	var stat light.TxStatus
   919  	// Looking the transaction in txpool first.
   920  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   921  
   922  	// If the transaction is unknown to the pool, try looking it up locally.
   923  	if stat.Status == core.TxStatusUnknown {
   924  		lookup := h.blockchain.GetTransactionLookup(hash)
   925  		if lookup != nil {
   926  			stat.Status = core.TxStatusIncluded
   927  			stat.Lookup = lookup
   928  		}
   929  	}
   930  	return stat
   931  }
   932  
   933  // broadcastLoop broadcasts new block information to all connected light
   934  // clients. According to the agreement between client and server, server should
   935  // only broadcast new announcement if the total difficulty is higher than the
   936  // last one. Besides server will add the signature if client requires.
   937  func (h *serverHandler) broadcastLoop() {
   938  	defer h.wg.Done()
   939  
   940  	headCh := make(chan core.ChainHeadEvent, 10)
   941  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   942  	defer headSub.Unsubscribe()
   943  
   944  	var (
   945  		lastHead *types.Header
   946  		lastTd   = common.Big0
   947  	)
   948  	for {
   949  		select {
   950  		case ev := <-headCh:
   951  			header := ev.Block.Header()
   952  			hash, number := header.Hash(), header.Number.Uint64()
   953  			td := h.blockchain.GetTd(hash, number)
   954  			if td == nil || td.Cmp(lastTd) <= 0 {
   955  				continue
   956  			}
   957  			var reorg uint64
   958  			if lastHead != nil {
   959  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   960  			}
   961  			lastHead, lastTd = header, td
   962  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   963  			h.server.broadcaster.broadcast(announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg})
   964  		case <-h.closeCh:
   965  			return
   966  		}
   967  	}
   968  }
   969  
   970  // broadcaster sends new header announcements to active client peers
   971  type broadcaster struct {
   972  	ns                           *nodestate.NodeStateMachine
   973  	privateKey                   *crypto.PrivateKey
   974  	lastAnnounce, signedAnnounce announceData
   975  }
   976  
   977  // newBroadcaster creates a new broadcaster
   978  func newBroadcaster(ns *nodestate.NodeStateMachine) *broadcaster {
   979  	b := &broadcaster{ns: ns}
   980  	ns.SubscribeState(priorityPoolSetup.ActiveFlag, func(node *enode.Node, oldState, newState nodestate.Flags) {
   981  		if newState.Equals(priorityPoolSetup.ActiveFlag) {
   982  			// send last announcement to activated peers
   983  			b.sendTo(node)
   984  		}
   985  	})
   986  	return b
   987  }
   988  
   989  // setSignerKey sets the signer key for signed announcements. Should be called before
   990  // starting the protocol handler.
   991  func (b *broadcaster) setSignerKey(privateKey *crypto.PrivateKey) {
   992  	b.privateKey = privateKey
   993  }
   994  
   995  // broadcast sends the given announcements to all active peers
   996  func (b *broadcaster) broadcast(announce announceData) {
   997  	b.ns.Operation(func() {
   998  		// iterate in an Operation to ensure that the active set does not change while iterating
   999  		b.lastAnnounce = announce
  1000  		b.ns.ForEach(priorityPoolSetup.ActiveFlag, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
  1001  			b.sendTo(node)
  1002  		})
  1003  	})
  1004  }
  1005  
  1006  // sendTo sends the most recent announcement to the given node unless the same or higher Td
  1007  // announcement has already been sent.
  1008  func (b *broadcaster) sendTo(node *enode.Node) {
  1009  	if b.lastAnnounce.Td == nil {
  1010  		return
  1011  	}
  1012  	if p, _ := b.ns.GetField(node, clientPeerField).(*clientPeer); p != nil {
  1013  		if p.headInfo.Td == nil || b.lastAnnounce.Td.Cmp(p.headInfo.Td) > 0 {
  1014  			announce := b.lastAnnounce
  1015  			switch p.announceType {
  1016  			case announceTypeSimple:
  1017  				if !p.queueSend(func() { p.sendAnnounce(announce) }) {
  1018  					log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash)
  1019  				} else {
  1020  					log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash)
  1021  				}
  1022  			case announceTypeSigned:
  1023  				if b.signedAnnounce.Hash != b.lastAnnounce.Hash {
  1024  					b.signedAnnounce = b.lastAnnounce
  1025  					b.signedAnnounce.sign(b.privateKey)
  1026  				}
  1027  				announce := b.signedAnnounce
  1028  				if !p.queueSend(func() { p.sendAnnounce(announce) }) {
  1029  					log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash)
  1030  				} else {
  1031  					log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash)
  1032  				}
  1033  			}
  1034  			p.headInfo = blockInfo{b.lastAnnounce.Hash, b.lastAnnounce.Number, b.lastAnnounce.Td}
  1035  		}
  1036  	}
  1037  }