decred.org/dcrwallet/v3@v3.1.0/cmd/repaircfilters/main.go (about)

     1  // Copyright (c) 2020 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"encoding/binary"
    12  	"encoding/hex"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"net"
    17  	"os"
    18  	"path/filepath"
    19  
    20  	"github.com/decred/dcrd/chaincfg/chainhash"
    21  	"github.com/decred/dcrd/chaincfg/v3"
    22  	"github.com/decred/dcrd/crypto/blake256"
    23  	"github.com/decred/dcrd/dcrutil/v4"
    24  	"github.com/decred/dcrd/wire"
    25  	"github.com/jessevdk/go-flags"
    26  	"github.com/jrick/wsrpc/v2"
    27  	"golang.org/x/term"
    28  )
    29  
    30  var (
    31  	walletDataDirectory = dcrutil.AppDataDir("dcrwallet", false)
    32  	newlineBytes        = []byte{'\n'}
    33  )
    34  
    35  var opts = struct {
    36  	TestNet            bool   `long:"testnet" description:"Use the test decred network"`
    37  	RPCConnect         string `short:"c" long:"connect" description:"Hostname[:port] of wallet RPC server"`
    38  	RPCUsername        string `short:"u" long:"rpcuser" description:"Wallet RPC username"`
    39  	RPCPassword        string `short:"P" long:"rpcpass" description:"Wallet RPC password"`
    40  	RPCCertificateFile string `long:"cafile" description:"Wallet RPC TLS certificate"`
    41  	CFiltersFile       string `long:"cfiltersfile" description:"Binary file with pre-dcp0005 filter data"`
    42  }{
    43  	TestNet:            false,
    44  	RPCConnect:         "localhost",
    45  	RPCUsername:        "",
    46  	RPCPassword:        "",
    47  	RPCCertificateFile: filepath.Join(walletDataDirectory, "rpc.cert"),
    48  }
    49  
    50  func fatalf(format string, args ...interface{}) {
    51  	fmt.Fprintf(os.Stderr, format, args...)
    52  	os.Stderr.Write(newlineBytes)
    53  	os.Exit(1)
    54  }
    55  
    56  func errContext(err error, context string) error {
    57  	return fmt.Errorf("%s: %v", context, err)
    58  }
    59  
    60  // normalizeAddress returns the normalized form of the address, adding a
    61  // default port if necessary.  An error is returned if the address, even
    62  // without a port, is not valid.
    63  func normalizeAddress(addr string, defaultPort string) (hostport string, err error) {
    64  	// If the first SplitHostPort errors because of a missing port and not
    65  	// for an invalid host, add the port.  If the second SplitHostPort
    66  	// fails, then a port is not missing and the original error should be
    67  	// returned.
    68  	host, port, origErr := net.SplitHostPort(addr)
    69  	if origErr == nil {
    70  		return net.JoinHostPort(host, port), nil
    71  	}
    72  	addr = net.JoinHostPort(addr, defaultPort)
    73  	_, _, err = net.SplitHostPort(addr)
    74  	if err != nil {
    75  		return "", origErr
    76  	}
    77  	return addr, nil
    78  }
    79  
    80  func walletPort(net *chaincfg.Params) string {
    81  	switch net.Net {
    82  	case wire.MainNet:
    83  		return "9110"
    84  	case wire.TestNet3:
    85  		return "19110"
    86  	case wire.SimNet:
    87  		return "19557"
    88  	default:
    89  		return ""
    90  	}
    91  }
    92  
    93  // Parse and validate flags.
    94  func init() {
    95  	// Unset localhost defaults if certificate file can not be found.
    96  	_, err := os.Stat(opts.RPCCertificateFile)
    97  	if err != nil {
    98  		opts.RPCConnect = ""
    99  		opts.RPCCertificateFile = ""
   100  	}
   101  
   102  	_, err = flags.Parse(&opts)
   103  	if err != nil {
   104  		os.Exit(1)
   105  	}
   106  
   107  	var activeNet = chaincfg.MainNetParams()
   108  	if opts.TestNet {
   109  		activeNet = chaincfg.TestNet3Params()
   110  	}
   111  
   112  	if opts.RPCConnect == "" {
   113  		fatalf("RPC hostname[:port] is required")
   114  	}
   115  	rpcConnect, err := normalizeAddress(opts.RPCConnect, walletPort(activeNet))
   116  	if err != nil {
   117  		fatalf("Invalid RPC network address `%v`: %v", opts.RPCConnect, err)
   118  	}
   119  	opts.RPCConnect = rpcConnect
   120  
   121  	if opts.RPCUsername == "" {
   122  		fatalf("RPC username is required")
   123  	}
   124  
   125  	_, err = os.Stat(opts.RPCCertificateFile)
   126  	if err != nil {
   127  		fatalf("RPC certificate file `%s` not found", opts.RPCCertificateFile)
   128  	}
   129  
   130  	if opts.CFiltersFile == "" {
   131  		fatalf("CFilters file is required")
   132  	}
   133  }
   134  
   135  func promptSecret(what string) (string, error) {
   136  	fmt.Printf("%s: ", what)
   137  	fd := int(os.Stdin.Fd())
   138  	input, err := term.ReadPassword(fd)
   139  	fmt.Println()
   140  	if err != nil {
   141  		return "", err
   142  	}
   143  	return string(input), nil
   144  }
   145  
   146  func repair() error {
   147  	rpcPassword := opts.RPCPassword
   148  
   149  	if rpcPassword == "" {
   150  		secret, err := promptSecret("Wallet RPC password")
   151  		if err != nil {
   152  			return errContext(err, "failed to read RPC password")
   153  		}
   154  
   155  		rpcPassword = secret
   156  	}
   157  
   158  	cffile, err := os.Open(opts.CFiltersFile)
   159  	if err != nil {
   160  		return errContext(err, "failed to open cfilters file")
   161  	}
   162  	defer cffile.Close()
   163  
   164  	ctx := context.Background()
   165  	rpcCertificate, err := os.ReadFile(opts.RPCCertificateFile)
   166  	if err != nil {
   167  		return errContext(err, "failed to read RPC certificate")
   168  	}
   169  	rpcopts := make([]wsrpc.Option, 0, 5)
   170  	rpcopts = append(rpcopts, wsrpc.WithBasicAuth(opts.RPCUsername, rpcPassword))
   171  	rpcopts = append(rpcopts, wsrpc.WithoutPongDeadline())
   172  	pool := x509.NewCertPool()
   173  	pool.AppendCertsFromPEM(rpcCertificate)
   174  	tc := &tls.Config{
   175  		RootCAs: pool,
   176  	}
   177  	addr := "wss://" + opts.RPCConnect + "/ws"
   178  	rpcopts = append(rpcopts, wsrpc.WithTLSConfig(tc))
   179  	client, err := wsrpc.Dial(ctx, addr, rpcopts...)
   180  	if err != nil {
   181  		return errContext(err, "failed to connect to the wallet")
   182  	}
   183  	defer client.Close()
   184  
   185  	hasher := blake256.New()
   186  
   187  	// Read in 64KiB steps. No individual cfilter will be larger than this.
   188  	cfbuf := make([]byte, 65536)
   189  	readOffset := 0
   190  	height := int32(0)
   191  	for {
   192  		n, ioErr := cffile.Read(cfbuf[readOffset:])
   193  		if err != nil && !errors.Is(ioErr, io.EOF) {
   194  			return errContext(err, "cfiltersfile read error")
   195  		}
   196  
   197  		var filters []string
   198  
   199  		// Split the buffer into as many filters as needed for the next
   200  		// cmd. Each "record" in the file is 2 bytes for an uint16
   201  		// (size of cfilter) + n bytes for the cfilter data.
   202  		nextcf := cfbuf[:readOffset+n]
   203  		for len(nextcf) > 1 {
   204  			cflen := binary.BigEndian.Uint16(nextcf)
   205  			if int(cflen) > len(nextcf)-2 {
   206  				// Reached the end of this block of cfilters.
   207  				break
   208  			}
   209  
   210  			var cf []byte
   211  			cf, nextcf = nextcf[2:cflen+2], nextcf[2+cflen:]
   212  			hasher.Write(cf)
   213  			cfhex := hex.EncodeToString(cf)
   214  			filters = append(filters, cfhex)
   215  		}
   216  
   217  		// Import this batch of cfilters.
   218  		if len(filters) > 0 {
   219  			err = client.Call(ctx, "importcfiltersv2", nil, height, filters)
   220  			if err != nil {
   221  				return errContext(err, "failed to import cfilters")
   222  			}
   223  		}
   224  
   225  		// Advance to next batch.
   226  		height += int32(len(filters))
   227  		copy(cfbuf[0:], nextcf)
   228  		readOffset = len(nextcf)
   229  
   230  		// Finish only after processing any data that might have been
   231  		// returned at the last buffered read.
   232  		if errors.Is(ioErr, io.EOF) {
   233  			break
   234  		}
   235  	}
   236  
   237  	var cfsetHash chainhash.Hash
   238  	cfsetHash.SetBytes(hasher.Sum(nil))
   239  	fmt.Printf("Hash of cf data sent: %s\n", cfsetHash)
   240  	fmt.Printf("Height: %d\n", height-1)
   241  
   242  	return nil
   243  }
   244  
   245  func main() {
   246  	err := repair()
   247  	if err != nil {
   248  		fatalf("%v", err)
   249  	}
   250  }