github.com/mit-dci/lit@v0.0.0-20221102210550-8c3d3b49f2ce/uspv/init.go (about)

     1  package uspv
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net"
     8  	"os"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/mit-dci/lit/lnutil"
    13  	"github.com/mit-dci/lit/logging"
    14  	"github.com/mit-dci/lit/wire"
    15  	"golang.org/x/net/proxy"
    16  )
    17  
    18  // IP4 ...
    19  func IP4(ipAddress string) bool {
    20  	parseIP := net.ParseIP(ipAddress)
    21  	if parseIP.To4() == nil {
    22  		return false
    23  	}
    24  	return true
    25  }
    26  
    27  func (s *SPVCon) parseRemoteNode(remoteNode string) (string, string, error) {
    28  	colonCount := strings.Count(remoteNode, ":")
    29  	var conMode string
    30  	if colonCount <= 1 {
    31  		if colonCount == 0 {
    32  			remoteNode = remoteNode + ":" + s.Param.DefaultPort
    33  		}
    34  		return remoteNode, "tcp4", nil
    35  	} else if colonCount >= 5 {
    36  		// ipv6 without remote port
    37  		// assume users don't give ports with ipv6 nodes
    38  		if !strings.Contains(remoteNode, "[") && !strings.Contains(remoteNode, "]") {
    39  			remoteNode = "[" + remoteNode + "]" + ":" + s.Param.DefaultPort
    40  		}
    41  		conMode = "tcp6"
    42  		return remoteNode, conMode, nil
    43  	} else {
    44  		return "", "", fmt.Errorf("Invalid ip")
    45  	}
    46  }
    47  
    48  // GetListOfNodes contacts all DNSSeeds for the coin specified and then contacts
    49  // each one of them in order to receive a list of ips and then returns a combined
    50  // list
    51  func (s *SPVCon) GetListOfNodes() ([]string, error) {
    52  	var listOfNodes []string // slice of IP addrs returned from the DNS seed
    53  	logging.Infof("Attempting to retrieve peers to connect to based on DNS Seed\n")
    54  
    55  	for _, seed := range s.Param.DNSSeeds {
    56  		temp, err := net.LookupHost(seed)
    57  		// need this temp in order to capture the error from net.LookupHost
    58  		// also need this to report the number of IPs we get from a seed
    59  		if err != nil {
    60  			logging.Infof("Have difficulty trying to connect to %s. Going to the next seed", seed)
    61  			continue
    62  		}
    63  		listOfNodes = append(listOfNodes, temp...)
    64  		logging.Infof("Got %d IPs from DNS seed %s\n", len(temp), seed)
    65  	}
    66  	if len(listOfNodes) == 0 {
    67  		return nil, fmt.Errorf("No peers found connected to DNS Seeds. Please provide a host to connect to.")
    68  	}
    69  	logging.Info(listOfNodes)
    70  	return listOfNodes, nil
    71  }
    72  
    73  // DialNode receives a list of node ips and then tries to connect to them one by one.
    74  func (s *SPVCon) DialNode(listOfNodes []string) error {
    75  
    76  	// now have some IPs, go through and try to connect to one.
    77  	var err error
    78  	for i, ip := range listOfNodes {
    79  		// try to connect to all nodes in this range
    80  		var conString, conMode string
    81  		// need to check whether conString is ipv4 or ipv6
    82  		conString, conMode, err = s.parseRemoteNode(ip)
    83  		if err != nil {
    84  			logging.Infof("parse error for node (skipped): %s", err)
    85  			continue
    86  		}
    87  		logging.Infof("Attempting connection to node at %s\n",
    88  			conString)
    89  
    90  		if s.ProxyURL != "" {
    91  			logging.Infof("Attempting to connect via proxy %s", s.ProxyURL)
    92  			d, err := proxy.SOCKS5("tcp", s.ProxyURL, nil, proxy.Direct)
    93  			if err != nil {
    94  				return err
    95  			}
    96  			s.con, err = d.Dial(conMode, conString)
    97  		} else {
    98  			d := net.Dialer{Timeout: 2 * time.Second}
    99  			s.con, err = d.Dial(conMode, conString)
   100  		}
   101  
   102  		if err != nil {
   103  			if i != len(listOfNodes)-1 {
   104  				logging.Warn(err.Error())
   105  				continue
   106  			} else if i == len(listOfNodes)-1 {
   107  				logging.Error(err)
   108  				// all nodes have been exhausted, we move on to the next one, if any.
   109  				return fmt.Errorf(" Tried to connect to all available node Addresses. Failed")
   110  			}
   111  		}
   112  		break
   113  	}
   114  
   115  	if s.con == nil {
   116  		return fmt.Errorf("Failed to connect to a coin daemon")
   117  	}
   118  
   119  	return nil
   120  }
   121  
   122  // Handshake ...
   123  func (s *SPVCon) Handshake(listOfNodes []string) error {
   124  	// assign version bits for local node
   125  	s.localVersion = VERSION
   126  	myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	err = myMsgVer.AddUserAgent("lit", "v0.1")
   131  	if err != nil {
   132  		return err
   133  	}
   134  	// must set this to enable SPV stuff
   135  	myMsgVer.AddService(wire.SFNodeBloom)
   136  	// set this to enable segWit
   137  	myMsgVer.AddService(wire.SFNodeWitness)
   138  	// this actually sends
   139  	n, err := wire.WriteMessageWithEncodingN(
   140  		s.con, myMsgVer, s.localVersion,
   141  		wire.BitcoinNet(s.Param.NetMagicBytes), wire.LatestEncoding)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	s.WBytes += uint64(n)
   146  	logging.Infof("wrote %d byte version message to %s\n",
   147  		n, s.con.RemoteAddr().String())
   148  	n, m, b, err := wire.ReadMessageWithEncodingN(
   149  		s.con, s.localVersion,
   150  		wire.BitcoinNet(s.Param.NetMagicBytes), wire.LatestEncoding)
   151  	if err != nil {
   152  		logging.Error(err)
   153  		return err
   154  	}
   155  	s.RBytes += uint64(n)
   156  	logging.Infof("got %d byte response %x\n command: %s\n", n, b, m.Command())
   157  
   158  	mv, ok := m.(*wire.MsgVersion)
   159  	if ok {
   160  		logging.Infof("connected to %s", mv.UserAgent)
   161  	} else {
   162  		return fmt.Errorf("Message wrong type and/or nil pointer received, mv: %v", mv)
   163  	}
   164  
   165  	if mv.ProtocolVersion < 70013 {
   166  		//70014 -> core v0.13.1, so we should be fine
   167  		return fmt.Errorf("Remote node version: %x too old, disconnecting.", mv.ProtocolVersion)
   168  	}
   169  
   170  	if !((strings.Contains(s.Param.Name, "lite") && strings.Contains(mv.UserAgent, "LitecoinCore")) || strings.Contains(mv.UserAgent, "Satoshi") || strings.Contains(mv.UserAgent, "btcd")) && (len(listOfNodes) != 0) {
   171  		// TODO: improve this filtering criterion
   172  		return fmt.Errorf("Couldn't connect to this node. Returning!")
   173  	}
   174  
   175  	logging.Infof("remote reports version %x (dec %d)\n",
   176  		mv.ProtocolVersion, mv.ProtocolVersion)
   177  
   178  	// set remote height
   179  	s.remoteHeight = mv.LastBlock
   180  	// set remote version
   181  	s.remoteVersion = uint32(mv.ProtocolVersion)
   182  
   183  	mva := wire.NewMsgVerAck()
   184  	n, err = wire.WriteMessageWithEncodingN(
   185  		s.con, mva, s.localVersion,
   186  		wire.BitcoinNet(s.Param.NetMagicBytes), wire.LatestEncoding)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	s.WBytes += uint64(n)
   191  	return nil
   192  }
   193  
   194  // Connect dials out and connects to full nodes. Calls GetListOfNodes to get the
   195  // list of nodes if the user has specified a YupString. Else, moves on to dial
   196  // the node to see if its up and establishes a connection followed by Handshake()
   197  // which sends out wire messages, checks for version string to prevent spam, etc.
   198  func (s *SPVCon) Connect(remoteNode string) error {
   199  	var err error
   200  	var listOfNodes []string
   201  	if lnutil.YupString(remoteNode) { // TODO Make this better.  Perhaps a "connection target"?
   202  		s.randomNodesOK = true
   203  		// if remoteNode is "yes" but no IP specified, use DNS seed
   204  		listOfNodes, err = s.GetListOfNodes()
   205  		if err != nil {
   206  			logging.Error(err)
   207  			return err
   208  			// automatically quit if there are no other hosts to connect to.
   209  		}
   210  	} else { // else connect to user-specified node
   211  		listOfNodes = []string{remoteNode}
   212  	}
   213  	handShakeFailed := false // need to be in this scope to access it here
   214  	connEstablished := false
   215  	for len(listOfNodes) != 0 && !connEstablished {
   216  		err = s.DialNode(listOfNodes)
   217  		if err != nil {
   218  			logging.Error(err)
   219  			logging.Infof("Couldn't dial node %s, Moving on", listOfNodes[0])
   220  			listOfNodes = listOfNodes[1:]
   221  			continue
   222  		}
   223  		err = s.Handshake(listOfNodes)
   224  		if err != nil {
   225  			// spam node or some other problem. Delete node from list and try again
   226  			handShakeFailed = true
   227  			logging.Infof("Handshake with %s failed. Moving on. Error: %s", listOfNodes[0], err.Error())
   228  			if len(listOfNodes) == 1 { // this is the last node, error out
   229  				return fmt.Errorf("Couldn't establish connection with any remote node. Exiting.")
   230  			}
   231  			logging.Error("Couldn't establish connection with node. Proceeding to the next one")
   232  			listOfNodes = listOfNodes[1:]
   233  			connEstablished = false
   234  		} else {
   235  			connEstablished = true
   236  		}
   237  	}
   238  
   239  	if !handShakeFailed && !connEstablished {
   240  		// this case happens when user provided node fails to connect
   241  		return fmt.Errorf("Couldn't establish connection with node. Exiting.")
   242  	}
   243  	if handShakeFailed && !connEstablished {
   244  		// this case is when the last node fails and we continue, only to exit the
   245  		// loop and execute below code, which is unnecessary.
   246  		return fmt.Errorf("Couldn't establish connection with any remote node after an instance of handshake. Exiting.")
   247  	}
   248  	go s.incomingMessageHandler()
   249  	go s.outgoingMessageHandler()
   250  
   251  	if s.HardMode {
   252  		s.blockQueue = make(chan HashAndHeight, 32) // queue depth 32 for hardmode.
   253  	} else {
   254  		// for SPV, concurrent in-flight merkleblocks makes us miss txs.
   255  		// The BloomUpdateAll setting seems like it should prevent it, but it
   256  		// doesn't; occasionally it misses transactions, seems like with low
   257  		// block index.  Could be a bug somewhere.  1 at a time merkleblock
   258  		// seems OK.
   259  		s.blockQueue = make(chan HashAndHeight, 1) // queue depth 1 for spv
   260  	}
   261  	s.fPositives = make(chan int32, 4000) // a block full, approx
   262  	s.inWaitState = make(chan bool, 1)
   263  	go s.fPositiveHandler()
   264  
   265  	// if s.HardMode { // what about for non-hard?  send filter?
   266  	// 	Ignore filters now; switch to filters fed to SPVcon from TS
   267  	// 		filt, err := s.TS.GimmeFilter()
   268  	// 		if err != nil {
   269  	// 			return err
   270  	// 		}
   271  	// 		s.localFilter = filt
   272  	// 		//		s.Refilter(filt)
   273  	// }
   274  
   275  	return nil
   276  }
   277  
   278  /*
   279  Truncated header files
   280  Like a regular header but the first 80 bytes is mostly empty.
   281  The very first 4 bytes (big endian) says what height the empty 80 bytes
   282  replace.  The next header, starting at offset 80, needs to be valid.
   283  */
   284  func (s *SPVCon) openHeaderFile(hfn string) error {
   285  	_, err := os.Stat(hfn)
   286  	if err != nil {
   287  		if os.IsNotExist(err) {
   288  			var b bytes.Buffer
   289  			// if StartHeader is defined, start with hardcoded height
   290  			if s.Param.StartHeight != 0 {
   291  				hdr := s.Param.StartHeader
   292  				_, err := b.Write(hdr[:])
   293  				if err != nil {
   294  					return err
   295  				}
   296  			} else {
   297  				err = s.Param.GenesisBlock.Header.Serialize(&b)
   298  				if err != nil {
   299  					return err
   300  				}
   301  			}
   302  			err = ioutil.WriteFile(hfn, b.Bytes(), 0600)
   303  			if err != nil {
   304  				return err
   305  			}
   306  			logging.Infof("made genesis header %x\n", b.Bytes())
   307  			logging.Infof("made genesis hash %s\n", s.Param.GenesisHash.String())
   308  			logging.Infof("created hardcoded genesis header at %s\n", hfn)
   309  		}
   310  	}
   311  
   312  	if s.Param.StartHeight != 0 {
   313  		s.headerStartHeight = s.Param.StartHeight
   314  	}
   315  
   316  	s.headerFile, err = os.OpenFile(hfn, os.O_RDWR, 0600)
   317  	if err != nil {
   318  		return err
   319  	}
   320  	logging.Infof("opened header file %s\n", s.headerFile.Name())
   321  	return nil
   322  }