github.com/sagernet/sing-box@v1.9.0-rc.20/outbound/dns.go (about) 1 package outbound 2 3 import ( 4 "context" 5 "encoding/binary" 6 "net" 7 "os" 8 9 "github.com/sagernet/sing-box/adapter" 10 C "github.com/sagernet/sing-box/constant" 11 "github.com/sagernet/sing-dns" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 "github.com/sagernet/sing/common/bufio" 15 "github.com/sagernet/sing/common/canceler" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/task" 19 20 mDNS "github.com/miekg/dns" 21 ) 22 23 var _ adapter.Outbound = (*DNS)(nil) 24 25 type DNS struct { 26 myOutboundAdapter 27 } 28 29 func NewDNS(router adapter.Router, tag string) *DNS { 30 return &DNS{ 31 myOutboundAdapter{ 32 protocol: C.TypeDNS, 33 network: []string{N.NetworkTCP, N.NetworkUDP}, 34 router: router, 35 tag: tag, 36 }, 37 } 38 } 39 40 func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 41 return nil, os.ErrInvalid 42 } 43 44 func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 45 return nil, os.ErrInvalid 46 } 47 48 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 49 metadata.Destination = M.Socksaddr{} 50 defer conn.Close() 51 for { 52 err := d.handleConnection(ctx, conn, metadata) 53 if err != nil { 54 return err 55 } 56 } 57 } 58 59 func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 60 var queryLength uint16 61 err := binary.Read(conn, binary.BigEndian, &queryLength) 62 if err != nil { 63 return err 64 } 65 if queryLength == 0 { 66 return dns.RCodeFormatError 67 } 68 buffer := buf.NewSize(int(queryLength)) 69 defer buffer.Release() 70 _, err = buffer.ReadFullFrom(conn, int(queryLength)) 71 if err != nil { 72 return err 73 } 74 var message mDNS.Msg 75 err = message.Unpack(buffer.Bytes()) 76 if err != nil { 77 return err 78 } 79 metadataInQuery := metadata 80 go func() error { 81 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 82 if err != nil { 83 return err 84 } 85 responseBuffer := buf.NewPacket() 86 defer responseBuffer.Release() 87 responseBuffer.Resize(2, 0) 88 n, err := response.PackBuffer(responseBuffer.FreeBytes()) 89 if err != nil { 90 return err 91 } 92 responseBuffer.Truncate(len(n)) 93 binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n))) 94 _, err = conn.Write(responseBuffer.Bytes()) 95 return err 96 }() 97 return nil 98 } 99 100 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 101 metadata.Destination = M.Socksaddr{} 102 var reader N.PacketReader = conn 103 var counters []N.CountFunc 104 var cachedPackets []*N.PacketBuffer 105 for { 106 reader, counters = N.UnwrapCountPacketReader(reader, counters) 107 if cachedReader, isCached := reader.(N.CachedPacketReader); isCached { 108 packet := cachedReader.ReadCachedPacket() 109 if packet != nil { 110 cachedPackets = append(cachedPackets, packet) 111 continue 112 } 113 } 114 if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { 115 readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) 116 return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) 117 } 118 break 119 } 120 fastClose, cancel := common.ContextWithCancelCause(ctx) 121 timeout := canceler.New(fastClose, cancel, C.DNSTimeout) 122 var group task.Group 123 group.Append0(func(ctx context.Context) error { 124 for { 125 var message mDNS.Msg 126 var destination M.Socksaddr 127 var err error 128 if len(cachedPackets) > 0 { 129 packet := cachedPackets[0] 130 cachedPackets = cachedPackets[1:] 131 for _, counter := range counters { 132 counter(int64(packet.Buffer.Len())) 133 } 134 err = message.Unpack(packet.Buffer.Bytes()) 135 packet.Buffer.Release() 136 if err != nil { 137 cancel(err) 138 return err 139 } 140 destination = packet.Destination 141 } else { 142 buffer := buf.NewPacket() 143 destination, err = conn.ReadPacket(buffer) 144 if err != nil { 145 buffer.Release() 146 cancel(err) 147 return err 148 } 149 for _, counter := range counters { 150 counter(int64(buffer.Len())) 151 } 152 err = message.Unpack(buffer.Bytes()) 153 buffer.Release() 154 if err != nil { 155 cancel(err) 156 return err 157 } 158 timeout.Update() 159 } 160 metadataInQuery := metadata 161 go func() error { 162 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 163 if err != nil { 164 cancel(err) 165 return err 166 } 167 timeout.Update() 168 responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024) 169 if err != nil { 170 cancel(err) 171 return err 172 } 173 err = conn.WritePacket(responseBuffer, destination) 174 if err != nil { 175 cancel(err) 176 } 177 return err 178 }() 179 } 180 }) 181 group.Cleanup(func() { 182 conn.Close() 183 }) 184 return group.Run(fastClose) 185 } 186 187 func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { 188 ctx = adapter.WithContext(ctx, &metadata) 189 fastClose, cancel := common.ContextWithCancelCause(ctx) 190 timeout := canceler.New(fastClose, cancel, C.DNSTimeout) 191 var group task.Group 192 group.Append0(func(ctx context.Context) error { 193 for { 194 var ( 195 message mDNS.Msg 196 destination M.Socksaddr 197 err error 198 buffer *buf.Buffer 199 ) 200 if len(cached) > 0 { 201 packet := cached[0] 202 cached = cached[1:] 203 for _, counter := range readCounters { 204 counter(int64(packet.Buffer.Len())) 205 } 206 err = message.Unpack(packet.Buffer.Bytes()) 207 packet.Buffer.Release() 208 if err != nil { 209 cancel(err) 210 return err 211 } 212 destination = packet.Destination 213 } else { 214 buffer, destination, err = readWaiter.WaitReadPacket() 215 if err != nil { 216 cancel(err) 217 return err 218 } 219 for _, counter := range readCounters { 220 counter(int64(buffer.Len())) 221 } 222 err = message.Unpack(buffer.Bytes()) 223 buffer.Release() 224 if err != nil { 225 cancel(err) 226 return err 227 } 228 timeout.Update() 229 } 230 metadataInQuery := metadata 231 go func() error { 232 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 233 if err != nil { 234 cancel(err) 235 return err 236 } 237 timeout.Update() 238 responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024) 239 if err != nil { 240 cancel(err) 241 return err 242 } 243 err = conn.WritePacket(responseBuffer, destination) 244 if err != nil { 245 cancel(err) 246 } 247 return err 248 }() 249 } 250 }) 251 group.Cleanup(func() { 252 conn.Close() 253 }) 254 return group.Run(fastClose) 255 }