github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/udp.go (about) 1 package dns 2 3 import ( 4 "context" 5 "encoding/binary" 6 "fmt" 7 "math" 8 "math/rand/v2" 9 "net" 10 "sync" 11 12 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns" 15 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 16 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 17 "github.com/Asutorufa/yuhaiin/pkg/utils/singleflight" 18 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 19 ) 20 21 func init() { 22 Register(pdns.Type_udp, NewDoU) 23 } 24 25 type udp struct { 26 *client 27 bufChanMap syncmap.SyncMap[[2]byte, *bufChan] 28 sf singleflight.Group[uint64, net.PacketConn] 29 packetConn net.PacketConn 30 mu sync.RWMutex 31 } 32 33 func (u *udp) Close() error { 34 if u.packetConn != nil { 35 u.packetConn.Close() 36 u.packetConn = nil 37 } 38 return nil 39 } 40 41 func (u *udp) handleResponse(packet net.PacketConn) { 42 defer func() { 43 u.mu.Lock() 44 u.packetConn = nil 45 u.mu.Unlock() 46 47 packet.Close() 48 }() 49 50 for { 51 buf := pool.GetBytesBuffer(nat.MaxSegmentSize) 52 n, _, err := buf.ReadFromPacket(packet) 53 if err != nil { 54 buf.Free() 55 return 56 } 57 58 if n < 2 { 59 buf.Free() 60 continue 61 } 62 63 c, ok := u.bufChanMap.Load([2]byte(buf.Bytes()[:2])) 64 if !ok || c == nil { 65 buf.Free() 66 continue 67 } 68 69 c.Send(buf) 70 } 71 } 72 73 func (u *udp) initPacketConn(ctx context.Context) (net.PacketConn, error) { 74 if u.packetConn != nil { 75 return u.packetConn, nil 76 } 77 78 conn, err, _ := u.sf.Do(0, func() (net.PacketConn, error) { 79 if u.packetConn != nil { 80 _ = u.packetConn.Close() 81 } 82 83 addr, err := ParseAddr(statistic.Type_udp, u.config.Host, "53") 84 if err != nil { 85 return nil, fmt.Errorf("parse addr failed: %w", err) 86 } 87 88 conn, err := u.config.Dialer.PacketConn(ctx, addr) 89 if err != nil { 90 return nil, fmt.Errorf("get packetConn failed: %w", err) 91 } 92 93 u.mu.Lock() 94 u.packetConn = conn 95 u.mu.Unlock() 96 97 go u.handleResponse(conn) 98 return conn, nil 99 }) 100 101 return conn, err 102 } 103 104 type bufChan struct { 105 ctx context.Context 106 bufChan chan *pool.Bytes 107 } 108 109 func (b *bufChan) Send(buf *pool.Bytes) { 110 select { 111 case b.bufChan <- buf: 112 case <-b.ctx.Done(): 113 buf.Free() 114 } 115 } 116 117 func NewDoU(config Config) (netapi.Resolver, error) { 118 addr, err := ParseAddr(statistic.Type_udp, config.Host, "53") 119 if err != nil { 120 return nil, fmt.Errorf("parse addr failed: %w", err) 121 } 122 123 udp := &udp{} 124 125 udp.client = NewClient(config, func(ctx context.Context, req []byte) (*pool.Bytes, error) { 126 127 packetConn, err := udp.initPacketConn(ctx) 128 if err != nil { 129 return nil, err 130 } 131 id := [2]byte{req[0], req[1]} 132 133 ctx, cancel := context.WithCancel(ctx) 134 defer cancel() 135 136 bchan := &bufChan{bufChan: make(chan *pool.Bytes), ctx: ctx} 137 138 _retry: 139 _, ok := udp.bufChanMap.LoadOrStore([2]byte(req[:2]), bchan) 140 if ok { 141 binary.BigEndian.PutUint16(req[0:2], uint16(rand.UintN(math.MaxUint16))) 142 goto _retry 143 } 144 defer udp.bufChanMap.Delete([2]byte(req[:2])) 145 146 udpAddr := addr.UDPAddr(ctx) 147 if udpAddr.Err != nil { 148 return nil, udpAddr.Err 149 } 150 151 _, err = packetConn.WriteTo(req, udpAddr.V) 152 if err != nil { 153 _ = packetConn.Close() 154 return nil, err 155 } 156 157 select { 158 case <-ctx.Done(): 159 return nil, ctx.Err() 160 case data := <-bchan.bufChan: 161 data.Bytes()[0] = id[0] 162 data.Bytes()[1] = id[1] 163 return data, nil 164 } 165 }) 166 167 return udp, nil 168 }