github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/cmd/geph-exit/handle.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto/ed25519"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/http"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	log "github.com/sirupsen/logrus"
    18  
    19  	"github.com/ethereum/go-ethereum/rlp"
    20  	"github.com/geph-official/geph2/libs/backedtcp"
    21  	"github.com/geph-official/geph2/libs/cwl"
    22  	"github.com/geph-official/geph2/libs/tinyss"
    23  	"github.com/hashicorp/yamux"
    24  	"github.com/xtaci/smux"
    25  	"golang.org/x/time/rate"
    26  )
    27  
    28  // blacklist of local networks
    29  var cidrBlacklist []*net.IPNet
    30  
    31  func init() {
    32  	for _, s := range []string{
    33  		"127.0.0.1/8",
    34  		"10.0.0.0/8",
    35  		"172.16.0.0/12",
    36  		"192.168.0.0/16",
    37  		"::1/128",
    38  	} {
    39  		_, n, _ := net.ParseCIDR(s)
    40  		cidrBlacklist = append(cidrBlacklist, n)
    41  	}
    42  }
    43  
    44  func isBlack(addr *net.TCPAddr) bool {
    45  	for _, n := range cidrBlacklist {
    46  		if n.Contains(addr.IP) {
    47  			return true
    48  		}
    49  	}
    50  	return false
    51  }
    52  
    53  var tunnCount uint64
    54  
    55  func init() {
    56  	go func() {
    57  		for {
    58  			time.Sleep(time.Second * 10)
    59  			if statClient != nil {
    60  				statClient.Send(map[string]string{
    61  					hostname + ".sessionCount": fmt.Sprintf("%v|g",
    62  						freeSessCounter.ItemCount()+paidSessCounter.ItemCount()),
    63  				}, 1)
    64  				statClient.Send(map[string]string{
    65  					hostname + ".tunnelCount": fmt.Sprintf("%v|g", atomic.LoadUint64(&tunnCount)),
    66  				}, 1)
    67  				statClient.Send(map[string]string{
    68  					hostname + ".freeSessionCount": fmt.Sprintf("%v|g",
    69  						freeSessCounter.ItemCount()),
    70  				}, 1)
    71  				statClient.Send(map[string]string{
    72  					hostname + ".paidSessionCount": fmt.Sprintf("%v|g",
    73  						paidSessCounter.ItemCount()),
    74  				}, 1)
    75  			}
    76  		}
    77  	}()
    78  }
    79  
    80  func handle(rawClient net.Conn) {
    81  	log.Println("handle called with", rawClient.RemoteAddr())
    82  	rawClient.SetDeadline(time.Now().Add(time.Second * 30))
    83  	tssClient, err := tinyss.Handshake(rawClient, 0)
    84  	if err != nil {
    85  		rawClient.Close()
    86  		return
    87  	}
    88  	log.Println("tssClient with prot", tssClient.NextProt())
    89  	// HACK: it's bridged if the remote address has a dot in it
    90  	//isBridged := strings.Contains(rawClient.RemoteAddr().String(), ".")
    91  	// sign the shared secret
    92  	ssSignature := ed25519.Sign(seckey, tssClient.SharedSec())
    93  	rlp.Encode(tssClient, &ssSignature)
    94  	var limiter *rate.Limiter
    95  	limiter = infiniteLimit
    96  	slowLimit := false
    97  	// "generic" stuff
    98  	var acceptStream func() (net.Conn, error)
    99  	if singleHop == "" {
   100  		// authenticate the client
   101  		var greeting [2][]byte
   102  		err = rlp.Decode(tssClient, &greeting)
   103  		if err != nil {
   104  			log.Println("Error decoding greeting from", rawClient.RemoteAddr(), err)
   105  			tssClient.Close()
   106  			return
   107  		}
   108  		err = bclient.RedeemTicket("paid", greeting[0], greeting[1])
   109  		if err != nil {
   110  			if onlyPaid {
   111  				log.Printf("%v isn't paid and we only accept paid %v. Failing!", rawClient.RemoteAddr(), err)
   112  				rlp.Encode(tssClient, "FAIL")
   113  				tssClient.Close()
   114  				return
   115  			}
   116  			err = bclient.RedeemTicket("free", greeting[0], greeting[1])
   117  			if err != nil {
   118  				log.Printf("%v isn't free either %v. fail", rawClient.RemoteAddr(), err)
   119  				rlp.Encode(tssClient, "FAIL")
   120  				tssClient.Close()
   121  				return
   122  			}
   123  			slowLimit = true
   124  		}
   125  		// IGNORE FOR NOW
   126  		rlp.Encode(tssClient, "OK")
   127  	}
   128  	rawClient.SetDeadline(time.Now().Add(time.Hour * 24))
   129  	sessid := fmt.Sprintf("%v", strings.Split(tssClient.RemoteAddr().String(), ":")[0])
   130  	switch tssClient.NextProt() {
   131  	case 0:
   132  		defer tssClient.Close()
   133  		// create smux context
   134  		muxSrv, err := smux.Server(tssClient, &smux.Config{
   135  			Version:           1,
   136  			KeepAliveInterval: time.Minute * 10,
   137  			KeepAliveTimeout:  time.Minute * 40,
   138  			MaxFrameSize:      8192,
   139  			MaxReceiveBuffer:  100 * 1024,
   140  			MaxStreamBuffer:   100 * 1024,
   141  		})
   142  		if err != nil {
   143  			log.Println("Error negotiating smux from", rawClient.RemoteAddr(), err)
   144  			return
   145  		}
   146  		acceptStream = func() (n net.Conn, e error) {
   147  			n, e = muxSrv.AcceptStream()
   148  			return
   149  		}
   150  	case 2:
   151  		defer tssClient.Close()
   152  		// create smux context
   153  		muxSrv, err := smux.Server(tssClient, &smux.Config{
   154  			Version:           2,
   155  			KeepAliveInterval: time.Minute * 2,
   156  			KeepAliveTimeout:  time.Minute * 20,
   157  			MaxFrameSize:      32768,
   158  			MaxReceiveBuffer:  100 * 1024 * 1024,
   159  			MaxStreamBuffer:   100 * 1024 * 1024,
   160  		})
   161  		if err != nil {
   162  			log.Println("Error negotiating smux from", rawClient.RemoteAddr(), err)
   163  			return
   164  		}
   165  		acceptStream = func() (n net.Conn, e error) {
   166  			n, e = muxSrv.AcceptStream()
   167  			return
   168  		}
   169  	case 'S':
   170  		defer tssClient.Close()
   171  		// create smux context
   172  		muxSrv, err := yamux.Server(tssClient, &yamux.Config{
   173  			AcceptBacklog:          1000,
   174  			EnableKeepAlive:        false,
   175  			KeepAliveInterval:      time.Hour,
   176  			ConnectionWriteTimeout: time.Minute * 30,
   177  			MaxStreamWindowSize:    100 * 1024 * 1024,
   178  			LogOutput:              ioutil.Discard,
   179  		})
   180  		if err != nil {
   181  			log.Println("Error negotiating yamux from", rawClient.RemoteAddr(), err)
   182  			return
   183  		}
   184  		acceptStream = func() (n net.Conn, e error) {
   185  			n, e = muxSrv.AcceptStream()
   186  			return
   187  		}
   188  	case 'N':
   189  		defer tssClient.Close()
   190  		buf := make([]byte, 32)
   191  		io.ReadFull(tssClient, buf)
   192  		sessid = fmt.Sprintf("%x", buf)
   193  		// create smux context
   194  		muxSrv, err := smux.Server(tssClient, &smux.Config{
   195  			Version:           2,
   196  			KeepAliveInterval: time.Minute * 10,
   197  			KeepAliveTimeout:  time.Minute * 40,
   198  			MaxFrameSize:      32768,
   199  			MaxReceiveBuffer:  100 * 1024,
   200  			MaxStreamBuffer:   100 * 1024,
   201  		})
   202  		if err != nil {
   203  			log.Println("Error negotiating smux from", rawClient.RemoteAddr(), err)
   204  			return
   205  		}
   206  		acceptStream = func() (n net.Conn, e error) {
   207  			n, e = muxSrv.AcceptStream()
   208  			return
   209  		}
   210  	case 'R':
   211  		err = handleResumable(slowLimit, tssClient)
   212  		log.Println("handleResumable returned with", err)
   213  		if err != nil {
   214  			tssClient.Close()
   215  		}
   216  		return
   217  	}
   218  	if slowLimit {
   219  		limiter = slowLimitFactory.getLimiter(sessid)
   220  	}
   221  	smuxLoop(sessid, limiter, acceptStream)
   222  }
   223  
   224  type scEntry struct {
   225  	newConns chan net.Conn
   226  	currConn net.Conn
   227  	handle   *backedtcp.Socket
   228  }
   229  
   230  var sessionCache = make(map[[32]byte]*scEntry)
   231  var sessionCacheLock sync.Mutex
   232  
   233  func handleResumable(slowLimit bool, tssClient net.Conn) (err error) {
   234  	log.Println("handling resumable from", tssClient.RemoteAddr())
   235  	tssClient.SetDeadline(time.Now().Add(time.Second * 10))
   236  	var clientHello struct {
   237  		MetaSess [32]byte
   238  		SessID   [32]byte
   239  	}
   240  	err = binary.Read(tssClient, binary.BigEndian, &clientHello)
   241  	if err != nil {
   242  		return
   243  	}
   244  	log.Printf("[%v] M=%x, S=%x", tssClient.RemoteAddr(), clientHello.MetaSess, clientHello.SessID)
   245  	sessionCacheLock.Lock()
   246  	defer sessionCacheLock.Unlock()
   247  	if bt, ok := sessionCache[clientHello.SessID]; ok {
   248  		log.Printf("[%v] found session", tssClient.RemoteAddr())
   249  		bt.currConn.Close()
   250  		bt.currConn = tssClient
   251  		select {
   252  		case bt.newConns <- tssClient:
   253  			tssClient.Write([]byte{1})
   254  		case <-time.After(time.Millisecond * 100):
   255  			log.Printf("******** somehow stuck **********")
   256  		}
   257  		return
   258  	}
   259  	log.Printf("[%v] creating session", tssClient.RemoteAddr())
   260  	tssClient.Write([]byte{0})
   261  	ch := make(chan net.Conn, 1)
   262  	ch <- tssClient
   263  	btcp := backedtcp.NewSocket(func() (net.Conn, error) {
   264  		select {
   265  		case c := <-ch:
   266  			return c, nil
   267  		case <-time.After(time.Minute * 30):
   268  			return nil, errors.New("timeout")
   269  		}
   270  	})
   271  	sessionCache[clientHello.SessID] = &scEntry{
   272  		newConns: ch,
   273  		handle:   btcp,
   274  		currConn: tssClient,
   275  	}
   276  	go func() {
   277  		defer func() {
   278  			sessionCacheLock.Lock()
   279  			defer sessionCacheLock.Unlock()
   280  			log.Printf("deleting sessid %v", clientHello.SessID)
   281  			delete(sessionCache, clientHello.SessID)
   282  		}()
   283  		defer btcp.Close()
   284  		muxSrv, err := smux.Server(btcp, &smux.Config{
   285  			Version:           2,
   286  			KeepAliveInterval: time.Minute * 20,
   287  			KeepAliveTimeout:  time.Minute * 40,
   288  			MaxFrameSize:      32768,
   289  			MaxReceiveBuffer:  1 * 1024 * 1024,
   290  			MaxStreamBuffer:   256 * 1024,
   291  		})
   292  		if err != nil {
   293  			return
   294  		}
   295  		acceptStream := func() (n net.Conn, e error) {
   296  			n, e = muxSrv.AcceptStream()
   297  			return
   298  		}
   299  		var limiter *rate.Limiter
   300  		if slowLimit {
   301  			limiter = slowLimitFactory.getLimiter(fmt.Sprintf("%x", clientHello.MetaSess))
   302  		} else {
   303  			limiter = infiniteLimit
   304  		}
   305  		smuxLoop(fmt.Sprintf("%x", clientHello.MetaSess), limiter, acceptStream)
   306  	}()
   307  	return
   308  }
   309  
   310  func smuxLoop(sessid string, limiter *rate.Limiter, acceptStream func() (n net.Conn, e error)) {
   311  	// copy the streams while
   312  	var counter uint64
   313  	for {
   314  		soxclient, err := acceptStream()
   315  		if err != nil {
   316  			log.Println("failed accept stream", err)
   317  			return
   318  		}
   319  		if limiter == infiniteLimit {
   320  			paidSessCounter.SetDefault(sessid, true)
   321  		} else {
   322  			freeSessCounter.SetDefault(sessid, true)
   323  		}
   324  		go func() {
   325  			defer soxclient.Close()
   326  			soxclient.SetDeadline(time.Now().Add(time.Minute))
   327  			var command []string
   328  			err = rlp.Decode(&io.LimitedReader{R: soxclient, N: 1000}, &command)
   329  			if err != nil {
   330  				return
   331  			}
   332  			if len(command) == 0 {
   333  				return
   334  			}
   335  			soxclient.SetDeadline(time.Time{})
   336  			atomic.LoadUint64(&tunnCount)
   337  			timeout := time.Minute * 30
   338  			log.Debugf("[%v] cmd %v", timeout, command)
   339  			// match command
   340  			switch command[0] {
   341  			case "proxy":
   342  				if len(command) < 1 {
   343  					return
   344  				}
   345  				rlp.Encode(soxclient, true)
   346  				dialStart := time.Now()
   347  				host := command[1]
   348  				var remote net.Conn
   349  				for _, ntype := range []string{"tcp6", "tcp4"} {
   350  					tcpAddr, err := net.ResolveTCPAddr(ntype, host)
   351  					if err != nil || isBlack(tcpAddr) {
   352  						continue
   353  					}
   354  					remote, err = net.DialTimeout(ntype, tcpAddr.String(), time.Second*30)
   355  					if err != nil {
   356  						continue
   357  					}
   358  					break
   359  				}
   360  				if remote == nil {
   361  					return
   362  				}
   363  				atomic.AddUint64(&tunnCount, 1)
   364  				defer atomic.AddUint64(&tunnCount, ^uint64(0))
   365  				// measure dial latency
   366  				dialLatency := time.Since(dialStart)
   367  				if statClient != nil && singleHop == "" && reportRL.Allow() {
   368  					statClient.Timing(hostname+".dialLatency", dialLatency.Milliseconds())
   369  				}
   370  				defer remote.Close()
   371  				onPacket := func(l int) {
   372  					if statClient != nil && singleHop == "" {
   373  						before := atomic.LoadUint64(&counter)
   374  						atomic.AddUint64(&counter, uint64(l))
   375  						after := atomic.LoadUint64(&counter)
   376  						if before/1000000 != after/1000000 {
   377  							statClient.Increment(hostname + ".transferMB")
   378  						}
   379  					}
   380  				}
   381  				go func() {
   382  					defer remote.Close()
   383  					defer soxclient.Close()
   384  					cwl.CopyWithLimit(remote, soxclient, limiter, onPacket, timeout)
   385  				}()
   386  				cwl.CopyWithLimit(soxclient, remote, limiter, onPacket, timeout)
   387  			case "ip":
   388  				var ip string
   389  				if ipi, ok := ipcache.Get("ip"); ok {
   390  					ip = ipi.(string)
   391  				} else {
   392  					addr := "http://checkip.amazonaws.com"
   393  					resp, err := http.Get(addr)
   394  					if err != nil {
   395  						return
   396  					}
   397  					defer resp.Body.Close()
   398  					ipb, err := ioutil.ReadAll(resp.Body)
   399  					if err != nil {
   400  						return
   401  					}
   402  					ip = string(ipb)
   403  					ipcache.SetDefault("ip", ip)
   404  				}
   405  				rlp.Encode(soxclient, true)
   406  				rlp.Encode(soxclient, ip)
   407  				time.Sleep(time.Second)
   408  			}
   409  		}()
   410  	}
   411  }