github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/udp/udp.go (about) 1 package udp 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "io" 7 logger "log" 8 "net" 9 "runtime/debug" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/AntonOrnatskyi/goproxy/core/cs/server" 15 "github.com/AntonOrnatskyi/goproxy/services" 16 "github.com/AntonOrnatskyi/goproxy/utils" 17 "github.com/AntonOrnatskyi/goproxy/utils/mapx" 18 ) 19 20 type UDPArgs struct { 21 Parent *string 22 CertFile *string 23 KeyFile *string 24 CertBytes []byte 25 KeyBytes []byte 26 Local *string 27 ParentType *string 28 Timeout *int 29 CheckParentInterval *int 30 } 31 type UDP struct { 32 p mapx.ConcurrentMap 33 cfg UDPArgs 34 sc *server.ServerChannel 35 isStop bool 36 log *logger.Logger 37 outUDPConnCtxMap mapx.ConcurrentMap 38 udpConns mapx.ConcurrentMap 39 dstAddr *net.UDPAddr 40 } 41 type UDPConnItem struct { 42 conn *net.Conn 43 touchtime int64 44 srcAddr *net.UDPAddr 45 localAddr *net.UDPAddr 46 connid string 47 } 48 type outUDPConnCtx struct { 49 localAddr *net.UDPAddr 50 srcAddr *net.UDPAddr 51 udpconn *net.UDPConn 52 touchtime int64 53 } 54 55 func NewUDP() services.Service { 56 return &UDP{ 57 p: mapx.NewConcurrentMap(), 58 isStop: false, 59 outUDPConnCtxMap: mapx.NewConcurrentMap(), 60 udpConns: mapx.NewConcurrentMap(), 61 } 62 } 63 func (s *UDP) CheckArgs() (err error) { 64 if len(*s.cfg.Parent) == 0 { 65 err = fmt.Errorf("parent required for udp %s", *s.cfg.Local) 66 return 67 } 68 if *s.cfg.ParentType == "" { 69 err = fmt.Errorf("parent type unkown,use -T <udp|tls|tcp>") 70 return 71 } 72 if *s.cfg.ParentType == "tls" { 73 s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) 74 if err != nil { 75 return 76 } 77 } 78 79 s.dstAddr, err = net.ResolveUDPAddr("udp", *s.cfg.Parent) 80 if err != nil { 81 s.log.Printf("resolve udp addr %s fail fail,ERR:%s", *s.cfg.Parent, err) 82 return 83 } 84 return 85 } 86 func (s *UDP) InitService() (err error) { 87 s.OutToUDPGCDeamon() 88 s.UDPGCDeamon() 89 return 90 } 91 func (s *UDP) StopService() { 92 defer func() { 93 e := recover() 94 if e != nil { 95 s.log.Printf("stop udp service crashed,%s", e) 96 } else { 97 s.log.Printf("service udp stopped") 98 } 99 s.cfg = UDPArgs{} 100 s.log = nil 101 s.p = nil 102 s.sc = nil 103 s = nil 104 }() 105 s.isStop = true 106 if s.sc.Listener != nil && *s.sc.Listener != nil { 107 (*s.sc.Listener).Close() 108 } 109 if s.sc.UDPListener != nil { 110 (*s.sc.UDPListener).Close() 111 } 112 } 113 func (s *UDP) Start(args interface{}, log *logger.Logger) (err error) { 114 s.log = log 115 s.cfg = args.(UDPArgs) 116 if err = s.CheckArgs(); err != nil { 117 return 118 } 119 s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) 120 if err = s.InitService(); err != nil { 121 return 122 } 123 host, port, _ := net.SplitHostPort(*s.cfg.Local) 124 p, _ := strconv.Atoi(port) 125 sc := server.NewServerChannel(host, p, s.log) 126 s.sc = &sc 127 err = sc.ListenUDP(s.callback) 128 if err != nil { 129 return 130 } 131 s.log.Printf("udp proxy on %s", (*sc.UDPListener).LocalAddr()) 132 return 133 } 134 135 func (s *UDP) Clean() { 136 s.StopService() 137 } 138 func (s *UDP) callback(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) { 139 defer func() { 140 if err := recover(); err != nil { 141 s.log.Printf("udp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) 142 } 143 }() 144 switch *s.cfg.ParentType { 145 case "tcp", "tls": 146 s.OutToTCP(packet, localAddr, srcAddr) 147 case "udp": 148 s.OutToUDP(packet, localAddr, srcAddr) 149 default: 150 s.log.Printf("unkown parent type %s", *s.cfg.ParentType) 151 } 152 } 153 func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) { 154 isNew = !s.p.Has(connKey) 155 var _conn interface{} 156 if isNew { 157 _conn, err = s.GetParentConn() 158 if err != nil { 159 return nil, false, err 160 } 161 s.p.Set(connKey, _conn) 162 } else { 163 _conn, _ = s.p.Get(connKey) 164 } 165 conn = _conn.(net.Conn) 166 return 167 } 168 func (s *UDP) OutToTCP(data []byte, localAddr, srcAddr *net.UDPAddr) (err error) { 169 s.UDPSend(data, localAddr, srcAddr) 170 return 171 } 172 func (s *UDP) OutToUDPGCDeamon() { 173 gctime := int64(30) 174 go func() { 175 defer func() { 176 if e := recover(); e != nil { 177 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 178 } 179 }() 180 if s.isStop { 181 return 182 } 183 timer := time.NewTicker(time.Second) 184 for { 185 <-timer.C 186 gcKeys := []string{} 187 s.outUDPConnCtxMap.IterCb(func(key string, v interface{}) { 188 if time.Now().Unix()-v.(*outUDPConnCtx).touchtime > gctime { 189 (*(v.(*outUDPConnCtx).udpconn)).Close() 190 gcKeys = append(gcKeys, key) 191 s.log.Printf("gc udp conn %s <--> %s", (*v.(*outUDPConnCtx)).srcAddr, (*v.(*outUDPConnCtx)).localAddr) 192 } 193 }) 194 for _, k := range gcKeys { 195 s.outUDPConnCtxMap.Remove(k) 196 } 197 gcKeys = nil 198 } 199 }() 200 } 201 func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) { 202 var ouc *outUDPConnCtx 203 if v, ok := s.outUDPConnCtxMap.Get(srcAddr.String()); !ok { 204 clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} 205 conn, err := net.DialUDP("udp", clientSrcAddr, s.dstAddr) 206 if err != nil { 207 s.log.Printf("connect to udp %s fail,ERR:%s", s.dstAddr.String(), err) 208 209 } 210 ouc = &outUDPConnCtx{ 211 localAddr: localAddr, 212 srcAddr: srcAddr, 213 udpconn: conn, 214 } 215 s.outUDPConnCtxMap.Set(srcAddr.String(), ouc) 216 go func() { 217 defer func() { 218 if e := recover(); e != nil { 219 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 220 } 221 }() 222 s.log.Printf("udp conn %s <--> %s connected", srcAddr.String(), localAddr.String()) 223 buf := utils.LeakyBuffer.Get() 224 defer func() { 225 utils.LeakyBuffer.Put(buf) 226 s.outUDPConnCtxMap.Remove(srcAddr.String()) 227 s.log.Printf("udp conn %s <--> %s released", srcAddr.String(), localAddr.String()) 228 }() 229 for { 230 n, err := ouc.udpconn.Read(buf) 231 if err != nil { 232 if !utils.IsNetClosedErr(err) { 233 s.log.Printf("udp conn read udp packet fail , err: %s ", err) 234 } 235 return 236 } 237 ouc.touchtime = time.Now().Unix() 238 go func() { 239 defer func() { 240 if e := recover(); e != nil { 241 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 242 } 243 }() 244 (*(s.sc).UDPListener).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 245 _, err = (*(s.sc).UDPListener).WriteTo(buf[:n], srcAddr) 246 (*(s.sc).UDPListener).SetWriteDeadline(time.Time{}) 247 }() 248 } 249 }() 250 } else { 251 ouc = v.(*outUDPConnCtx) 252 } 253 go func() { 254 ouc.touchtime = time.Now().Unix() 255 ouc.udpconn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 256 ouc.udpconn.Write(packet) 257 ouc.udpconn.SetWriteDeadline(time.Time{}) 258 }() 259 return 260 } 261 func (s *UDP) GetParentConn() (conn net.Conn, err error) { 262 if *s.cfg.ParentType == "tls" { 263 var _conn tls.Conn 264 _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) 265 if err == nil { 266 conn = net.Conn(&_conn) 267 } 268 } else { 269 conn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout) 270 } 271 return 272 } 273 func (s *UDP) UDPGCDeamon() { 274 gctime := int64(30) 275 go func() { 276 defer func() { 277 if e := recover(); e != nil { 278 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 279 } 280 }() 281 if s.isStop { 282 return 283 } 284 timer := time.NewTicker(time.Second) 285 for { 286 <-timer.C 287 gcKeys := []string{} 288 s.udpConns.IterCb(func(key string, v interface{}) { 289 if time.Now().Unix()-v.(*UDPConnItem).touchtime > gctime { 290 (*(v.(*UDPConnItem).conn)).Close() 291 gcKeys = append(gcKeys, key) 292 s.log.Printf("gc udp conn %s", v.(*UDPConnItem).connid) 293 } 294 }) 295 for _, k := range gcKeys { 296 s.udpConns.Remove(k) 297 } 298 gcKeys = nil 299 } 300 }() 301 } 302 func (s *UDP) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { 303 var ( 304 uc *UDPConnItem 305 key = srcAddr.String() 306 err error 307 outconn net.Conn 308 ) 309 v, ok := s.udpConns.Get(key) 310 if !ok { 311 for { 312 outconn, err = s.GetParentConn() 313 if err != nil && strings.Contains(err.Error(), "can not connect at same time") { 314 time.Sleep(time.Millisecond * 500) 315 continue 316 } else { 317 break 318 } 319 } 320 if err != nil { 321 s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err) 322 return 323 } 324 uc = &UDPConnItem{ 325 conn: &outconn, 326 srcAddr: srcAddr, 327 localAddr: localAddr, 328 } 329 s.udpConns.Set(key, uc) 330 s.UDPRevecive(key) 331 } else { 332 uc = v.(*UDPConnItem) 333 } 334 go func() { 335 defer func() { 336 if e := recover(); e != nil { 337 (*uc.conn).Close() 338 s.udpConns.Remove(key) 339 s.log.Printf("udp sender crashed with error : %s", e) 340 } 341 }() 342 uc.touchtime = time.Now().Unix() 343 (*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) 344 _, err = (*uc.conn).Write(utils.UDPPacket(fmt.Sprintf("%s", srcAddr.String()), data)) 345 (*uc.conn).SetWriteDeadline(time.Time{}) 346 if err != nil { 347 s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err) 348 } 349 }() 350 } 351 func (s *UDP) UDPRevecive(key string) { 352 go func() { 353 defer func() { 354 if e := recover(); e != nil { 355 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 356 } 357 }() 358 s.log.Printf("udp conn %s connected", key) 359 var uc *UDPConnItem 360 defer func() { 361 if uc != nil { 362 (*uc.conn).Close() 363 } 364 s.udpConns.Remove(key) 365 s.log.Printf("udp conn %s released", key) 366 }() 367 v, ok := s.udpConns.Get(key) 368 if !ok { 369 s.log.Printf("[warn] udp conn not exists for %s", key) 370 return 371 } 372 uc = v.(*UDPConnItem) 373 for { 374 _, body, err := utils.ReadUDPPacket(*uc.conn) 375 if err != nil { 376 if strings.Contains(err.Error(), "n != int(") { 377 continue 378 } 379 if err != io.EOF && !utils.IsNetClosedErr(err) { 380 s.log.Printf("udp conn read udp packet fail , err: %s ", err) 381 } 382 return 383 } 384 uc.touchtime = time.Now().Unix() 385 go func() { 386 defer func() { 387 if e := recover(); e != nil { 388 fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack())) 389 } 390 }() 391 s.sc.UDPListener.WriteToUDP(body, uc.srcAddr) 392 }() 393 } 394 }() 395 }