github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/mux/mux_bridge.go (about) 1 package mux 2 3 import ( 4 "bufio" 5 "fmt" 6 "io" 7 logger "log" 8 "math/rand" 9 "net" 10 "runtime/debug" 11 "strings" 12 "sync" 13 "time" 14 15 srvtransport "github.com/AntonOrnatskyi/goproxy/core/cs/server" 16 "github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg" 17 "github.com/AntonOrnatskyi/goproxy/services" 18 "github.com/AntonOrnatskyi/goproxy/utils" 19 "github.com/AntonOrnatskyi/goproxy/utils/mapx" 20 //"github.com/xtaci/smux" 21 smux "github.com/hashicorp/yamux" 22 ) 23 24 type MuxBridgeArgs struct { 25 CertFile *string 26 KeyFile *string 27 CertBytes []byte 28 KeyBytes []byte 29 Local *string 30 LocalType *string 31 Timeout *int 32 IsCompress *bool 33 KCP kcpcfg.KCPConfigArgs 34 TCPSMethod *string 35 TCPSPassword *string 36 TOUMethod *string 37 TOUPassword *string 38 } 39 type MuxBridge struct { 40 cfg MuxBridgeArgs 41 clientControlConns mapx.ConcurrentMap 42 serverConns mapx.ConcurrentMap 43 router utils.ClientKeyRouter 44 l *sync.Mutex 45 isStop bool 46 sc *srvtransport.ServerChannel 47 log *logger.Logger 48 } 49 50 func NewMuxBridge() services.Service { 51 b := &MuxBridge{ 52 cfg: MuxBridgeArgs{}, 53 clientControlConns: mapx.NewConcurrentMap(), 54 serverConns: mapx.NewConcurrentMap(), 55 l: &sync.Mutex{}, 56 isStop: false, 57 } 58 b.router = utils.NewClientKeyRouter(&b.clientControlConns, 50000) 59 return b 60 } 61 62 func (s *MuxBridge) InitService() (err error) { 63 return 64 } 65 func (s *MuxBridge) CheckArgs() (err error) { 66 if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { 67 err = fmt.Errorf("cert and key file required") 68 return 69 } 70 if *s.cfg.LocalType == "tls" { 71 s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) 72 if err != nil { 73 return 74 } 75 } 76 return 77 } 78 func (s *MuxBridge) StopService() { 79 defer func() { 80 e := recover() 81 if e != nil { 82 s.log.Printf("stop bridge service crashed,%s", e) 83 } else { 84 s.log.Printf("service bridge stopped") 85 } 86 s.cfg = MuxBridgeArgs{} 87 s.clientControlConns = nil 88 s.l = nil 89 s.log = nil 90 s.router = utils.ClientKeyRouter{} 91 s.sc = nil 92 s.serverConns = nil 93 s = nil 94 }() 95 s.isStop = true 96 if s.sc != nil && (*s.sc).Listener != nil { 97 (*(*s.sc).Listener).Close() 98 } 99 for _, g := range s.clientControlConns.Items() { 100 for _, session := range g.(*mapx.ConcurrentMap).Items() { 101 (session.(*smux.Session)).Close() 102 } 103 } 104 for _, c := range s.serverConns.Items() { 105 (*c.(*net.Conn)).Close() 106 } 107 } 108 func (s *MuxBridge) Start(args interface{}, log *logger.Logger) (err error) { 109 s.log = log 110 s.cfg = args.(MuxBridgeArgs) 111 if err = s.CheckArgs(); err != nil { 112 return 113 } 114 if err = s.InitService(); err != nil { 115 return 116 } 117 118 sc := srvtransport.NewServerChannelHost(*s.cfg.Local, s.log) 119 if *s.cfg.LocalType == "tcp" { 120 err = sc.ListenTCP(s.handler) 121 } else if *s.cfg.LocalType == "tls" { 122 err = sc.ListenTLS(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.handler) 123 } else if *s.cfg.LocalType == "kcp" { 124 err = sc.ListenKCP(s.cfg.KCP, s.handler, s.log) 125 } else if *s.cfg.LocalType == "tcps" { 126 err = sc.ListenTCPS(*s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, s.handler) 127 } else if *s.cfg.LocalType == "tou" { 128 err = sc.ListenTOU(*s.cfg.TOUMethod, *s.cfg.TOUPassword, false, s.handler) 129 } 130 if err != nil { 131 return 132 } 133 s.sc = &sc 134 if *s.cfg.LocalType == "tou" { 135 s.log.Printf("%s bridge on %s", *s.cfg.LocalType, sc.UDPListener.LocalAddr()) 136 } else { 137 s.log.Printf("%s bridge on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) 138 } 139 return 140 } 141 func (s *MuxBridge) Clean() { 142 s.StopService() 143 } 144 func (s *MuxBridge) handler(inConn net.Conn) { 145 reader := bufio.NewReader(inConn) 146 147 var err error 148 var connType uint8 149 var key string 150 inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 151 err = utils.ReadPacket(reader, &connType, &key) 152 inConn.SetDeadline(time.Time{}) 153 if err != nil { 154 s.log.Printf("read error,ERR:%s", err) 155 return 156 } 157 switch connType { 158 case CONN_SERVER: 159 var serverID string 160 inAddr := inConn.RemoteAddr().String() 161 inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 162 err = utils.ReadPacketData(reader, &serverID) 163 inConn.SetDeadline(time.Time{}) 164 if err != nil { 165 s.log.Printf("read error,ERR:%s", err) 166 return 167 } 168 s.log.Printf("server connection %s %s connected", serverID, key) 169 if c, ok := s.serverConns.Get(inAddr); ok { 170 (*c.(*net.Conn)).Close() 171 } 172 s.serverConns.Set(inAddr, &inConn) 173 session, err := smux.Server(inConn, nil) 174 if err != nil { 175 utils.CloseConn(&inConn) 176 s.log.Printf("server session error,ERR:%s", err) 177 return 178 } 179 for { 180 if s.isStop { 181 return 182 } 183 stream, err := session.AcceptStream() 184 if err != nil { 185 session.Close() 186 utils.CloseConn(&inConn) 187 s.serverConns.Remove(inAddr) 188 s.log.Printf("server connection %s %s released", serverID, key) 189 return 190 } 191 go func() { 192 defer func() { 193 if e := recover(); e != nil { 194 s.log.Printf("bridge callback crashed,err: %s", e) 195 } 196 }() 197 s.callback(stream, serverID, key) 198 }() 199 } 200 case CONN_CLIENT: 201 s.log.Printf("client connection %s connected", key) 202 session, err := smux.Client(inConn, nil) 203 if err != nil { 204 utils.CloseConn(&inConn) 205 s.log.Printf("client session error,ERR:%s", err) 206 return 207 } 208 keyInfo := strings.Split(key, "-") 209 if len(keyInfo) != 2 { 210 utils.CloseConn(&inConn) 211 s.log.Printf("client key format error,key:%s", key) 212 return 213 } 214 groupKey := keyInfo[0] 215 index := keyInfo[1] 216 s.l.Lock() 217 defer s.l.Unlock() 218 var group *mapx.ConcurrentMap 219 if !s.clientControlConns.Has(groupKey) { 220 _g := mapx.NewConcurrentMap() 221 group = &_g 222 s.clientControlConns.Set(groupKey, group) 223 //s.log.Printf("init client session group %s", groupKey) 224 } else { 225 _group, _ := s.clientControlConns.Get(groupKey) 226 group = _group.(*mapx.ConcurrentMap) 227 } 228 if v, ok := group.Get(index); ok { 229 v.(*smux.Session).Close() 230 } 231 group.Set(index, session) 232 //s.log.Printf("set client session %s to group %s,grouplen:%d", index, groupKey, group.Count()) 233 go func() { 234 defer func() { 235 if e := recover(); e != nil { 236 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 237 } 238 }() 239 for { 240 if s.isStop { 241 return 242 } 243 if session.IsClosed() { 244 s.l.Lock() 245 defer s.l.Unlock() 246 if sess, ok := group.Get(index); ok && sess.(*smux.Session).IsClosed() { 247 group.Remove(index) 248 //s.log.Printf("client session %s removed from group %s, grouplen:%d", key, groupKey, group.Count()) 249 s.log.Printf("client connection %s released", key) 250 } 251 if group.IsEmpty() { 252 s.clientControlConns.Remove(groupKey) 253 //s.log.Printf("client session group %s removed", groupKey) 254 } 255 break 256 } 257 time.Sleep(time.Second * 5) 258 } 259 }() 260 //s.log.Printf("set client session,key: %s", key) 261 } 262 263 } 264 func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) { 265 try := 20 266 for { 267 if s.isStop { 268 return 269 } 270 try-- 271 if try == 0 { 272 break 273 } 274 if key == "*" { 275 key = s.router.GetKey() 276 } 277 //s.log.Printf("server get client session %s", key) 278 _group, ok := s.clientControlConns.Get(key) 279 if !ok { 280 s.log.Printf("client %s session not exists for server stream %s, retrying...", key, serverID) 281 time.Sleep(time.Second * 3) 282 continue 283 } 284 group := _group.(*mapx.ConcurrentMap) 285 keys := []string{} 286 group.IterCb(func(key string, v interface{}) { 287 keys = append(keys, key) 288 }) 289 keysLen := len(keys) 290 //s.log.Printf("client session %s , len:%d , keysLen: %d", key, group.Count(), keysLen) 291 i := 0 292 if keysLen > 0 { 293 i = rand.Intn(keysLen) 294 } else { 295 s.log.Printf("client %s session empty for server stream %s, retrying...", key, serverID) 296 time.Sleep(time.Second * 3) 297 continue 298 } 299 index := keys[i] 300 s.log.Printf("select client : %s-%s", key, index) 301 session, _ := group.Get(index) 302 //session.(*smux.Session).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 303 stream, err := session.(*smux.Session).OpenStream() 304 //session.(*smux.Session).SetDeadline(time.Time{}) 305 if err != nil { 306 s.log.Printf("%s client session open stream %s fail, err: %s, retrying...", key, serverID, err) 307 time.Sleep(time.Second * 3) 308 continue 309 } else { 310 s.log.Printf("stream %s -> %s created", serverID, key) 311 die1 := make(chan bool, 1) 312 die2 := make(chan bool, 1) 313 go func() { 314 defer func() { 315 if e := recover(); e != nil { 316 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 317 } 318 }() 319 io.Copy(stream, inConn) 320 die1 <- true 321 }() 322 go func() { 323 defer func() { 324 if e := recover(); e != nil { 325 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 326 } 327 }() 328 io.Copy(inConn, stream) 329 die2 <- true 330 }() 331 select { 332 case <-die1: 333 case <-die2: 334 } 335 stream.Close() 336 inConn.Close() 337 s.log.Printf("%s server %s stream released", key, serverID) 338 break 339 } 340 } 341 342 }