github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/pseudotcp/ptcp.go (about)

     1  package pseudotcp
     2  
     3  import (
     4  	"io"
     5  	"log"
     6  	"math/rand"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/geph-official/geph2/libs/buffconn"
    12  	"github.com/xtaci/smux"
    13  	"gopkg.in/tomb.v1"
    14  )
    15  
    16  var dialArray = make([]*dialer, 32)
    17  
    18  func init() {
    19  	for i := range dialArray {
    20  		dialArray[i] = new(dialer)
    21  	}
    22  }
    23  
    24  // Dial haha
    25  func Dial(host string) (conn net.Conn, err error) {
    26  	return dialArray[rand.Int()%len(dialArray)].Dial(host)
    27  }
    28  
    29  type dialer struct {
    30  	locks  sync.Map // string => *sync.RWMutex
    31  	smuxes sync.Map // string => *smux.Session
    32  }
    33  
    34  func (dl *dialer) getLock(host string) *sync.RWMutex {
    35  	lok, _ := dl.locks.LoadOrStore(host, new(sync.RWMutex))
    36  	return lok.(*sync.RWMutex)
    37  }
    38  
    39  var smuxConf = &smux.Config{
    40  	Version:           2,
    41  	KeepAliveInterval: time.Minute * 1,
    42  	KeepAliveTimeout:  time.Minute * 2,
    43  	MaxFrameSize:      32768,
    44  	MaxReceiveBuffer:  10 * 1024 * 1024,
    45  	MaxStreamBuffer:   10 * 1024 * 1024,
    46  }
    47  
    48  // Dial dials a "pseudoTCP" connection to the given host
    49  func (dl *dialer) Dial(host string) (conn net.Conn, err error) {
    50  	dl.getLock(host).Lock()
    51  	defer dl.getLock(host).Unlock()
    52  	fixConn := func() {
    53  		conn.SetDeadline(time.Now().Add(time.Second * 10))
    54  		buf := make([]byte, 1)
    55  		conn.Write(buf)
    56  		io.ReadFull(conn, buf)
    57  		conn.SetDeadline(time.Time{})
    58  	}
    59  	if s, ok := dl.smuxes.Load(host); ok {
    60  		ssess := s.(*smux.Session)
    61  		conn, err = ssess.OpenStream()
    62  		if err != nil {
    63  			dl.smuxes.Delete(host)
    64  		} else {
    65  			fixConn()
    66  		}
    67  		return
    68  	}
    69  
    70  	rawConn, err := net.DialTimeout("tcp", host, time.Second*5)
    71  	if err != nil {
    72  		return
    73  	}
    74  	ssess, err := smux.Client(buffconn.New(rawConn), smuxConf)
    75  	if err != nil {
    76  		rawConn.Close()
    77  		return
    78  	}
    79  	dl.smuxes.Store(host, ssess)
    80  	conn, err = ssess.OpenStream()
    81  	if err == nil {
    82  		fixConn()
    83  	}
    84  	return
    85  }
    86  
    87  // Listener listens for ptcp connections
    88  type Listener struct {
    89  	death      tomb.Tomb
    90  	incoming   chan net.Conn
    91  	underlying net.Listener
    92  }
    93  
    94  // Listen opens a Listener
    95  func Listen(addr string) (listener net.Listener, err error) {
    96  	tListener, err := net.Listen("tcp", addr)
    97  	if err != nil {
    98  		return
    99  	}
   100  	toret := &Listener{incoming: make(chan net.Conn), underlying: tListener}
   101  	go func() {
   102  		defer toret.death.Kill(io.ErrClosedPipe)
   103  		for {
   104  			rawConn, err := tListener.Accept()
   105  			if err != nil {
   106  				log.Println("raw accept:", err)
   107  				break
   108  			}
   109  			go func() {
   110  				defer rawConn.Close()
   111  				srv, err := smux.Server(buffconn.New(rawConn), smuxConf)
   112  				if err != nil {
   113  					log.Println("smux create:", err)
   114  					return
   115  				}
   116  				for {
   117  					conn, err := srv.AcceptStream()
   118  					if err != nil {
   119  						log.Println("smux accept:", err)
   120  						return
   121  					}
   122  					go func() {
   123  						conn.SetDeadline(time.Now().Add(time.Second * 10))
   124  						buf := make([]byte, 1)
   125  						io.ReadFull(conn, buf)
   126  						conn.Write(buf)
   127  						conn.SetDeadline(time.Time{})
   128  						select {
   129  						case toret.incoming <- conn:
   130  						case <-toret.death.Dying():
   131  							srv.Close()
   132  							tListener.Close()
   133  							return
   134  						}
   135  					}()
   136  				}
   137  			}()
   138  		}
   139  	}()
   140  	listener = toret
   141  	return
   142  }
   143  
   144  // Accept accepts a new connection.
   145  func (l *Listener) Accept() (conn net.Conn, err error) {
   146  	select {
   147  	case conn = <-l.incoming:
   148  	case <-l.death.Dying():
   149  		err = l.death.Err()
   150  	}
   151  	return
   152  }
   153  
   154  // Addr is the address of the listener.
   155  func (l *Listener) Addr() net.Addr {
   156  	return l.underlying.Addr()
   157  }
   158  
   159  // Close closes the listener.
   160  func (l *Listener) Close() error {
   161  	l.death.Kill(io.ErrClosedPipe)
   162  	return nil
   163  }