git.pirl.io/community/pirl@v0.0.0-20201111064343-9d3d31ff74be/les/server_handler.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package les
    18  
    19  import (
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"git.pirl.io/community/pirl/common"
    28  	"git.pirl.io/community/pirl/common/mclock"
    29  	"git.pirl.io/community/pirl/core"
    30  	"git.pirl.io/community/pirl/core/rawdb"
    31  	"git.pirl.io/community/pirl/core/state"
    32  	"git.pirl.io/community/pirl/core/types"
    33  	"git.pirl.io/community/pirl/ethdb"
    34  	"git.pirl.io/community/pirl/light"
    35  	"git.pirl.io/community/pirl/log"
    36  	"git.pirl.io/community/pirl/metrics"
    37  	"git.pirl.io/community/pirl/p2p"
    38  	"git.pirl.io/community/pirl/rlp"
    39  	"git.pirl.io/community/pirl/trie"
    40  )
    41  
    42  const (
    43  	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
    44  	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
    45  	ethVersion        = 63              // equivalent eth version for the downloader
    46  
    47  	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
    48  	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
    49  	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
    50  	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
    51  	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
    52  	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
    53  	MaxTxSend                = 64  // Amount of transactions to be send per request
    54  	MaxTxStatus              = 256 // Amount of transactions to queried per request
    55  )
    56  
    57  var (
    58  	errTooManyInvalidRequest = errors.New("too many invalid requests made")
    59  	errFullClientPool        = errors.New("client pool is full")
    60  )
    61  
    62  // serverHandler is responsible for serving light client and process
    63  // all incoming light requests.
    64  type serverHandler struct {
    65  	blockchain *core.BlockChain
    66  	chainDb    ethdb.Database
    67  	txpool     *core.TxPool
    68  	server     *LesServer
    69  
    70  	closeCh chan struct{}  // Channel used to exit all background routines of handler.
    71  	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
    72  	synced  func() bool    // Callback function used to determine whether local node is synced.
    73  
    74  	// Testing fields
    75  	addTxsSync bool
    76  }
    77  
    78  func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
    79  	handler := &serverHandler{
    80  		server:     server,
    81  		blockchain: blockchain,
    82  		chainDb:    chainDb,
    83  		txpool:     txpool,
    84  		closeCh:    make(chan struct{}),
    85  		synced:     synced,
    86  	}
    87  	return handler
    88  }
    89  
    90  // start starts the server handler.
    91  func (h *serverHandler) start() {
    92  	h.wg.Add(1)
    93  	go h.broadcastHeaders()
    94  }
    95  
    96  // stop stops the server handler.
    97  func (h *serverHandler) stop() {
    98  	close(h.closeCh)
    99  	h.wg.Wait()
   100  }
   101  
   102  // runPeer is the p2p protocol run function for the given version.
   103  func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
   104  	peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version)))
   105  	h.wg.Add(1)
   106  	defer h.wg.Done()
   107  	return h.handle(peer)
   108  }
   109  
   110  func (h *serverHandler) handle(p *peer) error {
   111  	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
   112  
   113  	// Execute the LES handshake
   114  	var (
   115  		head   = h.blockchain.CurrentHeader()
   116  		hash   = head.Hash()
   117  		number = head.Number.Uint64()
   118  		td     = h.blockchain.GetTd(hash, number)
   119  	)
   120  	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
   121  		p.Log().Debug("Light Ethereum handshake failed", "err", err)
   122  		return err
   123  	}
   124  	if p.server {
   125  		// connected to another server, no messages expected, just wait for disconnection
   126  		_, err := p.rw.ReadMsg()
   127  		return err
   128  	}
   129  	// Reject light clients if server is not synced.
   130  	if !h.synced() {
   131  		p.Log().Debug("Light server not synced, rejecting peer")
   132  		return p2p.DiscRequested
   133  	}
   134  	defer p.fcClient.Disconnect()
   135  
   136  	// Disconnect the inbound peer if it's rejected by clientPool
   137  	if !h.server.clientPool.connect(p, 0) {
   138  		p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
   139  		return errFullClientPool
   140  	}
   141  	// Register the peer locally
   142  	if err := h.server.peers.Register(p); err != nil {
   143  		h.server.clientPool.disconnect(p)
   144  		p.Log().Error("Light Ethereum peer registration failed", "err", err)
   145  		return err
   146  	}
   147  	clientConnectionGauge.Update(int64(h.server.peers.Len()))
   148  
   149  	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
   150  
   151  	connectedAt := mclock.Now()
   152  	defer func() {
   153  		wg.Wait() // Ensure all background task routines have exited.
   154  		h.server.peers.Unregister(p.id)
   155  		h.server.clientPool.disconnect(p)
   156  		clientConnectionGauge.Update(int64(h.server.peers.Len()))
   157  		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
   158  	}()
   159  
   160  	// Spawn a main loop to handle all incoming messages.
   161  	for {
   162  		select {
   163  		case err := <-p.errCh:
   164  			p.Log().Debug("Failed to send light ethereum response", "err", err)
   165  			return err
   166  		default:
   167  		}
   168  		if err := h.handleMsg(p, &wg); err != nil {
   169  			p.Log().Debug("Light Ethereum message handling failed", "err", err)
   170  			return err
   171  		}
   172  	}
   173  }
   174  
   175  // handleMsg is invoked whenever an inbound message is received from a remote
   176  // peer. The remote connection is torn down upon returning any error.
   177  func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error {
   178  	// Read the next message from the remote peer, and ensure it's fully consumed
   179  	msg, err := p.rw.ReadMsg()
   180  	if err != nil {
   181  		return err
   182  	}
   183  	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
   184  
   185  	// Discard large message which exceeds the limitation.
   186  	if msg.Size > ProtocolMaxMsgSize {
   187  		clientErrorMeter.Mark(1)
   188  		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
   189  	}
   190  	defer msg.Discard()
   191  
   192  	var (
   193  		maxCost uint64
   194  		task    *servingTask
   195  	)
   196  	p.responseCount++
   197  	responseCount := p.responseCount
   198  	// accept returns an indicator whether the request can be served.
   199  	// If so, deduct the max cost from the flow control buffer.
   200  	accept := func(reqID, reqCnt, maxCnt uint64) bool {
   201  		// Short circuit if the peer is already frozen or the request is invalid.
   202  		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
   203  		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
   204  			p.fcClient.OneTimeCost(inSizeCost)
   205  			return false
   206  		}
   207  		// Prepaid max cost units before request been serving.
   208  		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
   209  		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
   210  		if !accepted {
   211  			p.freezeClient()
   212  			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
   213  			p.fcClient.OneTimeCost(inSizeCost)
   214  			return false
   215  		}
   216  		// Create a multi-stage task, estimate the time it takes for the task to
   217  		// execute, and cache it in the request service queue.
   218  		factor := h.server.costTracker.globalFactor()
   219  		if factor < 0.001 {
   220  			factor = 1
   221  			p.Log().Error("Invalid global cost factor", "factor", factor)
   222  		}
   223  		maxTime := uint64(float64(maxCost) / factor)
   224  		task = h.server.servingQueue.newTask(p, maxTime, priority)
   225  		if task.start() {
   226  			return true
   227  		}
   228  		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
   229  		return false
   230  	}
   231  	// sendResponse sends back the response and updates the flow control statistic.
   232  	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
   233  		p.responseLock.Lock()
   234  		defer p.responseLock.Unlock()
   235  
   236  		// Short circuit if the client is already frozen.
   237  		if p.isFrozen() {
   238  			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
   239  			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   240  			return
   241  		}
   242  		// Positive correction buffer value with real cost.
   243  		var replySize uint32
   244  		if reply != nil {
   245  			replySize = reply.size()
   246  		}
   247  		var realCost uint64
   248  		if h.server.costTracker.testing {
   249  			realCost = maxCost // Assign a fake cost for testing purpose
   250  		} else {
   251  			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
   252  		}
   253  		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
   254  		if amount != 0 {
   255  			// Feed cost tracker request serving statistic.
   256  			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
   257  			// Reduce priority "balance" for the specific peer.
   258  			h.server.clientPool.requestCost(p, realCost)
   259  		}
   260  		if reply != nil {
   261  			p.queueSend(func() {
   262  				if err := reply.send(bv); err != nil {
   263  					select {
   264  					case p.errCh <- err:
   265  					default:
   266  					}
   267  				}
   268  			})
   269  		}
   270  	}
   271  	switch msg.Code {
   272  	case GetBlockHeadersMsg:
   273  		p.Log().Trace("Received block header request")
   274  		if metrics.EnabledExpensive {
   275  			miscInHeaderPacketsMeter.Mark(1)
   276  			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
   277  		}
   278  		var req struct {
   279  			ReqID uint64
   280  			Query getBlockHeadersData
   281  		}
   282  		if err := msg.Decode(&req); err != nil {
   283  			clientErrorMeter.Mark(1)
   284  			return errResp(ErrDecode, "%v: %v", msg, err)
   285  		}
   286  		query := req.Query
   287  		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
   288  			wg.Add(1)
   289  			go func() {
   290  				defer wg.Done()
   291  				hashMode := query.Origin.Hash != (common.Hash{})
   292  				first := true
   293  				maxNonCanonical := uint64(100)
   294  
   295  				// Gather headers until the fetch or network limits is reached
   296  				var (
   297  					bytes   common.StorageSize
   298  					headers []*types.Header
   299  					unknown bool
   300  				)
   301  				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
   302  					if !first && !task.waitOrStop() {
   303  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   304  						return
   305  					}
   306  					// Retrieve the next header satisfying the query
   307  					var origin *types.Header
   308  					if hashMode {
   309  						if first {
   310  							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
   311  							if origin != nil {
   312  								query.Origin.Number = origin.Number.Uint64()
   313  							}
   314  						} else {
   315  							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
   316  						}
   317  					} else {
   318  						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
   319  					}
   320  					if origin == nil {
   321  						atomic.AddUint32(&p.invalidCount, 1)
   322  						break
   323  					}
   324  					headers = append(headers, origin)
   325  					bytes += estHeaderRlpSize
   326  
   327  					// Advance to the next header of the query
   328  					switch {
   329  					case hashMode && query.Reverse:
   330  						// Hash based traversal towards the genesis block
   331  						ancestor := query.Skip + 1
   332  						if ancestor == 0 {
   333  							unknown = true
   334  						} else {
   335  							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
   336  							unknown = query.Origin.Hash == common.Hash{}
   337  						}
   338  					case hashMode && !query.Reverse:
   339  						// Hash based traversal towards the leaf block
   340  						var (
   341  							current = origin.Number.Uint64()
   342  							next    = current + query.Skip + 1
   343  						)
   344  						if next <= current {
   345  							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
   346  							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
   347  							unknown = true
   348  						} else {
   349  							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
   350  								nextHash := header.Hash()
   351  								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
   352  								if expOldHash == query.Origin.Hash {
   353  									query.Origin.Hash, query.Origin.Number = nextHash, next
   354  								} else {
   355  									unknown = true
   356  								}
   357  							} else {
   358  								unknown = true
   359  							}
   360  						}
   361  					case query.Reverse:
   362  						// Number based traversal towards the genesis block
   363  						if query.Origin.Number >= query.Skip+1 {
   364  							query.Origin.Number -= query.Skip + 1
   365  						} else {
   366  							unknown = true
   367  						}
   368  
   369  					case !query.Reverse:
   370  						// Number based traversal towards the leaf block
   371  						query.Origin.Number += query.Skip + 1
   372  					}
   373  					first = false
   374  				}
   375  				reply := p.ReplyBlockHeaders(req.ReqID, headers)
   376  				sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
   377  				if metrics.EnabledExpensive {
   378  					miscOutHeaderPacketsMeter.Mark(1)
   379  					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
   380  					miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime))
   381  				}
   382  			}()
   383  		}
   384  
   385  	case GetBlockBodiesMsg:
   386  		p.Log().Trace("Received block bodies request")
   387  		if metrics.EnabledExpensive {
   388  			miscInBodyPacketsMeter.Mark(1)
   389  			miscInBodyTrafficMeter.Mark(int64(msg.Size))
   390  		}
   391  		var req struct {
   392  			ReqID  uint64
   393  			Hashes []common.Hash
   394  		}
   395  		if err := msg.Decode(&req); err != nil {
   396  			clientErrorMeter.Mark(1)
   397  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   398  		}
   399  		var (
   400  			bytes  int
   401  			bodies []rlp.RawValue
   402  		)
   403  		reqCnt := len(req.Hashes)
   404  		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
   405  			wg.Add(1)
   406  			go func() {
   407  				defer wg.Done()
   408  				for i, hash := range req.Hashes {
   409  					if i != 0 && !task.waitOrStop() {
   410  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   411  						return
   412  					}
   413  					if bytes >= softResponseLimit {
   414  						break
   415  					}
   416  					body := h.blockchain.GetBodyRLP(hash)
   417  					if body == nil {
   418  						atomic.AddUint32(&p.invalidCount, 1)
   419  						continue
   420  					}
   421  					bodies = append(bodies, body)
   422  					bytes += len(body)
   423  				}
   424  				reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies)
   425  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   426  				if metrics.EnabledExpensive {
   427  					miscOutBodyPacketsMeter.Mark(1)
   428  					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
   429  					miscServingTimeBodyTimer.Update(time.Duration(task.servingTime))
   430  				}
   431  			}()
   432  		}
   433  
   434  	case GetCodeMsg:
   435  		p.Log().Trace("Received code request")
   436  		if metrics.EnabledExpensive {
   437  			miscInCodePacketsMeter.Mark(1)
   438  			miscInCodeTrafficMeter.Mark(int64(msg.Size))
   439  		}
   440  		var req struct {
   441  			ReqID uint64
   442  			Reqs  []CodeReq
   443  		}
   444  		if err := msg.Decode(&req); err != nil {
   445  			clientErrorMeter.Mark(1)
   446  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   447  		}
   448  		var (
   449  			bytes int
   450  			data  [][]byte
   451  		)
   452  		reqCnt := len(req.Reqs)
   453  		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
   454  			wg.Add(1)
   455  			go func() {
   456  				defer wg.Done()
   457  				for i, request := range req.Reqs {
   458  					if i != 0 && !task.waitOrStop() {
   459  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   460  						return
   461  					}
   462  					// Look up the root hash belonging to the request
   463  					header := h.blockchain.GetHeaderByHash(request.BHash)
   464  					if header == nil {
   465  						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
   466  						atomic.AddUint32(&p.invalidCount, 1)
   467  						continue
   468  					}
   469  					// Refuse to search stale state data in the database since looking for
   470  					// a non-exist key is kind of expensive.
   471  					local := h.blockchain.CurrentHeader().Number.Uint64()
   472  					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   473  						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
   474  						atomic.AddUint32(&p.invalidCount, 1)
   475  						continue
   476  					}
   477  					triedb := h.blockchain.StateCache().TrieDB()
   478  
   479  					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
   480  					if err != nil {
   481  						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   482  						atomic.AddUint32(&p.invalidCount, 1)
   483  						continue
   484  					}
   485  					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
   486  					if err != nil {
   487  						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
   488  						continue
   489  					}
   490  					// Accumulate the code and abort if enough data was retrieved
   491  					data = append(data, code)
   492  					if bytes += len(code); bytes >= softResponseLimit {
   493  						break
   494  					}
   495  				}
   496  				reply := p.ReplyCode(req.ReqID, data)
   497  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   498  				if metrics.EnabledExpensive {
   499  					miscOutCodePacketsMeter.Mark(1)
   500  					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
   501  					miscServingTimeCodeTimer.Update(time.Duration(task.servingTime))
   502  				}
   503  			}()
   504  		}
   505  
   506  	case GetReceiptsMsg:
   507  		p.Log().Trace("Received receipts request")
   508  		if metrics.EnabledExpensive {
   509  			miscInReceiptPacketsMeter.Mark(1)
   510  			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
   511  		}
   512  		var req struct {
   513  			ReqID  uint64
   514  			Hashes []common.Hash
   515  		}
   516  		if err := msg.Decode(&req); err != nil {
   517  			clientErrorMeter.Mark(1)
   518  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   519  		}
   520  		var (
   521  			bytes    int
   522  			receipts []rlp.RawValue
   523  		)
   524  		reqCnt := len(req.Hashes)
   525  		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
   526  			wg.Add(1)
   527  			go func() {
   528  				defer wg.Done()
   529  				for i, hash := range req.Hashes {
   530  					if i != 0 && !task.waitOrStop() {
   531  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   532  						return
   533  					}
   534  					if bytes >= softResponseLimit {
   535  						break
   536  					}
   537  					// Retrieve the requested block's receipts, skipping if unknown to us
   538  					results := h.blockchain.GetReceiptsByHash(hash)
   539  					if results == nil {
   540  						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
   541  							atomic.AddUint32(&p.invalidCount, 1)
   542  							continue
   543  						}
   544  					}
   545  					// If known, encode and queue for response packet
   546  					if encoded, err := rlp.EncodeToBytes(results); err != nil {
   547  						log.Error("Failed to encode receipt", "err", err)
   548  					} else {
   549  						receipts = append(receipts, encoded)
   550  						bytes += len(encoded)
   551  					}
   552  				}
   553  				reply := p.ReplyReceiptsRLP(req.ReqID, receipts)
   554  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   555  				if metrics.EnabledExpensive {
   556  					miscOutReceiptPacketsMeter.Mark(1)
   557  					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
   558  					miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime))
   559  				}
   560  			}()
   561  		}
   562  
   563  	case GetProofsV2Msg:
   564  		p.Log().Trace("Received les/2 proofs request")
   565  		if metrics.EnabledExpensive {
   566  			miscInTrieProofPacketsMeter.Mark(1)
   567  			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
   568  		}
   569  		var req struct {
   570  			ReqID uint64
   571  			Reqs  []ProofReq
   572  		}
   573  		if err := msg.Decode(&req); err != nil {
   574  			clientErrorMeter.Mark(1)
   575  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   576  		}
   577  		// Gather state data until the fetch or network limits is reached
   578  		var (
   579  			lastBHash common.Hash
   580  			root      common.Hash
   581  		)
   582  		reqCnt := len(req.Reqs)
   583  		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
   584  			wg.Add(1)
   585  			go func() {
   586  				defer wg.Done()
   587  				nodes := light.NewNodeSet()
   588  
   589  				for i, request := range req.Reqs {
   590  					if i != 0 && !task.waitOrStop() {
   591  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   592  						return
   593  					}
   594  					// Look up the root hash belonging to the request
   595  					var (
   596  						header *types.Header
   597  						trie   state.Trie
   598  					)
   599  					if request.BHash != lastBHash {
   600  						root, lastBHash = common.Hash{}, request.BHash
   601  
   602  						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
   603  							p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash)
   604  							atomic.AddUint32(&p.invalidCount, 1)
   605  							continue
   606  						}
   607  						// Refuse to search stale state data in the database since looking for
   608  						// a non-exist key is kind of expensive.
   609  						local := h.blockchain.CurrentHeader().Number.Uint64()
   610  						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
   611  							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
   612  							atomic.AddUint32(&p.invalidCount, 1)
   613  							continue
   614  						}
   615  						root = header.Root
   616  					}
   617  					// If a header lookup failed (non existent), ignore subsequent requests for the same header
   618  					if root == (common.Hash{}) {
   619  						atomic.AddUint32(&p.invalidCount, 1)
   620  						continue
   621  					}
   622  					// Open the account or storage trie for the request
   623  					statedb := h.blockchain.StateCache()
   624  
   625  					switch len(request.AccKey) {
   626  					case 0:
   627  						// No account key specified, open an account trie
   628  						trie, err = statedb.OpenTrie(root)
   629  						if trie == nil || err != nil {
   630  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
   631  							continue
   632  						}
   633  					default:
   634  						// Account key specified, open a storage trie
   635  						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
   636  						if err != nil {
   637  							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
   638  							atomic.AddUint32(&p.invalidCount, 1)
   639  							continue
   640  						}
   641  						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
   642  						if trie == nil || err != nil {
   643  							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
   644  							continue
   645  						}
   646  					}
   647  					// Prove the user's request from the account or stroage trie
   648  					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
   649  						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
   650  						continue
   651  					}
   652  					if nodes.DataSize() >= softResponseLimit {
   653  						break
   654  					}
   655  				}
   656  				reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList())
   657  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   658  				if metrics.EnabledExpensive {
   659  					miscOutTrieProofPacketsMeter.Mark(1)
   660  					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
   661  					miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime))
   662  				}
   663  			}()
   664  		}
   665  
   666  	case GetHelperTrieProofsMsg:
   667  		p.Log().Trace("Received helper trie proof request")
   668  		if metrics.EnabledExpensive {
   669  			miscInHelperTriePacketsMeter.Mark(1)
   670  			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
   671  		}
   672  		var req struct {
   673  			ReqID uint64
   674  			Reqs  []HelperTrieReq
   675  		}
   676  		if err := msg.Decode(&req); err != nil {
   677  			clientErrorMeter.Mark(1)
   678  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   679  		}
   680  		// Gather state data until the fetch or network limits is reached
   681  		var (
   682  			auxBytes int
   683  			auxData  [][]byte
   684  		)
   685  		reqCnt := len(req.Reqs)
   686  		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
   687  			wg.Add(1)
   688  			go func() {
   689  				defer wg.Done()
   690  				var (
   691  					lastIdx  uint64
   692  					lastType uint
   693  					root     common.Hash
   694  					auxTrie  *trie.Trie
   695  				)
   696  				nodes := light.NewNodeSet()
   697  				for i, request := range req.Reqs {
   698  					if i != 0 && !task.waitOrStop() {
   699  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   700  						return
   701  					}
   702  					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
   703  						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
   704  
   705  						var prefix string
   706  						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
   707  							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
   708  						}
   709  					}
   710  					if request.AuxReq == auxRoot {
   711  						var data []byte
   712  						if root != (common.Hash{}) {
   713  							data = root[:]
   714  						}
   715  						auxData = append(auxData, data)
   716  						auxBytes += len(data)
   717  					} else {
   718  						if auxTrie != nil {
   719  							auxTrie.Prove(request.Key, request.FromLevel, nodes)
   720  						}
   721  						if request.AuxReq != 0 {
   722  							data := h.getAuxiliaryHeaders(request)
   723  							auxData = append(auxData, data)
   724  							auxBytes += len(data)
   725  						}
   726  					}
   727  					if nodes.DataSize()+auxBytes >= softResponseLimit {
   728  						break
   729  					}
   730  				}
   731  				reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
   732  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   733  				if metrics.EnabledExpensive {
   734  					miscOutHelperTriePacketsMeter.Mark(1)
   735  					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
   736  					miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime))
   737  				}
   738  			}()
   739  		}
   740  
   741  	case SendTxV2Msg:
   742  		p.Log().Trace("Received new transactions")
   743  		if metrics.EnabledExpensive {
   744  			miscInTxsPacketsMeter.Mark(1)
   745  			miscInTxsTrafficMeter.Mark(int64(msg.Size))
   746  		}
   747  		var req struct {
   748  			ReqID uint64
   749  			Txs   []*types.Transaction
   750  		}
   751  		if err := msg.Decode(&req); err != nil {
   752  			clientErrorMeter.Mark(1)
   753  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   754  		}
   755  		reqCnt := len(req.Txs)
   756  		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
   757  			wg.Add(1)
   758  			go func() {
   759  				defer wg.Done()
   760  				stats := make([]light.TxStatus, len(req.Txs))
   761  				for i, tx := range req.Txs {
   762  					if i != 0 && !task.waitOrStop() {
   763  						return
   764  					}
   765  					hash := tx.Hash()
   766  					stats[i] = h.txStatus(hash)
   767  					if stats[i].Status == core.TxStatusUnknown {
   768  						addFn := h.txpool.AddRemotes
   769  						// Add txs synchronously for testing purpose
   770  						if h.addTxsSync {
   771  							addFn = h.txpool.AddRemotesSync
   772  						}
   773  						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
   774  							stats[i].Error = errs[0].Error()
   775  							continue
   776  						}
   777  						stats[i] = h.txStatus(hash)
   778  					}
   779  				}
   780  				reply := p.ReplyTxStatus(req.ReqID, stats)
   781  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   782  				if metrics.EnabledExpensive {
   783  					miscOutTxsPacketsMeter.Mark(1)
   784  					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
   785  					miscServingTimeTxTimer.Update(time.Duration(task.servingTime))
   786  				}
   787  			}()
   788  		}
   789  
   790  	case GetTxStatusMsg:
   791  		p.Log().Trace("Received transaction status query request")
   792  		if metrics.EnabledExpensive {
   793  			miscInTxStatusPacketsMeter.Mark(1)
   794  			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
   795  		}
   796  		var req struct {
   797  			ReqID  uint64
   798  			Hashes []common.Hash
   799  		}
   800  		if err := msg.Decode(&req); err != nil {
   801  			clientErrorMeter.Mark(1)
   802  			return errResp(ErrDecode, "msg %v: %v", msg, err)
   803  		}
   804  		reqCnt := len(req.Hashes)
   805  		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
   806  			wg.Add(1)
   807  			go func() {
   808  				defer wg.Done()
   809  				stats := make([]light.TxStatus, len(req.Hashes))
   810  				for i, hash := range req.Hashes {
   811  					if i != 0 && !task.waitOrStop() {
   812  						sendResponse(req.ReqID, 0, nil, task.servingTime)
   813  						return
   814  					}
   815  					stats[i] = h.txStatus(hash)
   816  				}
   817  				reply := p.ReplyTxStatus(req.ReqID, stats)
   818  				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
   819  				if metrics.EnabledExpensive {
   820  					miscOutTxStatusPacketsMeter.Mark(1)
   821  					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
   822  					miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime))
   823  				}
   824  			}()
   825  		}
   826  
   827  	default:
   828  		p.Log().Trace("Received invalid message", "code", msg.Code)
   829  		clientErrorMeter.Mark(1)
   830  		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
   831  	}
   832  	// If the client has made too much invalid request(e.g. request a non-exist data),
   833  	// reject them to prevent SPAM attack.
   834  	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
   835  		clientErrorMeter.Mark(1)
   836  		return errTooManyInvalidRequest
   837  	}
   838  	return nil
   839  }
   840  
   841  // getAccount retrieves an account from the state based on root.
   842  func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
   843  	trie, err := trie.New(root, triedb)
   844  	if err != nil {
   845  		return state.Account{}, err
   846  	}
   847  	blob, err := trie.TryGet(hash[:])
   848  	if err != nil {
   849  		return state.Account{}, err
   850  	}
   851  	var account state.Account
   852  	if err = rlp.DecodeBytes(blob, &account); err != nil {
   853  		return state.Account{}, err
   854  	}
   855  	return account, nil
   856  }
   857  
   858  // getHelperTrie returns the post-processed trie root for the given trie ID and section index
   859  func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
   860  	switch typ {
   861  	case htCanonical:
   862  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
   863  		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
   864  	case htBloomBits:
   865  		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
   866  		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
   867  	}
   868  	return common.Hash{}, ""
   869  }
   870  
   871  // getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
   872  func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
   873  	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
   874  		blockNum := binary.BigEndian.Uint64(req.Key)
   875  		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
   876  		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
   877  	}
   878  	return nil
   879  }
   880  
   881  // txStatus returns the status of a specified transaction.
   882  func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
   883  	var stat light.TxStatus
   884  	// Looking the transaction in txpool first.
   885  	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
   886  
   887  	// If the transaction is unknown to the pool, try looking it up locally.
   888  	if stat.Status == core.TxStatusUnknown {
   889  		lookup := h.blockchain.GetTransactionLookup(hash)
   890  		if lookup != nil {
   891  			stat.Status = core.TxStatusIncluded
   892  			stat.Lookup = lookup
   893  		}
   894  	}
   895  	return stat
   896  }
   897  
   898  // broadcastHeaders broadcasts new block information to all connected light
   899  // clients. According to the agreement between client and server, server should
   900  // only broadcast new announcement if the total difficulty is higher than the
   901  // last one. Besides server will add the signature if client requires.
   902  func (h *serverHandler) broadcastHeaders() {
   903  	defer h.wg.Done()
   904  
   905  	headCh := make(chan core.ChainHeadEvent, 10)
   906  	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
   907  	defer headSub.Unsubscribe()
   908  
   909  	var (
   910  		lastHead *types.Header
   911  		lastTd   = common.Big0
   912  	)
   913  	for {
   914  		select {
   915  		case ev := <-headCh:
   916  			peers := h.server.peers.AllPeers()
   917  			if len(peers) == 0 {
   918  				continue
   919  			}
   920  			header := ev.Block.Header()
   921  			hash, number := header.Hash(), header.Number.Uint64()
   922  			td := h.blockchain.GetTd(hash, number)
   923  			if td == nil || td.Cmp(lastTd) <= 0 {
   924  				continue
   925  			}
   926  			var reorg uint64
   927  			if lastHead != nil {
   928  				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
   929  			}
   930  			lastHead, lastTd = header, td
   931  
   932  			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
   933  			var (
   934  				signed         bool
   935  				signedAnnounce announceData
   936  			)
   937  			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
   938  			for _, p := range peers {
   939  				p := p
   940  				switch p.announceType {
   941  				case announceTypeSimple:
   942  					p.queueSend(func() { p.SendAnnounce(announce) })
   943  				case announceTypeSigned:
   944  					if !signed {
   945  						signedAnnounce = announce
   946  						signedAnnounce.sign(h.server.privateKey)
   947  						signed = true
   948  					}
   949  					p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
   950  				}
   951  			}
   952  		case <-h.closeCh:
   953  			return
   954  		}
   955  	}
   956  }