github.com/decred/dcrlnd@v0.7.6/chainscan/csdrivers/remotedcrwdriver.go (about)

     1  package csdrivers
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"time"
     9  
    10  	"decred.org/dcrwallet/v4/rpc/walletrpc"
    11  	"github.com/decred/dcrd/chaincfg/chainhash"
    12  	"github.com/decred/dcrd/gcs/v4"
    13  	"github.com/decred/dcrd/gcs/v4/blockcf2"
    14  	"github.com/decred/dcrd/wire"
    15  	"github.com/decred/dcrlnd/blockcache"
    16  	"github.com/decred/dcrlnd/chainscan"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  type RemoteWalletCSDriver struct {
    22  	wsvc walletrpc.WalletServiceClient
    23  	nsvc walletrpc.NetworkServiceClient
    24  
    25  	blockCache *blockcache.BlockCache
    26  
    27  	mtx      sync.Mutex
    28  	ereaders []*eventReader
    29  
    30  	// The following are members that define a cfilter cache which is
    31  	// useful to reduce db contention by reading cfilters in batches.
    32  	cache            []cfilter
    33  	cacheStartHeight int32
    34  }
    35  
    36  // Type assertions to ensure the driver fulfills the correct interfaces.
    37  var _ chainscan.HistoricalChainSource = (*RemoteWalletCSDriver)(nil)
    38  var _ chainscan.TipChainSource = (*RemoteWalletCSDriver)(nil)
    39  
    40  func NewRemoteWalletCSDriver(wsvc walletrpc.WalletServiceClient,
    41  	nsvc walletrpc.NetworkServiceClient, bcache *blockcache.BlockCache) *RemoteWalletCSDriver {
    42  	return &RemoteWalletCSDriver{
    43  		wsvc:       wsvc,
    44  		nsvc:       nsvc,
    45  		cache:      make([]cfilter, 0, cacheCapHint),
    46  		blockCache: bcache,
    47  	}
    48  }
    49  
    50  func (d *RemoteWalletCSDriver) signalEventReaders(e chainscan.ChainEvent) {
    51  	d.mtx.Lock()
    52  	readers := d.ereaders
    53  	d.mtx.Unlock()
    54  
    55  	for _, er := range readers {
    56  		select {
    57  		case <-er.ctx.Done():
    58  		case er.c <- e:
    59  		}
    60  	}
    61  }
    62  
    63  func (d *RemoteWalletCSDriver) singleCFilter(ctx context.Context, hash *chainhash.Hash) ([16]byte, *gcs.FilterV2, error) {
    64  	req := &walletrpc.GetCFiltersRequest{
    65  		StartingBlockHash: hash[:],
    66  		EndingBlockHash:   hash[:],
    67  	}
    68  
    69  	stream, err := d.wsvc.GetCFilters(ctx, req)
    70  	if err != nil {
    71  		return [16]byte{}, nil, err
    72  	}
    73  	resp, err := stream.Recv()
    74  	if err != nil {
    75  		return [16]byte{}, nil, err
    76  	}
    77  	var key [16]byte
    78  
    79  	filter, err := gcs.FromBytesV2(blockcf2.B, blockcf2.M, resp.Filter)
    80  	if err != nil {
    81  		return key, nil, err
    82  	}
    83  
    84  	copy(key[:], resp.Key)
    85  
    86  	return key, filter, nil
    87  }
    88  
    89  func (d *RemoteWalletCSDriver) Run(ctx context.Context) error {
    90  	req := &walletrpc.TransactionNotificationsRequest{}
    91  	stream, err := d.wsvc.TransactionNotifications(ctx, req)
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	getHeaderReq := &walletrpc.BlockInfoRequest{}
    97  
    98  nextntfn:
    99  	for {
   100  		var resp *walletrpc.TransactionNotificationsResponse
   101  		resp, err = stream.Recv()
   102  		if err != nil {
   103  			break
   104  		}
   105  
   106  		for _, b := range resp.DetachedBlockHeaders {
   107  			var hash, prevHash *chainhash.Hash
   108  			hash, err = chainhash.NewHash(b.Hash)
   109  			if err != nil {
   110  				break nextntfn
   111  			}
   112  			prevHash, err = chainhash.NewHash(b.PrevBlock)
   113  			if err != nil {
   114  				break nextntfn
   115  			}
   116  
   117  			var headerRes *walletrpc.BlockInfoResponse
   118  			getHeaderReq.BlockHash = hash[:]
   119  			headerRes, err = d.wsvc.BlockInfo(ctx, getHeaderReq)
   120  			if err != nil {
   121  				break nextntfn
   122  			}
   123  			header := new(wire.BlockHeader)
   124  			err = header.FromBytes(headerRes.BlockHeader)
   125  			if err != nil {
   126  				break nextntfn
   127  			}
   128  
   129  			e := chainscan.BlockDisconnectedEvent{
   130  				Hash:     *hash,
   131  				Height:   b.Height,
   132  				PrevHash: *prevHash,
   133  				Header:   header,
   134  			}
   135  
   136  			d.signalEventReaders(e)
   137  		}
   138  
   139  		for _, bl := range resp.AttachedBlocks {
   140  			// Shouldn't happen, but play it safe.
   141  			if bl.Height <= 0 {
   142  				continue
   143  			}
   144  
   145  			var hash, prevHash *chainhash.Hash
   146  			hash, err = chainhash.NewHash(bl.Hash)
   147  			if err != nil {
   148  				break nextntfn
   149  			}
   150  			prevHash, err = chainhash.NewHash(bl.PrevBlock)
   151  			if err != nil {
   152  				break nextntfn
   153  			}
   154  
   155  			var cfKey [16]byte
   156  			var filter *gcs.FilterV2
   157  			cfKey, filter, err = d.singleCFilter(ctx, hash)
   158  			if err != nil {
   159  				break nextntfn
   160  			}
   161  
   162  			var headerRes *walletrpc.BlockInfoResponse
   163  			getHeaderReq.BlockHash = hash[:]
   164  			headerRes, err = d.wsvc.BlockInfo(ctx, getHeaderReq)
   165  			if err != nil {
   166  				break nextntfn
   167  			}
   168  			header := new(wire.BlockHeader)
   169  			err = header.FromBytes(headerRes.BlockHeader)
   170  			if err != nil {
   171  				break nextntfn
   172  			}
   173  
   174  			e := chainscan.BlockConnectedEvent{
   175  				PrevHash: *prevHash,
   176  				Hash:     *hash,
   177  				Height:   bl.Height,
   178  				CFKey:    cfKey,
   179  				Filter:   filter,
   180  				Header:   header,
   181  			}
   182  			d.signalEventReaders(e)
   183  		}
   184  	}
   185  
   186  	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
   187  		log.Tracef("RemoteWalletCSDriver run context done: %v", err)
   188  	} else if err != nil {
   189  		log.Errorf("RemoteWalletCSDriver run errored: %v", err)
   190  	}
   191  	return err
   192  }
   193  
   194  func (d *RemoteWalletCSDriver) ChainEvents(ctx context.Context) <-chan chainscan.ChainEvent {
   195  	er := &eventReader{
   196  		ctx: ctx,
   197  		c:   make(chan chainscan.ChainEvent),
   198  	}
   199  	d.mtx.Lock()
   200  	d.ereaders = append(d.ereaders, er)
   201  	d.mtx.Unlock()
   202  
   203  	return er.c
   204  }
   205  
   206  func (d *RemoteWalletCSDriver) GetBlock(ctx context.Context, bh *chainhash.Hash) (*wire.MsgBlock, error) {
   207  	return d.blockCache.GetBlock(ctx, bh, d.getBlock)
   208  }
   209  
   210  func (d *RemoteWalletCSDriver) getBlock(ctx context.Context, bh *chainhash.Hash) (*wire.MsgBlock, error) {
   211  	var (
   212  		resp *walletrpc.GetRawBlockResponse
   213  		err  error
   214  	)
   215  
   216  	req := &walletrpc.GetRawBlockRequest{
   217  		BlockHash: bh[:],
   218  	}
   219  
   220  	// If the response error code is 'Unavailable' it means the wallet
   221  	// isn't connected to any peers while in SPV mode. In that case, wait a
   222  	// bit and try again.
   223  	for stop := false; !stop; {
   224  		resp, err = d.nsvc.GetRawBlock(ctx, req)
   225  		switch {
   226  		case status.Code(err) == codes.Unavailable:
   227  			log.Warnf("Network unavailable from wallet; will try again in 5 seconds")
   228  			select {
   229  			case <-ctx.Done():
   230  				return nil, ctx.Err()
   231  			case <-time.After(5 * time.Second):
   232  			}
   233  		case err != nil:
   234  			return nil, err
   235  		default:
   236  			stop = true
   237  		}
   238  	}
   239  
   240  	bl := &wire.MsgBlock{}
   241  	err = bl.FromBytes(resp.Block)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  
   246  	return bl, nil
   247  }
   248  
   249  func (d *RemoteWalletCSDriver) CurrentTip(ctx context.Context) (*chainhash.Hash, int32, error) {
   250  	resp, err := d.wsvc.BestBlock(ctx, &walletrpc.BestBlockRequest{})
   251  	if err != nil {
   252  		return nil, 0, err
   253  	}
   254  	bh, err := chainhash.NewHash(resp.Hash)
   255  	if err != nil {
   256  		return nil, 0, err
   257  	}
   258  
   259  	return bh, int32(resp.Height), nil
   260  }
   261  
   262  // GetCFilter is part of the chainscan.HistoricalChainSource interface.
   263  //
   264  // NOTE: The returned chainhash pointer is not safe for storage as it belongs
   265  // to a cache entry. This is fine for use on a chainscan.Historical scanner
   266  // since it never stores or leaks the pointer itself.
   267  func (d *RemoteWalletCSDriver) GetCFilter(ctx context.Context, height int32) (*chainhash.Hash, [16]byte, *gcs.FilterV2, error) {
   268  	// Fast track when data is in memory.
   269  	if height >= d.cacheStartHeight && height < d.cacheStartHeight+int32(len(d.cache)) {
   270  		i := int(height - d.cacheStartHeight)
   271  		c := &d.cache[i]
   272  		return &c.hash, c.key, c.filter, nil
   273  	}
   274  
   275  	// Use a new ctx so we can canel halfway through.
   276  	ctxReq, cancel := context.WithCancel(ctx)
   277  	defer cancel()
   278  
   279  	// Read a bunch of CFilters in one go since we're likely to be
   280  	// requested the next few ones.
   281  	d.cache = d.cache[:cap(d.cache)]
   282  	i := 0
   283  
   284  tryconn:
   285  	for {
   286  		req := &walletrpc.GetCFiltersRequest{
   287  			StartingBlockHeight: height + int32(i),
   288  		}
   289  
   290  		if i != 0 {
   291  			log.Tracef("Attempting to refetch at %d",
   292  				req.StartingBlockHeight)
   293  		}
   294  
   295  		stream, err := d.wsvc.GetCFilters(ctxReq, req)
   296  		if status.Code(err) == codes.Unavailable {
   297  			// Wallet may be temporarily offline. Backoff and try
   298  			// again in a bit.
   299  			select {
   300  			case <-time.After(time.Second):
   301  				continue
   302  			case <-ctx.Done():
   303  				return nil, [16]byte{}, nil, ctx.Err()
   304  			}
   305  		}
   306  		if err != nil {
   307  			return nil, [16]byte{}, nil, err
   308  		}
   309  		for {
   310  			var key [16]byte
   311  			resp, err := stream.Recv()
   312  			if errors.Is(err, io.EOF) {
   313  				break tryconn
   314  			} else if status.Code(err) == codes.Unavailable {
   315  				log.Tracef("Broke connection at height %d",
   316  					req.StartingBlockHeight+int32(i))
   317  				continue tryconn
   318  			} else if err != nil {
   319  				return nil, key, nil, err
   320  			}
   321  
   322  			bh, err := chainhash.NewHash(resp.BlockHash)
   323  			if err != nil {
   324  				return nil, key, nil, err
   325  			}
   326  
   327  			filter, err := gcs.FromBytesV2(blockcf2.B, blockcf2.M, resp.Filter)
   328  			if err != nil {
   329  				return nil, key, nil, err
   330  			}
   331  			copy(key[:], resp.Key)
   332  
   333  			d.cache[i] = cfilter{
   334  				hash:   *bh,
   335  				height: height + int32(i),
   336  				key:    key,
   337  				filter: filter,
   338  			}
   339  
   340  			// Stop if the cache has been filled.
   341  			i++
   342  			if i >= cap(d.cache) {
   343  				break tryconn
   344  			}
   345  		}
   346  	}
   347  
   348  	// If we didn't read any filters from the db, it means we were
   349  	// requested a filter past the current mainchain tip. Inform the
   350  	// appropriate error in this case.
   351  	if i == 0 {
   352  		return nil, [16]byte{}, nil, chainscan.ErrBlockAfterTip{Height: height}
   353  	}
   354  
   355  	// Clear out unused entries so we don't keep a reference to the filters
   356  	// forever.
   357  	for j := i; j < cap(d.cache); j++ {
   358  		d.cache[j].filter = nil
   359  	}
   360  
   361  	// Keep track of correct cache start and size.
   362  	d.cache = d.cache[:i]
   363  	d.cacheStartHeight = height
   364  
   365  	// The desired filter is the first one.
   366  	c := &d.cache[0]
   367  	return &c.hash, c.key, c.filter, nil
   368  }