github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/tcp.go (about) 1 package dns 2 3 import ( 4 "context" 5 "crypto/tls" 6 "encoding/binary" 7 "fmt" 8 "io" 9 "net" 10 "net/netip" 11 "strings" 12 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns" 15 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 16 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 17 ) 18 19 func init() { 20 Register(pdns.Type_tcp, NewTCP) 21 } 22 23 func NewTCP(config Config) (netapi.Resolver, error) { 24 return newTCP(config, "53", nil) 25 } 26 27 // ParseAddr 28 // host eg: cloudflare-dns.com, https://cloudflare-dns.com, 1.1.1.1:853 29 func ParseAddr(netType statistic.Type, host, defaultPort string) (netapi.Address, error) { 30 if i := strings.Index(host, "://"); i != -1 { 31 host = host[i+3:] 32 } 33 34 if i := strings.IndexByte(host, '/'); i != -1 { 35 host = host[:i] 36 } 37 38 _, _, err := net.SplitHostPort(host) 39 if err != nil { 40 e, ok := err.(*net.AddrError) 41 if !ok || !strings.Contains(e.Err, "missing port in address") { 42 if ok && strings.Contains(e.Err, "too many colons in address") { 43 if _, er := netip.ParseAddr(host); er != nil { 44 return nil, fmt.Errorf("split host port failed: %w", err) 45 } 46 } 47 } 48 49 host = net.JoinHostPort(host, defaultPort) 50 } 51 52 addr, err := netapi.ParseAddress(netType, host) 53 if err != nil { 54 return nil, fmt.Errorf("parse address failed: %w", err) 55 } 56 57 return addr, nil 58 } 59 60 func newTCP(config Config, defaultPort string, tlsConfig *tls.Config) (*client, error) { 61 addr, err := ParseAddr(statistic.Type_tcp, config.Host, defaultPort) 62 if err != nil { 63 return nil, fmt.Errorf("parse addr failed: %w", err) 64 } 65 66 return NewClient(config, 67 func(ctx context.Context, b []byte) (*pool.Bytes, error) { 68 conn, err := config.Dialer.Conn(ctx, addr) 69 if err != nil { 70 return nil, fmt.Errorf("tcp dial failed: %w", err) 71 } 72 defer conn.Close() 73 74 if tlsConfig != nil { 75 conn = tls.Client(conn, tlsConfig) 76 } 77 78 // dns over tcp, prefix two bytes is request data's length 79 err = binary.Write(conn, binary.BigEndian, uint16(len(b))) 80 if err != nil { 81 return nil, fmt.Errorf("write data length failed: %w", err) 82 } 83 84 _, err = conn.Write(b) 85 if err != nil { 86 return nil, fmt.Errorf("write data failed: %w", err) 87 } 88 89 var length uint16 90 err = binary.Read(conn, binary.BigEndian, &length) 91 if err != nil { 92 return nil, fmt.Errorf("read data length from server failed: %w", err) 93 } 94 95 all := pool.GetBytesBuffer(int(length)) 96 _, err = io.ReadFull(conn, all.Bytes()) 97 if err != nil { 98 return nil, fmt.Errorf("read data from server failed: %w", err) 99 } 100 return all, err 101 }), nil 102 }