github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/mux/mux_client.go (about) 1 package mux 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "io" 7 logger "log" 8 "net" 9 "runtime/debug" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/golang/snappy" 15 clienttransport "github.com/AntonOrnatskyi/goproxy/core/cs/client" 16 "github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg" 17 encryptconn "github.com/AntonOrnatskyi/goproxy/core/lib/transport/encrypt" 18 "github.com/AntonOrnatskyi/goproxy/services" 19 "github.com/AntonOrnatskyi/goproxy/utils" 20 "github.com/AntonOrnatskyi/goproxy/utils/jumper" 21 "github.com/AntonOrnatskyi/goproxy/utils/mapx" 22 //"github.com/xtaci/smux" 23 smux "github.com/hashicorp/yamux" 24 ) 25 26 type MuxClientArgs struct { 27 Parent *string 28 ParentType *string 29 CertFile *string 30 KeyFile *string 31 CertBytes []byte 32 KeyBytes []byte 33 Key *string 34 Timeout *int 35 IsCompress *bool 36 SessionCount *int 37 KCP kcpcfg.KCPConfigArgs 38 Jumper *string 39 TCPSMethod *string 40 TCPSPassword *string 41 TOUMethod *string 42 TOUPassword *string 43 } 44 type ClientUDPConnItem struct { 45 conn *smux.Stream 46 isActive bool 47 touchtime int64 48 srcAddr *net.UDPAddr 49 localAddr *net.UDPAddr 50 udpConn *net.UDPConn 51 connid string 52 } 53 type MuxClient struct { 54 cfg MuxClientArgs 55 isStop bool 56 sessions mapx.ConcurrentMap 57 log *logger.Logger 58 jumper *jumper.Jumper 59 udpConns mapx.ConcurrentMap 60 } 61 62 func NewMuxClient() services.Service { 63 return &MuxClient{ 64 cfg: MuxClientArgs{}, 65 isStop: false, 66 sessions: mapx.NewConcurrentMap(), 67 udpConns: mapx.NewConcurrentMap(), 68 } 69 } 70 71 func (s *MuxClient) InitService() (err error) { 72 s.UDPGCDeamon() 73 return 74 } 75 76 func (s *MuxClient) CheckArgs() (err error) { 77 if *s.cfg.Parent != "" { 78 s.log.Printf("use tls parent %s", *s.cfg.Parent) 79 } else { 80 err = fmt.Errorf("parent required") 81 return 82 } 83 if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { 84 err = fmt.Errorf("cert and key file required") 85 return 86 } 87 if *s.cfg.ParentType == "tls" { 88 s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) 89 if err != nil { 90 return 91 } 92 } 93 if *s.cfg.Jumper != "" { 94 if *s.cfg.ParentType != "tls" && *s.cfg.ParentType != "tcp" { 95 err = fmt.Errorf("jumper only worked of -T is tls or tcp") 96 return 97 } 98 var j jumper.Jumper 99 j, err = jumper.New(*s.cfg.Jumper, time.Millisecond*time.Duration(*s.cfg.Timeout)) 100 if err != nil { 101 err = fmt.Errorf("parse jumper fail, err %s", err) 102 return 103 } 104 s.jumper = &j 105 } 106 return 107 } 108 func (s *MuxClient) StopService() { 109 defer func() { 110 e := recover() 111 if e != nil { 112 s.log.Printf("stop client service crashed,%s", e) 113 } else { 114 s.log.Printf("service client stopped") 115 } 116 s.cfg = MuxClientArgs{} 117 s.jumper = nil 118 s.log = nil 119 s.sessions = nil 120 s.udpConns = nil 121 s = nil 122 }() 123 s.isStop = true 124 for _, sess := range s.sessions.Items() { 125 sess.(*smux.Session).Close() 126 } 127 } 128 func (s *MuxClient) Start(args interface{}, log *logger.Logger) (err error) { 129 s.log = log 130 s.cfg = args.(MuxClientArgs) 131 if err = s.CheckArgs(); err != nil { 132 return 133 } 134 if err = s.InitService(); err != nil { 135 return 136 } 137 s.log.Printf("client started") 138 count := 1 139 if *s.cfg.SessionCount > 0 { 140 count = *s.cfg.SessionCount 141 } 142 for i := 1; i <= count; i++ { 143 key := fmt.Sprintf("worker[%d]", i) 144 s.log.Printf("session %s started", key) 145 go func(i int) { 146 defer func() { 147 e := recover() 148 if e != nil { 149 s.log.Printf("session worker crashed: %s\nstack:%s", e, string(debug.Stack())) 150 } 151 }() 152 for { 153 if s.isStop { 154 return 155 } 156 conn, err := s.getParentConn() 157 if err != nil { 158 s.log.Printf("connection err: %s, retrying...", err) 159 time.Sleep(time.Second * 3) 160 continue 161 } 162 conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 163 g := sync.WaitGroup{} 164 g.Add(1) 165 go func() { 166 defer func() { 167 _ = recover() 168 g.Done() 169 }() 170 _, err = conn.Write(utils.BuildPacket(CONN_CLIENT, fmt.Sprintf("%s-%d", *s.cfg.Key, i))) 171 }() 172 g.Wait() 173 conn.SetDeadline(time.Time{}) 174 if err != nil { 175 conn.Close() 176 s.log.Printf("connection err: %s, retrying...", err) 177 time.Sleep(time.Second * 3) 178 continue 179 } 180 session, err := smux.Server(conn, nil) 181 if err != nil { 182 s.log.Printf("session err: %s, retrying...", err) 183 conn.Close() 184 time.Sleep(time.Second * 3) 185 continue 186 } 187 if _sess, ok := s.sessions.Get(key); ok { 188 _sess.(*smux.Session).Close() 189 } 190 s.sessions.Set(key, session) 191 for { 192 if s.isStop { 193 return 194 } 195 stream, err := session.AcceptStream() 196 if err != nil { 197 s.log.Printf("accept stream err: %s, retrying...", err) 198 session.Close() 199 time.Sleep(time.Second * 3) 200 break 201 } 202 go func() { 203 defer func() { 204 e := recover() 205 if e != nil { 206 s.log.Printf("stream handler crashed: %s", e) 207 } 208 }() 209 var ID, clientLocalAddr, serverID string 210 stream.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 211 err = utils.ReadPacketData(stream, &ID, &clientLocalAddr, &serverID) 212 stream.SetDeadline(time.Time{}) 213 if err != nil { 214 s.log.Printf("read stream signal err: %s", err) 215 stream.Close() 216 return 217 } 218 //s.log.Printf("worker[%d] signal revecived,server %s stream %s %s", i, serverID, ID, clientLocalAddr) 219 protocol := clientLocalAddr[:3] 220 localAddr := clientLocalAddr[4:] 221 if protocol == "udp" { 222 s.ServeUDP(stream, localAddr, ID) 223 } else { 224 s.ServeConn(stream, localAddr, ID) 225 } 226 }() 227 } 228 } 229 }(i) 230 } 231 return 232 } 233 func (s *MuxClient) Clean() { 234 s.StopService() 235 } 236 func (s *MuxClient) getParentConn() (conn net.Conn, err error) { 237 if *s.cfg.ParentType == "tls" { 238 if s.jumper == nil { 239 var _conn tls.Conn 240 _conn, err = clienttransport.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) 241 if err == nil { 242 conn = net.Conn(&_conn) 243 } 244 } else { 245 conf, e := utils.TlsConfig(s.cfg.CertBytes, s.cfg.KeyBytes, nil) 246 if e != nil { 247 return nil, e 248 } 249 var _c net.Conn 250 _c, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout)) 251 if err == nil { 252 conn = net.Conn(tls.Client(_c, conf)) 253 } 254 } 255 256 } else if *s.cfg.ParentType == "kcp" { 257 conn, err = clienttransport.KCPConnectHost(*s.cfg.Parent, s.cfg.KCP) 258 } else if *s.cfg.ParentType == "tcps" { 259 if s.jumper == nil { 260 conn, err = clienttransport.TCPSConnectHost(*s.cfg.Parent, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, *s.cfg.Timeout) 261 } else { 262 conn, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout)) 263 if err == nil { 264 conn, err = encryptconn.NewConn(conn, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword) 265 } 266 } 267 268 } else if *s.cfg.ParentType == "tou" { 269 conn, err = clienttransport.TOUConnectHost(*s.cfg.Parent, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, *s.cfg.Timeout) 270 } else { 271 if s.jumper == nil { 272 conn, err = clienttransport.TCPConnectHost(*s.cfg.Parent, *s.cfg.Timeout) 273 } else { 274 conn, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout)) 275 } 276 } 277 return 278 } 279 func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) { 280 var item *ClientUDPConnItem 281 var body []byte 282 var err error 283 srcAddr := "" 284 defer func() { 285 if item != nil { 286 (*item).conn.Close() 287 (*item).udpConn.Close() 288 s.udpConns.Remove(srcAddr) 289 inConn.Close() 290 } 291 }() 292 for { 293 if s.isStop { 294 return 295 } 296 srcAddr, body, err = utils.ReadUDPPacket(inConn) 297 if err != nil { 298 if strings.Contains(err.Error(), "n != int(") { 299 continue 300 } 301 if !utils.IsNetDeadlineErr(err) && err != io.EOF { 302 s.log.Printf("udp packet revecived from bridge fail, err: %s", err) 303 } 304 return 305 } 306 if v, ok := s.udpConns.Get(srcAddr); !ok { 307 _srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr) 308 zeroAddr, _ := net.ResolveUDPAddr("udp", ":") 309 _localAddr, _ := net.ResolveUDPAddr("udp", localAddr) 310 c, err := net.DialUDP("udp", zeroAddr, _localAddr) 311 if err != nil { 312 s.log.Printf("create local udp conn fail, err : %s", err) 313 inConn.Close() 314 return 315 } 316 item = &ClientUDPConnItem{ 317 conn: inConn, 318 srcAddr: _srcAddr, 319 localAddr: _localAddr, 320 udpConn: c, 321 connid: ID, 322 } 323 s.udpConns.Set(srcAddr, item) 324 s.UDPRevecive(srcAddr, ID) 325 } else { 326 item = v.(*ClientUDPConnItem) 327 } 328 (*item).touchtime = time.Now().Unix() 329 go func() { 330 defer func() { _ = recover() }() 331 (*item).udpConn.Write(body) 332 }() 333 } 334 } 335 func (s *MuxClient) UDPRevecive(key, ID string) { 336 go func() { 337 defer func() { 338 if e := recover(); e != nil { 339 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 340 } 341 }() 342 s.log.Printf("udp conn %s connected", ID) 343 v, ok := s.udpConns.Get(key) 344 if !ok { 345 s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID) 346 return 347 } 348 cui := v.(*ClientUDPConnItem) 349 buf := utils.LeakyBuffer.Get() 350 defer func() { 351 utils.LeakyBuffer.Put(buf) 352 cui.conn.Close() 353 cui.udpConn.Close() 354 s.udpConns.Remove(key) 355 s.log.Printf("udp conn %s released", ID) 356 }() 357 for { 358 n, err := cui.udpConn.Read(buf) 359 if err != nil { 360 if !utils.IsNetClosedErr(err) { 361 s.log.Printf("udp conn read udp packet fail , err: %s ", err) 362 } 363 return 364 } 365 cui.touchtime = time.Now().Unix() 366 go func() { 367 defer func() { 368 if e := recover(); e != nil { 369 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 370 } 371 }() 372 cui.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 373 _, err = cui.conn.Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n])) 374 cui.conn.SetWriteDeadline(time.Time{}) 375 if err != nil { 376 cui.udpConn.Close() 377 return 378 } 379 }() 380 } 381 }() 382 } 383 func (s *MuxClient) UDPGCDeamon() { 384 gctime := int64(30) 385 go func() { 386 defer func() { 387 if e := recover(); e != nil { 388 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 389 } 390 }() 391 if s.isStop { 392 return 393 } 394 timer := time.NewTicker(time.Second) 395 for { 396 <-timer.C 397 gcKeys := []string{} 398 s.udpConns.IterCb(func(key string, v interface{}) { 399 if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime { 400 (*(v.(*ClientUDPConnItem).conn)).Close() 401 (v.(*ClientUDPConnItem).udpConn).Close() 402 gcKeys = append(gcKeys, key) 403 s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid) 404 } 405 }) 406 for _, k := range gcKeys { 407 s.udpConns.Remove(k) 408 } 409 gcKeys = nil 410 } 411 }() 412 } 413 func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) { 414 var err error 415 var outConn net.Conn 416 i := 0 417 for { 418 if s.isStop { 419 return 420 } 421 i++ 422 outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout) 423 if err == nil || i == 3 { 424 break 425 } else { 426 if i == 3 { 427 s.log.Printf("connect to %s err: %s, retrying...", localAddr, err) 428 time.Sleep(2 * time.Second) 429 continue 430 } 431 } 432 } 433 if err != nil { 434 inConn.Close() 435 utils.CloseConn(&outConn) 436 s.log.Printf("build connection error, err: %s", err) 437 return 438 } 439 440 s.log.Printf("stream %s created", ID) 441 if *s.cfg.IsCompress { 442 die1 := make(chan bool, 1) 443 die2 := make(chan bool, 1) 444 go func() { 445 defer func() { 446 if e := recover(); e != nil { 447 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 448 } 449 }() 450 io.Copy(outConn, snappy.NewReader(inConn)) 451 die1 <- true 452 }() 453 go func() { 454 defer func() { 455 if e := recover(); e != nil { 456 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 457 } 458 }() 459 io.Copy(snappy.NewWriter(inConn), outConn) 460 die2 <- true 461 }() 462 select { 463 case <-die1: 464 case <-die2: 465 } 466 outConn.Close() 467 inConn.Close() 468 s.log.Printf("%s stream %s released", *s.cfg.Key, ID) 469 } else { 470 utils.IoBind(inConn, outConn, func(err interface{}) { 471 s.log.Printf("stream %s released", ID) 472 }, s.log) 473 } 474 }