github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/functions.go (about) 1 package utils 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/sha1" 8 "crypto/tls" 9 "crypto/x509" 10 "encoding/base64" 11 "encoding/binary" 12 "encoding/hex" 13 "encoding/pem" 14 "errors" 15 "fmt" 16 "io" 17 "io/ioutil" 18 logger "log" 19 "math/rand" 20 "net" 21 "net/http" 22 "os" 23 "strings" 24 25 "github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg" 26 "github.com/AntonOrnatskyi/goproxy/utils/lb" 27 28 "golang.org/x/crypto/pbkdf2" 29 30 "strconv" 31 32 "time" 33 34 "github.com/AntonOrnatskyi/goproxy/utils/id" 35 36 kcp "github.com/xtaci/kcp-go" 37 ) 38 39 func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) { 40 ioBind(dst, src, fn, log, true) 41 } 42 func IoBindNoClose(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) { 43 ioBind(dst, src, fn, log, false) 44 } 45 func ioBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger, close bool) { 46 go func() { 47 defer func() { 48 if err := recover(); err != nil { 49 log.Printf("bind crashed %s", err) 50 } 51 }() 52 e1 := make(chan interface{}, 1) 53 e2 := make(chan interface{}, 1) 54 go func() { 55 defer func() { 56 if err := recover(); err != nil { 57 log.Printf("bind crashed %s", err) 58 } 59 }() 60 //_, err := io.Copy(dst, src) 61 err := ioCopy(dst, src) 62 e1 <- err 63 }() 64 go func() { 65 defer func() { 66 if err := recover(); err != nil { 67 log.Printf("bind crashed %s", err) 68 } 69 }() 70 //_, err := io.Copy(src, dst) 71 err := ioCopy(src, dst) 72 e2 <- err 73 }() 74 var err interface{} 75 select { 76 case err = <-e1: 77 //log.Printf("e1") 78 case err = <-e2: 79 //log.Printf("e2") 80 } 81 func() { 82 defer func() { 83 _ = recover() 84 }() 85 if close { 86 src.Close() 87 } 88 }() 89 func() { 90 defer func() { 91 _ = recover() 92 }() 93 if close { 94 dst.Close() 95 } 96 }() 97 if fn != nil { 98 fn(err) 99 } 100 }() 101 } 102 func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) { 103 defer func() { 104 if e := recover(); e != nil { 105 } 106 }() 107 buf := LeakyBuffer.Get() 108 defer LeakyBuffer.Put(buf) 109 n := 0 110 for { 111 n, err = src.Read(buf) 112 if n > 0 { 113 if n > len(buf) { 114 n = len(buf) 115 } 116 if _, e := dst.Write(buf[0:n]); e != nil { 117 return e 118 } 119 } 120 if err != nil { 121 return 122 } 123 } 124 } 125 func SingleTlsConnectHost(host string, timeout int, caCertBytes []byte) (conn tls.Conn, err error) { 126 h := strings.Split(host, ":") 127 port, _ := strconv.Atoi(h[1]) 128 return SingleTlsConnect(h[0], port, timeout, caCertBytes) 129 } 130 func SingleTlsConnect(host string, port, timeout int, caCertBytes []byte) (conn tls.Conn, err error) { 131 conf, err := getRequestSingleTlsConfig(caCertBytes) 132 if err != nil { 133 return 134 } 135 _conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond) 136 if err != nil { 137 return 138 } 139 return *tls.Client(_conn, conf), err 140 } 141 func SingleTlsConfig(caCertBytes []byte) (conf *tls.Config, err error) { 142 return getRequestSingleTlsConfig(caCertBytes) 143 } 144 func getRequestSingleTlsConfig(caCertBytes []byte) (conf *tls.Config, err error) { 145 conf = &tls.Config{InsecureSkipVerify: true} 146 serverCertPool := x509.NewCertPool() 147 if caCertBytes != nil { 148 ok := serverCertPool.AppendCertsFromPEM(caCertBytes) 149 if !ok { 150 err = errors.New("failed to parse root certificate") 151 } 152 conf.RootCAs = serverCertPool 153 conf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 154 opts := x509.VerifyOptions{ 155 Roots: serverCertPool, 156 } 157 for _, rawCert := range rawCerts { 158 cert, _ := x509.ParseCertificate(rawCert) 159 _, err := cert.Verify(opts) 160 if err != nil { 161 return err 162 } 163 } 164 return nil 165 } 166 } 167 return 168 } 169 func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { 170 h := strings.Split(host, ":") 171 port, _ := strconv.Atoi(h[1]) 172 return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes) 173 } 174 func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { 175 conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes) 176 if err != nil { 177 return 178 } 179 _conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond) 180 if err != nil { 181 return 182 } 183 return *tls.Client(_conn, conf), err 184 } 185 func TlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) { 186 return getRequestTlsConfig(certBytes, keyBytes, caCertBytes) 187 } 188 func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) { 189 190 var cert tls.Certificate 191 cert, err = tls.X509KeyPair(certBytes, keyBytes) 192 if err != nil { 193 return 194 } 195 serverCertPool := x509.NewCertPool() 196 caBytes := certBytes 197 if caCertBytes != nil { 198 caBytes = caCertBytes 199 200 } 201 ok := serverCertPool.AppendCertsFromPEM(caBytes) 202 if !ok { 203 err = errors.New("failed to parse root certificate") 204 } 205 block, _ := pem.Decode(caBytes) 206 if block == nil { 207 panic("failed to parse certificate PEM") 208 } 209 x509Cert, _ := x509.ParseCertificate(block.Bytes) 210 if x509Cert == nil { 211 panic("failed to parse block") 212 } 213 conf = &tls.Config{ 214 RootCAs: serverCertPool, 215 Certificates: []tls.Certificate{cert}, 216 InsecureSkipVerify: true, 217 ServerName: x509Cert.Subject.CommonName, 218 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 219 opts := x509.VerifyOptions{ 220 Roots: serverCertPool, 221 } 222 for _, rawCert := range rawCerts { 223 cert, _ := x509.ParseCertificate(rawCert) 224 _, err := cert.Verify(opts) 225 if err != nil { 226 return err 227 } 228 } 229 return nil 230 }, 231 } 232 return 233 } 234 235 func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) { 236 conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond) 237 return 238 } 239 func ConnectKCPHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.Conn, err error) { 240 kcpconn, err := kcp.DialWithOptions(hostAndPort, config.Block, *config.DataShard, *config.ParityShard) 241 if err != nil { 242 return 243 } 244 kcpconn.SetStreamMode(true) 245 kcpconn.SetWriteDelay(true) 246 kcpconn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion) 247 kcpconn.SetMtu(*config.MTU) 248 kcpconn.SetWindowSize(*config.SndWnd, *config.RcvWnd) 249 kcpconn.SetACKNoDelay(*config.AckNodelay) 250 if *config.NoComp { 251 return kcpconn, err 252 } 253 return NewCompStream(kcpconn), err 254 } 255 256 func PathExists(_path string) bool { 257 _, err := os.Stat(_path) 258 if err != nil && os.IsNotExist(err) { 259 return false 260 } 261 return true 262 } 263 func HTTPGet(URL string, timeout int) (err error) { 264 tr := &http.Transport{} 265 var resp *http.Response 266 var client *http.Client 267 defer func() { 268 if resp != nil && resp.Body != nil { 269 resp.Body.Close() 270 } 271 tr.CloseIdleConnections() 272 }() 273 client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} 274 resp, err = client.Get(URL) 275 if err != nil { 276 return 277 } 278 return 279 } 280 281 func CloseConn(conn *net.Conn) { 282 defer func() { 283 _ = recover() 284 }() 285 if conn != nil && *conn != nil { 286 (*conn).SetDeadline(time.Now().Add(time.Millisecond)) 287 (*conn).Close() 288 } 289 } 290 291 var allInterfaceAddrCache []net.IP 292 293 func GetAllInterfaceAddr() ([]net.IP, error) { 294 if allInterfaceAddrCache != nil { 295 return allInterfaceAddrCache, nil 296 } 297 ifaces, err := net.Interfaces() 298 if err != nil { 299 return nil, err 300 } 301 addresses := []net.IP{} 302 for _, iface := range ifaces { 303 304 if iface.Flags&net.FlagUp == 0 { 305 continue // interface down 306 } 307 // if iface.Flags&net.FlagLoopback != 0 { 308 // continue // loopback interface 309 // } 310 addrs, err := iface.Addrs() 311 if err != nil { 312 continue 313 } 314 315 for _, addr := range addrs { 316 var ip net.IP 317 switch v := addr.(type) { 318 case *net.IPNet: 319 ip = v.IP 320 case *net.IPAddr: 321 ip = v.IP 322 } 323 // if ip == nil || ip.IsLoopback() { 324 // continue 325 // } 326 ip = ip.To4() 327 if ip == nil { 328 continue // not an ipv4 address 329 } 330 addresses = append(addresses, ip) 331 } 332 } 333 if len(addresses) == 0 { 334 return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses) 335 } 336 //only need first 337 allInterfaceAddrCache = addresses 338 return addresses, nil 339 } 340 func UDPPacket(srcAddr string, packet []byte) []byte { 341 addrBytes := []byte(srcAddr) 342 addrLength := uint16(len(addrBytes)) 343 bodyLength := uint16(len(packet)) 344 //log.Printf("build packet : addr len %d, body len %d", addrLength, bodyLength) 345 pkg := new(bytes.Buffer) 346 binary.Write(pkg, binary.LittleEndian, addrLength) 347 binary.Write(pkg, binary.LittleEndian, addrBytes) 348 binary.Write(pkg, binary.LittleEndian, bodyLength) 349 binary.Write(pkg, binary.LittleEndian, packet) 350 return pkg.Bytes() 351 } 352 func ReadUDPPacket(_reader io.Reader) (srcAddr string, packet []byte, err error) { 353 reader := bufio.NewReader(_reader) 354 var addrLength uint16 355 var bodyLength uint16 356 err = binary.Read(reader, binary.LittleEndian, &addrLength) 357 if err != nil { 358 return 359 } 360 _srcAddr := make([]byte, addrLength) 361 n, err := reader.Read(_srcAddr) 362 if err != nil { 363 return 364 } 365 if n != int(addrLength) { 366 err = fmt.Errorf("n != int(addrLength), %d,%d", n, addrLength) 367 return 368 } 369 srcAddr = string(_srcAddr) 370 371 err = binary.Read(reader, binary.LittleEndian, &bodyLength) 372 if err != nil { 373 374 return 375 } 376 packet = make([]byte, bodyLength) 377 n, err = reader.Read(packet) 378 if err != nil { 379 return 380 } 381 if n != int(bodyLength) { 382 err = fmt.Errorf("n != int(bodyLength), %d,%d", n, bodyLength) 383 return 384 } 385 return 386 } 387 func Uniqueid() string { 388 str := fmt.Sprintf("%d%s", time.Now().UnixNano(), xid.New().String()) 389 hash := sha1.New() 390 hash.Write([]byte(str)) 391 return hex.EncodeToString(hash.Sum(nil)) 392 } 393 func RandString(strlen int) string { 394 codes := "QWERTYUIOPLKJHGFDSAZXCVBNMabcdefghijklmnopqrstuvwxyz0123456789" 395 codeLen := len(codes) 396 data := make([]byte, strlen) 397 rand.Seed(time.Now().UnixNano() + rand.Int63() + rand.Int63() + rand.Int63() + rand.Int63()) 398 for i := 0; i < strlen; i++ { 399 idx := rand.Intn(codeLen) 400 data[i] = byte(codes[idx]) 401 } 402 return string(data) 403 } 404 func RandInt(strLen int) int64 { 405 codes := "123456789" 406 codeLen := len(codes) 407 data := make([]byte, strLen) 408 rand.Seed(time.Now().UnixNano() + rand.Int63() + rand.Int63() + rand.Int63() + rand.Int63()) 409 for i := 0; i < strLen; i++ { 410 idx := rand.Intn(codeLen) 411 data[i] = byte(codes[idx]) 412 } 413 i, _ := strconv.ParseInt(string(data), 10, 64) 414 return i 415 } 416 func ReadBytes(r io.Reader) (data []byte, err error) { 417 defer func() { 418 if e := recover(); e != nil { 419 err = fmt.Errorf("read bytes fail ,err : %s", e) 420 } 421 }() 422 var len uint64 423 err = binary.Read(r, binary.LittleEndian, &len) 424 if err != nil { 425 return 426 } 427 if len == 0 || len > ^uint64(0) { 428 err = fmt.Errorf("data len out of range, %d", len) 429 return 430 } 431 var n int 432 data = make([]byte, len) 433 n, err = r.Read(data) 434 if err != nil { 435 return 436 } 437 if n != int(len) { 438 err = fmt.Errorf("error data len") 439 return 440 } 441 return 442 } 443 func ReadData(r io.Reader) (data string, err error) { 444 _data, err := ReadBytes(r) 445 if err != nil { 446 return 447 } 448 data = string(_data) 449 return 450 } 451 452 //non typed packet with Bytes 453 func ReadPacketBytes(r io.Reader, data ...*[]byte) (err error) { 454 for _, d := range data { 455 *d, err = ReadBytes(r) 456 if err != nil { 457 return 458 } 459 } 460 return 461 } 462 func BuildPacketBytes(data ...[]byte) []byte { 463 pkg := new(bytes.Buffer) 464 for _, d := range data { 465 binary.Write(pkg, binary.LittleEndian, uint64(len(d))) 466 binary.Write(pkg, binary.LittleEndian, d) 467 } 468 return pkg.Bytes() 469 } 470 471 //non typed packet with string 472 func ReadPacketData(r io.Reader, data ...*string) (err error) { 473 for _, d := range data { 474 *d, err = ReadData(r) 475 if err != nil { 476 return 477 } 478 } 479 return 480 } 481 func BuildPacketData(data ...string) []byte { 482 pkg := new(bytes.Buffer) 483 for _, d := range data { 484 bytes := []byte(d) 485 binary.Write(pkg, binary.LittleEndian, uint64(len(bytes))) 486 binary.Write(pkg, binary.LittleEndian, bytes) 487 } 488 return pkg.Bytes() 489 } 490 491 //typed packet with bytes 492 func ReadBytesPacket(r io.Reader, packetType *uint8, data ...*[]byte) (err error) { 493 var connType uint8 494 err = binary.Read(r, binary.LittleEndian, &connType) 495 if err != nil { 496 return 497 } 498 *packetType = connType 499 for _, d := range data { 500 *d, err = ReadBytes(r) 501 if err != nil { 502 return 503 } 504 } 505 return 506 } 507 func BuildBytesPacket(packetType uint8, data ...[]byte) []byte { 508 pkg := new(bytes.Buffer) 509 binary.Write(pkg, binary.LittleEndian, packetType) 510 for _, d := range data { 511 binary.Write(pkg, binary.LittleEndian, uint64(len(d))) 512 binary.Write(pkg, binary.LittleEndian, d) 513 } 514 return pkg.Bytes() 515 } 516 517 //typed packet with string 518 func ReadPacket(r io.Reader, packetType *uint8, data ...*string) (err error) { 519 var connType uint8 520 err = binary.Read(r, binary.LittleEndian, &connType) 521 if err != nil { 522 return 523 } 524 *packetType = connType 525 for _, d := range data { 526 *d, err = ReadData(r) 527 if err != nil { 528 return 529 } 530 } 531 return 532 } 533 534 func BuildPacket(packetType uint8, data ...string) []byte { 535 pkg := new(bytes.Buffer) 536 binary.Write(pkg, binary.LittleEndian, packetType) 537 for _, d := range data { 538 bytes := []byte(d) 539 binary.Write(pkg, binary.LittleEndian, uint64(len(bytes))) 540 binary.Write(pkg, binary.LittleEndian, bytes) 541 } 542 return pkg.Bytes() 543 } 544 545 func SubStr(str string, start, end int) string { 546 if len(str) == 0 { 547 return "" 548 } 549 if end >= len(str) { 550 end = len(str) - 1 551 } 552 return str[start:end] 553 } 554 func SubBytes(bytes []byte, start, end int) []byte { 555 if len(bytes) == 0 { 556 return []byte{} 557 } 558 if end >= len(bytes) { 559 end = len(bytes) - 1 560 } 561 return bytes[start:end] 562 } 563 func TlsBytes(cert, key string) (certBytes, keyBytes []byte, err error) { 564 base64Prefix := "base64://" 565 if strings.HasPrefix(cert, base64Prefix) { 566 certBytes, err = base64.StdEncoding.DecodeString(cert[len(base64Prefix):]) 567 } else { 568 certBytes, err = ioutil.ReadFile(cert) 569 } 570 if err != nil { 571 err = fmt.Errorf("err : %s", err) 572 return 573 } 574 if strings.HasPrefix(key, base64Prefix) { 575 keyBytes, err = base64.StdEncoding.DecodeString(key[len(base64Prefix):]) 576 } else { 577 keyBytes, err = ioutil.ReadFile(key) 578 } 579 if err != nil { 580 err = fmt.Errorf("err : %s", err) 581 return 582 } 583 return 584 } 585 func GetKCPBlock(method, key string) (block kcp.BlockCrypt) { 586 pass := pbkdf2.Key([]byte(key), []byte(key), 4096, 32, sha1.New) 587 switch method { 588 case "sm4": 589 block, _ = kcp.NewSM4BlockCrypt(pass[:16]) 590 case "tea": 591 block, _ = kcp.NewTEABlockCrypt(pass[:16]) 592 case "xor": 593 block, _ = kcp.NewSimpleXORBlockCrypt(pass) 594 case "none": 595 block, _ = kcp.NewNoneBlockCrypt(pass) 596 case "aes-128": 597 block, _ = kcp.NewAESBlockCrypt(pass[:16]) 598 case "aes-192": 599 block, _ = kcp.NewAESBlockCrypt(pass[:24]) 600 case "blowfish": 601 block, _ = kcp.NewBlowfishBlockCrypt(pass) 602 case "twofish": 603 block, _ = kcp.NewTwofishBlockCrypt(pass) 604 case "cast5": 605 block, _ = kcp.NewCast5BlockCrypt(pass[:16]) 606 case "3des": 607 block, _ = kcp.NewTripleDESBlockCrypt(pass[:24]) 608 case "xtea": 609 block, _ = kcp.NewXTEABlockCrypt(pass[:16]) 610 case "salsa20": 611 block, _ = kcp.NewSalsa20BlockCrypt(pass) 612 default: 613 block, _ = kcp.NewAESBlockCrypt(pass) 614 } 615 return 616 } 617 func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, err error) { 618 var tr *http.Transport 619 var client *http.Client 620 conf := &tls.Config{ 621 InsecureSkipVerify: true, 622 } 623 if strings.Contains(URL, "https://") { 624 tr = &http.Transport{TLSClientConfig: conf} 625 client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} 626 } else { 627 tr = &http.Transport{} 628 client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} 629 } 630 defer tr.CloseIdleConnections() 631 632 //resp, err := client.Get(URL) 633 req, err := http.NewRequest("GET", URL, nil) 634 if err != nil { 635 return 636 } 637 if len(host) == 1 && host[0] != "" { 638 req.Host = host[0] 639 } 640 resp, err := client.Do(req) 641 if err != nil { 642 return 643 } 644 defer resp.Body.Close() 645 code = resp.StatusCode 646 body, err = ioutil.ReadAll(resp.Body) 647 return 648 } 649 func IsInternalIP(domainOrIP string, always bool) bool { 650 var outIPs []net.IP 651 var err error 652 var isDomain bool 653 if net.ParseIP(domainOrIP) == nil { 654 isDomain = true 655 } 656 if always && isDomain { 657 return false 658 } 659 660 if isDomain { 661 outIPs, err = LookupIP(domainOrIP) 662 } else { 663 outIPs = []net.IP{net.ParseIP(domainOrIP)} 664 } 665 666 if err != nil { 667 return false 668 } 669 670 for _, ip := range outIPs { 671 if ip.IsLoopback() { 672 return true 673 } 674 if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "10.0.0.0" { 675 return true 676 } 677 if ip.To4().Mask(net.IPv4Mask(255, 255, 0, 0)).String() == "192.168.0.0" { 678 return true 679 } 680 if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "172.0.0.0" { 681 i, _ := strconv.Atoi(strings.Split(ip.To4().String(), ".")[1]) 682 return i >= 16 && i <= 31 683 } 684 } 685 return false 686 } 687 func IsHTTP(head []byte) bool { 688 keys := []string{"GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"} 689 for _, key := range keys { 690 if bytes.HasPrefix(head, []byte(key)) || bytes.HasPrefix(head, []byte(strings.ToLower(key))) { 691 return true 692 } 693 } 694 return false 695 } 696 func IsSocks5(head []byte) bool { 697 if len(head) < 3 { 698 return false 699 } 700 if head[0] == uint8(0x05) && 0 < int(head[1]) && int(head[1]) < 255 { 701 if len(head) == 2+int(head[1]) { 702 return true 703 } 704 } 705 return false 706 } 707 func RemoveProxyHeaders(head []byte) []byte { 708 newLines := [][]byte{} 709 var keys = map[string]bool{} 710 lines := bytes.Split(head, []byte("\r\n")) 711 IsBody := false 712 i := -1 713 for _, line := range lines { 714 i++ 715 if len(line) == 0 || IsBody { 716 newLines = append(newLines, line) 717 IsBody = true 718 } else { 719 hline := bytes.SplitN(line, []byte(":"), 2) 720 if i == 0 && IsHTTP(head) { 721 newLines = append(newLines, line) 722 continue 723 } 724 if len(hline) != 2 { 725 continue 726 } 727 k := strings.ToUpper(string(hline[0])) 728 if _, ok := keys[k]; ok || strings.HasPrefix(k, "PROXY-") { 729 continue 730 } 731 keys[k] = true 732 newLines = append(newLines, line) 733 } 734 } 735 return bytes.Join(newLines, []byte("\r\n")) 736 } 737 func InsertProxyHeaders(head []byte, headers string) []byte { 738 return bytes.Replace(head, []byte("\r\n"), []byte("\r\n"+headers), 1) 739 } 740 func LBMethod(key string) int { 741 typs := map[string]int{"weight": lb.SELECT_WEITHT, "leasttime": lb.SELECT_LEASTTIME, "leastconn": lb.SELECT_LEASTCONN, "hash": lb.SELECT_HASH, "roundrobin": lb.SELECT_ROUNDROBIN} 742 return typs[key] 743 } 744 func UDPCopy(dst, src *net.UDPConn, dstAddr net.Addr, readTimeout time.Duration, beforeWriteFn func(data []byte) []byte, deferFn func(e interface{})) { 745 go func() { 746 defer func() { 747 deferFn(recover()) 748 }() 749 buf := LeakyBuffer.Get() 750 defer LeakyBuffer.Put(buf) 751 for { 752 if readTimeout > 0 { 753 src.SetReadDeadline(time.Now().Add(readTimeout)) 754 } 755 n, err := src.Read(buf) 756 if readTimeout > 0 { 757 src.SetReadDeadline(time.Time{}) 758 } 759 if err != nil { 760 if IsNetClosedErr(err) || IsNetTimeoutErr(err) || IsNetRefusedErr(err) { 761 return 762 } 763 continue 764 } 765 _, err = dst.WriteTo(beforeWriteFn(buf[:n]), dstAddr) 766 if err != nil { 767 if IsNetClosedErr(err) { 768 return 769 } 770 continue 771 } 772 } 773 }() 774 } 775 func IsNetClosedErr(err error) bool { 776 return err != nil && strings.Contains(err.Error(), "use of closed network connection") 777 } 778 func IsNetTimeoutErr(err error) bool { 779 if err == nil { 780 return false 781 } 782 e, ok := err.(net.Error) 783 return ok && e.Timeout() 784 } 785 func IsNetRefusedErr(err error) bool { 786 return err != nil && strings.Contains(err.Error(), "connection refused") 787 } 788 func IsNetDeadlineErr(err error) bool { 789 return err != nil && strings.Contains(err.Error(), "i/o deadline reached") 790 } 791 func IsNetSocketNotConnectedErr(err error) bool { 792 return err != nil && strings.Contains(err.Error(), "socket is not connected") 793 } 794 func NewDefaultLogger() *logger.Logger { 795 return logger.New(os.Stderr, "", logger.LstdFlags) 796 } 797 798 // type sockaddr struct { 799 // family uint16 800 // data [14]byte 801 // } 802 803 // const SO_ORIGINAL_DST = 80 804 805 // realServerAddress returns an intercepted connection's original destination. 806 // func realServerAddress(conn *net.Conn) (string, error) { 807 // tcpConn, ok := (*conn).(*net.TCPConn) 808 // if !ok { 809 // return "", errors.New("not a TCPConn") 810 // } 811 812 // file, err := tcpConn.File() 813 // if err != nil { 814 // return "", err 815 // } 816 817 // // To avoid potential problems from making the socket non-blocking. 818 // tcpConn.Close() 819 // *conn, err = net.FileConn(file) 820 // if err != nil { 821 // return "", err 822 // } 823 824 // defer file.Close() 825 // fd := file.Fd() 826 827 // var addr sockaddr 828 // size := uint32(unsafe.Sizeof(addr)) 829 // err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size) 830 // if err != nil { 831 // return "", err 832 // } 833 834 // var ip net.IP 835 // switch addr.family { 836 // case syscall.AF_INET: 837 // ip = addr.data[2:6] 838 // default: 839 // return "", errors.New("unrecognized address family") 840 // } 841 842 // port := int(addr.data[0])<<8 + int(addr.data[1]) 843 844 // return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil 845 // } 846 847 // func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) { 848 // _, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0) 849 // if e1 != 0 { 850 // err = e1 851 // } 852 // return 853 // } 854 855 /* 856 net.LookupIP may cause deadlock in windows 857 https://github.com/golang/go/issues/24178 858 */ 859 860 func LookupIP(host string) ([]net.IP, error) { 861 862 ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(3)) 863 defer func() { 864 cancel() 865 //ctx.Done() 866 }() 867 addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) 868 if err != nil { 869 return nil, err 870 } 871 ips := make([]net.IP, len(addrs)) 872 for i, ia := range addrs { 873 ips[i] = ia.IP 874 } 875 return ips, nil 876 }