github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/tunnel/tunnel_bridge.go (about) 1 package tunnel 2 3 import ( 4 "bytes" 5 "fmt" 6 logger "log" 7 "net" 8 "os" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/AntonOrnatskyi/goproxy/core/cs/server" 14 "github.com/AntonOrnatskyi/goproxy/services" 15 "github.com/AntonOrnatskyi/goproxy/utils" 16 "github.com/AntonOrnatskyi/goproxy/utils/mapx" 17 18 //"github.com/xtaci/smux" 19 smux "github.com/hashicorp/yamux" 20 ) 21 22 const ( 23 CONN_CLIENT_CONTROL = uint8(1) 24 CONN_SERVER = uint8(4) 25 CONN_CLIENT = uint8(5) 26 ) 27 28 type TunnelBridgeArgs struct { 29 Parent *string 30 CertFile *string 31 KeyFile *string 32 CertBytes []byte 33 KeyBytes []byte 34 Local *string 35 Timeout *int 36 } 37 type ServerConn struct { 38 //ClientLocalAddr string //tcp:2.2.22:333@ID 39 Conn *net.Conn 40 } 41 type TunnelBridge struct { 42 cfg TunnelBridgeArgs 43 serverConns mapx.ConcurrentMap 44 clientControlConns mapx.ConcurrentMap 45 isStop bool 46 log *logger.Logger 47 } 48 49 func NewTunnelBridge() services.Service { 50 return &TunnelBridge{ 51 cfg: TunnelBridgeArgs{}, 52 serverConns: mapx.NewConcurrentMap(), 53 clientControlConns: mapx.NewConcurrentMap(), 54 isStop: false, 55 } 56 } 57 58 func (s *TunnelBridge) InitService() (err error) { 59 return 60 } 61 func (s *TunnelBridge) CheckArgs() (err error) { 62 if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { 63 err = fmt.Errorf("cert and key file required") 64 return 65 } 66 s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) 67 return 68 } 69 func (s *TunnelBridge) StopService() { 70 defer func() { 71 e := recover() 72 if e != nil { 73 s.log.Printf("stop tbridge service crashed,%s", e) 74 } else { 75 s.log.Printf("service tbridge stopped") 76 } 77 s.cfg = TunnelBridgeArgs{} 78 s.clientControlConns = nil 79 s.log = nil 80 s.serverConns = nil 81 s = nil 82 }() 83 s.isStop = true 84 for _, sess := range s.clientControlConns.Items() { 85 (*sess.(*net.Conn)).Close() 86 } 87 for _, sess := range s.serverConns.Items() { 88 (*sess.(ServerConn).Conn).Close() 89 } 90 } 91 func (s *TunnelBridge) Start(args interface{}, log *logger.Logger) (err error) { 92 s.log = log 93 s.cfg = args.(TunnelBridgeArgs) 94 if err = s.CheckArgs(); err != nil { 95 return 96 } 97 if err = s.InitService(); err != nil { 98 return 99 } 100 host, port, _ := net.SplitHostPort(*s.cfg.Local) 101 p, _ := strconv.Atoi(port) 102 sc := server.NewServerChannel(host, p, s.log) 103 104 err = sc.ListenTLS(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) 105 if err != nil { 106 return 107 } 108 s.log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr()) 109 return 110 } 111 func (s *TunnelBridge) Clean() { 112 s.StopService() 113 } 114 func (s *TunnelBridge) callback(inConn net.Conn) { 115 var err error 116 //s.log.Printf("connection from %s ", inConn.RemoteAddr()) 117 sess, err := smux.Server(inConn, &smux.Config{ 118 AcceptBacklog: 256, 119 EnableKeepAlive: true, 120 KeepAliveInterval: 9 * time.Second, 121 ConnectionWriteTimeout: 3 * time.Second, 122 MaxStreamWindowSize: 512 * 1024, 123 LogOutput: os.Stderr, 124 }) 125 if err != nil { 126 s.log.Printf("new mux server conn error,ERR:%s", err) 127 return 128 } 129 inConn, err = sess.AcceptStream() 130 if err != nil { 131 s.log.Printf("mux server conn accept error,ERR:%s", err) 132 return 133 } 134 go func() { 135 defer func() { 136 _ = recover() 137 }() 138 timer := time.NewTicker(time.Second * 3) 139 for { 140 <-timer.C 141 if sess.NumStreams() == 0 { 142 sess.Close() 143 timer.Stop() 144 return 145 } 146 } 147 }() 148 var buf = make([]byte, 1024) 149 n, _ := inConn.Read(buf) 150 reader := bytes.NewReader(buf[:n]) 151 152 //reader := bufio.NewReader(inConn) 153 154 var connType uint8 155 err = utils.ReadPacket(reader, &connType) 156 if err != nil { 157 s.log.Printf("read error,ERR:%s", err) 158 return 159 } 160 switch connType { 161 case CONN_SERVER: 162 var key, ID, clientLocalAddr, serverID string 163 err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) 164 if err != nil { 165 s.log.Printf("read error,ERR:%s", err) 166 return 167 } 168 packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) 169 s.log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) 170 171 //addr := clientLocalAddr + "@" + ID 172 s.serverConns.Set(ID, ServerConn{ 173 Conn: &inConn, 174 }) 175 for { 176 if s.isStop { 177 return 178 } 179 item, ok := s.clientControlConns.Get(key) 180 if !ok { 181 s.log.Printf("client %s control conn not exists", key) 182 time.Sleep(time.Second * 3) 183 continue 184 } 185 (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) 186 _, err := (*item.(*net.Conn)).Write(packet) 187 (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) 188 if err != nil && strings.Contains(err.Error(), "stream closed") { 189 s.log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) 190 time.Sleep(time.Second * 3) 191 continue 192 } else { 193 // s.cmServer.Add(serverID, ID, &inConn) 194 break 195 } 196 } 197 case CONN_CLIENT: 198 var key, ID, serverID string 199 err = utils.ReadPacketData(reader, &key, &ID, &serverID) 200 if err != nil { 201 s.log.Printf("read error,ERR:%s", err) 202 return 203 } 204 s.log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) 205 206 serverConnItem, ok := s.serverConns.Get(ID) 207 if !ok { 208 inConn.Close() 209 s.log.Printf("server conn %s exists", ID) 210 return 211 } 212 serverConn := serverConnItem.(ServerConn).Conn 213 utils.IoBind(*serverConn, inConn, func(err interface{}) { 214 s.serverConns.Remove(ID) 215 // s.cmClient.RemoveOne(key, ID) 216 // s.cmServer.RemoveOne(serverID, ID) 217 s.log.Printf("conn %s released", ID) 218 }, s.log) 219 // s.cmClient.Add(key, ID, &inConn) 220 s.log.Printf("conn %s created", ID) 221 222 case CONN_CLIENT_CONTROL: 223 var key string 224 err = utils.ReadPacketData(reader, &key) 225 if err != nil { 226 s.log.Printf("read error,ERR:%s", err) 227 return 228 } 229 s.log.Printf("client control connection, key: %s", key) 230 if s.clientControlConns.Has(key) { 231 item, _ := s.clientControlConns.Get(key) 232 (*item.(*net.Conn)).Close() 233 } 234 s.clientControlConns.Set(key, &inConn) 235 s.log.Printf("set client %s control conn", key) 236 } 237 }