github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/tcp.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/netip"
    11  	"strings"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    17  )
    18  
    19  func init() {
    20  	Register(pdns.Type_tcp, NewTCP)
    21  }
    22  
    23  func NewTCP(config Config) (netapi.Resolver, error) {
    24  	return newTCP(config, "53", nil)
    25  }
    26  
    27  // ParseAddr
    28  // host eg: cloudflare-dns.com, https://cloudflare-dns.com, 1.1.1.1:853
    29  func ParseAddr(netType statistic.Type, host, defaultPort string) (netapi.Address, error) {
    30  	if i := strings.Index(host, "://"); i != -1 {
    31  		host = host[i+3:]
    32  	}
    33  
    34  	if i := strings.IndexByte(host, '/'); i != -1 {
    35  		host = host[:i]
    36  	}
    37  
    38  	_, _, err := net.SplitHostPort(host)
    39  	if err != nil {
    40  		e, ok := err.(*net.AddrError)
    41  		if !ok || !strings.Contains(e.Err, "missing port in address") {
    42  			if ok && strings.Contains(e.Err, "too many colons in address") {
    43  				if _, er := netip.ParseAddr(host); er != nil {
    44  					return nil, fmt.Errorf("split host port failed: %w", err)
    45  				}
    46  			}
    47  		}
    48  
    49  		host = net.JoinHostPort(host, defaultPort)
    50  	}
    51  
    52  	addr, err := netapi.ParseAddress(netType, host)
    53  	if err != nil {
    54  		return nil, fmt.Errorf("parse address failed: %w", err)
    55  	}
    56  
    57  	return addr, nil
    58  }
    59  
    60  func newTCP(config Config, defaultPort string, tlsConfig *tls.Config) (*client, error) {
    61  	addr, err := ParseAddr(statistic.Type_tcp, config.Host, defaultPort)
    62  	if err != nil {
    63  		return nil, fmt.Errorf("parse addr failed: %w", err)
    64  	}
    65  
    66  	return NewClient(config,
    67  		func(ctx context.Context, b []byte) (*pool.Bytes, error) {
    68  			conn, err := config.Dialer.Conn(ctx, addr)
    69  			if err != nil {
    70  				return nil, fmt.Errorf("tcp dial failed: %w", err)
    71  			}
    72  			defer conn.Close()
    73  
    74  			if tlsConfig != nil {
    75  				conn = tls.Client(conn, tlsConfig)
    76  			}
    77  
    78  			// dns over tcp, prefix two bytes is request data's length
    79  			err = binary.Write(conn, binary.BigEndian, uint16(len(b)))
    80  			if err != nil {
    81  				return nil, fmt.Errorf("write data length failed: %w", err)
    82  			}
    83  
    84  			_, err = conn.Write(b)
    85  			if err != nil {
    86  				return nil, fmt.Errorf("write data failed: %w", err)
    87  			}
    88  
    89  			var length uint16
    90  			err = binary.Read(conn, binary.BigEndian, &length)
    91  			if err != nil {
    92  				return nil, fmt.Errorf("read data length from server failed: %w", err)
    93  			}
    94  
    95  			all := pool.GetBytesBuffer(int(length))
    96  			_, err = io.ReadFull(conn, all.Bytes())
    97  			if err != nil {
    98  				return nil, fmt.Errorf("read data from server failed: %w", err)
    99  			}
   100  			return all, err
   101  		}), nil
   102  }