github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/pseudotcp/ptcp.go (about) 1 package pseudotcp 2 3 import ( 4 "io" 5 "log" 6 "math/rand" 7 "net" 8 "sync" 9 "time" 10 11 "github.com/geph-official/geph2/libs/buffconn" 12 "github.com/xtaci/smux" 13 "gopkg.in/tomb.v1" 14 ) 15 16 var dialArray = make([]*dialer, 32) 17 18 func init() { 19 for i := range dialArray { 20 dialArray[i] = new(dialer) 21 } 22 } 23 24 // Dial haha 25 func Dial(host string) (conn net.Conn, err error) { 26 return dialArray[rand.Int()%len(dialArray)].Dial(host) 27 } 28 29 type dialer struct { 30 locks sync.Map // string => *sync.RWMutex 31 smuxes sync.Map // string => *smux.Session 32 } 33 34 func (dl *dialer) getLock(host string) *sync.RWMutex { 35 lok, _ := dl.locks.LoadOrStore(host, new(sync.RWMutex)) 36 return lok.(*sync.RWMutex) 37 } 38 39 var smuxConf = &smux.Config{ 40 Version: 2, 41 KeepAliveInterval: time.Minute * 1, 42 KeepAliveTimeout: time.Minute * 2, 43 MaxFrameSize: 32768, 44 MaxReceiveBuffer: 10 * 1024 * 1024, 45 MaxStreamBuffer: 10 * 1024 * 1024, 46 } 47 48 // Dial dials a "pseudoTCP" connection to the given host 49 func (dl *dialer) Dial(host string) (conn net.Conn, err error) { 50 dl.getLock(host).Lock() 51 defer dl.getLock(host).Unlock() 52 fixConn := func() { 53 conn.SetDeadline(time.Now().Add(time.Second * 10)) 54 buf := make([]byte, 1) 55 conn.Write(buf) 56 io.ReadFull(conn, buf) 57 conn.SetDeadline(time.Time{}) 58 } 59 if s, ok := dl.smuxes.Load(host); ok { 60 ssess := s.(*smux.Session) 61 conn, err = ssess.OpenStream() 62 if err != nil { 63 dl.smuxes.Delete(host) 64 } else { 65 fixConn() 66 } 67 return 68 } 69 70 rawConn, err := net.DialTimeout("tcp", host, time.Second*5) 71 if err != nil { 72 return 73 } 74 ssess, err := smux.Client(buffconn.New(rawConn), smuxConf) 75 if err != nil { 76 rawConn.Close() 77 return 78 } 79 dl.smuxes.Store(host, ssess) 80 conn, err = ssess.OpenStream() 81 if err == nil { 82 fixConn() 83 } 84 return 85 } 86 87 // Listener listens for ptcp connections 88 type Listener struct { 89 death tomb.Tomb 90 incoming chan net.Conn 91 underlying net.Listener 92 } 93 94 // Listen opens a Listener 95 func Listen(addr string) (listener net.Listener, err error) { 96 tListener, err := net.Listen("tcp", addr) 97 if err != nil { 98 return 99 } 100 toret := &Listener{incoming: make(chan net.Conn), underlying: tListener} 101 go func() { 102 defer toret.death.Kill(io.ErrClosedPipe) 103 for { 104 rawConn, err := tListener.Accept() 105 if err != nil { 106 log.Println("raw accept:", err) 107 break 108 } 109 go func() { 110 defer rawConn.Close() 111 srv, err := smux.Server(buffconn.New(rawConn), smuxConf) 112 if err != nil { 113 log.Println("smux create:", err) 114 return 115 } 116 for { 117 conn, err := srv.AcceptStream() 118 if err != nil { 119 log.Println("smux accept:", err) 120 return 121 } 122 go func() { 123 conn.SetDeadline(time.Now().Add(time.Second * 10)) 124 buf := make([]byte, 1) 125 io.ReadFull(conn, buf) 126 conn.Write(buf) 127 conn.SetDeadline(time.Time{}) 128 select { 129 case toret.incoming <- conn: 130 case <-toret.death.Dying(): 131 srv.Close() 132 tListener.Close() 133 return 134 } 135 }() 136 } 137 }() 138 } 139 }() 140 listener = toret 141 return 142 } 143 144 // Accept accepts a new connection. 145 func (l *Listener) Accept() (conn net.Conn, err error) { 146 select { 147 case conn = <-l.incoming: 148 case <-l.death.Dying(): 149 err = l.death.Err() 150 } 151 return 152 } 153 154 // Addr is the address of the listener. 155 func (l *Listener) Addr() net.Addr { 156 return l.underlying.Addr() 157 } 158 159 // Close closes the listener. 160 func (l *Listener) Close() error { 161 l.death.Kill(io.ErrClosedPipe) 162 return nil 163 }