github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/udp.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"math"
     8  	"math/rand/v2"
     9  	"net"
    10  	"sync"
    11  
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/singleflight"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    19  )
    20  
    21  func init() {
    22  	Register(pdns.Type_udp, NewDoU)
    23  }
    24  
    25  type udp struct {
    26  	*client
    27  	bufChanMap syncmap.SyncMap[[2]byte, *bufChan]
    28  	sf         singleflight.Group[uint64, net.PacketConn]
    29  	packetConn net.PacketConn
    30  	mu         sync.RWMutex
    31  }
    32  
    33  func (u *udp) Close() error {
    34  	if u.packetConn != nil {
    35  		u.packetConn.Close()
    36  		u.packetConn = nil
    37  	}
    38  	return nil
    39  }
    40  
    41  func (u *udp) handleResponse(packet net.PacketConn) {
    42  	defer func() {
    43  		u.mu.Lock()
    44  		u.packetConn = nil
    45  		u.mu.Unlock()
    46  
    47  		packet.Close()
    48  	}()
    49  
    50  	for {
    51  		buf := pool.GetBytesBuffer(nat.MaxSegmentSize)
    52  		n, _, err := buf.ReadFromPacket(packet)
    53  		if err != nil {
    54  			buf.Free()
    55  			return
    56  		}
    57  
    58  		if n < 2 {
    59  			buf.Free()
    60  			continue
    61  		}
    62  
    63  		c, ok := u.bufChanMap.Load([2]byte(buf.Bytes()[:2]))
    64  		if !ok || c == nil {
    65  			buf.Free()
    66  			continue
    67  		}
    68  
    69  		c.Send(buf)
    70  	}
    71  }
    72  
    73  func (u *udp) initPacketConn(ctx context.Context) (net.PacketConn, error) {
    74  	if u.packetConn != nil {
    75  		return u.packetConn, nil
    76  	}
    77  
    78  	conn, err, _ := u.sf.Do(0, func() (net.PacketConn, error) {
    79  		if u.packetConn != nil {
    80  			_ = u.packetConn.Close()
    81  		}
    82  
    83  		addr, err := ParseAddr(statistic.Type_udp, u.config.Host, "53")
    84  		if err != nil {
    85  			return nil, fmt.Errorf("parse addr failed: %w", err)
    86  		}
    87  
    88  		conn, err := u.config.Dialer.PacketConn(ctx, addr)
    89  		if err != nil {
    90  			return nil, fmt.Errorf("get packetConn failed: %w", err)
    91  		}
    92  
    93  		u.mu.Lock()
    94  		u.packetConn = conn
    95  		u.mu.Unlock()
    96  
    97  		go u.handleResponse(conn)
    98  		return conn, nil
    99  	})
   100  
   101  	return conn, err
   102  }
   103  
   104  type bufChan struct {
   105  	ctx     context.Context
   106  	bufChan chan *pool.Bytes
   107  }
   108  
   109  func (b *bufChan) Send(buf *pool.Bytes) {
   110  	select {
   111  	case b.bufChan <- buf:
   112  	case <-b.ctx.Done():
   113  		buf.Free()
   114  	}
   115  }
   116  
   117  func NewDoU(config Config) (netapi.Resolver, error) {
   118  	addr, err := ParseAddr(statistic.Type_udp, config.Host, "53")
   119  	if err != nil {
   120  		return nil, fmt.Errorf("parse addr failed: %w", err)
   121  	}
   122  
   123  	udp := &udp{}
   124  
   125  	udp.client = NewClient(config, func(ctx context.Context, req []byte) (*pool.Bytes, error) {
   126  
   127  		packetConn, err := udp.initPacketConn(ctx)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  		id := [2]byte{req[0], req[1]}
   132  
   133  		ctx, cancel := context.WithCancel(ctx)
   134  		defer cancel()
   135  
   136  		bchan := &bufChan{bufChan: make(chan *pool.Bytes), ctx: ctx}
   137  
   138  	_retry:
   139  		_, ok := udp.bufChanMap.LoadOrStore([2]byte(req[:2]), bchan)
   140  		if ok {
   141  			binary.BigEndian.PutUint16(req[0:2], uint16(rand.UintN(math.MaxUint16)))
   142  			goto _retry
   143  		}
   144  		defer udp.bufChanMap.Delete([2]byte(req[:2]))
   145  
   146  		udpAddr := addr.UDPAddr(ctx)
   147  		if udpAddr.Err != nil {
   148  			return nil, udpAddr.Err
   149  		}
   150  
   151  		_, err = packetConn.WriteTo(req, udpAddr.V)
   152  		if err != nil {
   153  			_ = packetConn.Close()
   154  			return nil, err
   155  		}
   156  
   157  		select {
   158  		case <-ctx.Done():
   159  			return nil, ctx.Err()
   160  		case data := <-bchan.bufChan:
   161  			data.Bytes()[0] = id[0]
   162  			data.Bytes()[1] = id[1]
   163  			return data, nil
   164  		}
   165  	})
   166  
   167  	return udp, nil
   168  }