github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/dns.go (about) 1 package outbound 2 3 import ( 4 "context" 5 "encoding/binary" 6 "net" 7 "os" 8 9 "github.com/inazumav/sing-box/adapter" 10 C "github.com/inazumav/sing-box/constant" 11 "github.com/sagernet/sing-dns" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 "github.com/sagernet/sing/common/bufio" 15 "github.com/sagernet/sing/common/canceler" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/task" 19 20 mDNS "github.com/miekg/dns" 21 ) 22 23 var _ adapter.Outbound = (*DNS)(nil) 24 25 type DNS struct { 26 myOutboundAdapter 27 } 28 29 func NewDNS(router adapter.Router, tag string) *DNS { 30 return &DNS{ 31 myOutboundAdapter{ 32 protocol: C.TypeDNS, 33 network: []string{N.NetworkTCP, N.NetworkUDP}, 34 router: router, 35 tag: tag, 36 }, 37 } 38 } 39 40 func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 41 return nil, os.ErrInvalid 42 } 43 44 func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 45 return nil, os.ErrInvalid 46 } 47 48 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 49 defer conn.Close() 50 ctx = adapter.WithContext(ctx, &metadata) 51 for { 52 err := d.handleConnection(ctx, conn, metadata) 53 if err != nil { 54 return err 55 } 56 } 57 } 58 59 func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 60 var queryLength uint16 61 err := binary.Read(conn, binary.BigEndian, &queryLength) 62 if err != nil { 63 return err 64 } 65 if queryLength == 0 { 66 return dns.RCodeFormatError 67 } 68 buffer := buf.NewSize(int(queryLength)) 69 defer buffer.Release() 70 _, err = buffer.ReadFullFrom(conn, int(queryLength)) 71 if err != nil { 72 return err 73 } 74 var message mDNS.Msg 75 err = message.Unpack(buffer.Bytes()) 76 if err != nil { 77 return err 78 } 79 metadataInQuery := metadata 80 go func() error { 81 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 82 if err != nil { 83 return err 84 } 85 responseBuffer := buf.NewPacket() 86 defer responseBuffer.Release() 87 responseBuffer.Resize(2, 0) 88 n, err := response.PackBuffer(responseBuffer.FreeBytes()) 89 if err != nil { 90 return err 91 } 92 responseBuffer.Truncate(len(n)) 93 binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n))) 94 _, err = conn.Write(responseBuffer.Bytes()) 95 return err 96 }() 97 return nil 98 } 99 100 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 101 var reader N.PacketReader = conn 102 var counters []N.CountFunc 103 var cachedPackets []*N.PacketBuffer 104 for { 105 reader, counters = N.UnwrapCountPacketReader(reader, counters) 106 if cachedReader, isCached := reader.(N.CachedPacketReader); isCached { 107 packet := cachedReader.ReadCachedPacket() 108 if packet != nil { 109 cachedPackets = append(cachedPackets, packet) 110 continue 111 } 112 } 113 if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { 114 return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) 115 } 116 break 117 } 118 ctx = adapter.WithContext(ctx, &metadata) 119 fastClose, cancel := common.ContextWithCancelCause(ctx) 120 timeout := canceler.New(fastClose, cancel, C.DNSTimeout) 121 var group task.Group 122 group.Append0(func(ctx context.Context) error { 123 for { 124 var message mDNS.Msg 125 var destination M.Socksaddr 126 var err error 127 if len(cachedPackets) > 0 { 128 packet := cachedPackets[0] 129 cachedPackets = cachedPackets[1:] 130 for _, counter := range counters { 131 counter(int64(packet.Buffer.Len())) 132 } 133 err = message.Unpack(packet.Buffer.Bytes()) 134 packet.Buffer.Release() 135 if err != nil { 136 cancel(err) 137 return err 138 } 139 destination = packet.Destination 140 } else { 141 buffer := buf.NewPacket() 142 destination, err = conn.ReadPacket(buffer) 143 if err != nil { 144 buffer.Release() 145 cancel(err) 146 return err 147 } 148 for _, counter := range counters { 149 counter(int64(buffer.Len())) 150 } 151 err = message.Unpack(buffer.Bytes()) 152 buffer.Release() 153 if err != nil { 154 cancel(err) 155 return err 156 } 157 timeout.Update() 158 } 159 metadataInQuery := metadata 160 go func() error { 161 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 162 if err != nil { 163 cancel(err) 164 return err 165 } 166 timeout.Update() 167 responseBuffer := buf.NewPacket() 168 n, err := response.PackBuffer(responseBuffer.FreeBytes()) 169 if err != nil { 170 cancel(err) 171 responseBuffer.Release() 172 return err 173 } 174 responseBuffer.Truncate(len(n)) 175 err = conn.WritePacket(responseBuffer, destination) 176 if err != nil { 177 cancel(err) 178 } 179 return err 180 }() 181 } 182 }) 183 group.Cleanup(func() { 184 conn.Close() 185 }) 186 return group.Run(fastClose) 187 } 188 189 func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { 190 ctx = adapter.WithContext(ctx, &metadata) 191 fastClose, cancel := common.ContextWithCancelCause(ctx) 192 timeout := canceler.New(fastClose, cancel, C.DNSTimeout) 193 var group task.Group 194 group.Append0(func(ctx context.Context) error { 195 var buffer *buf.Buffer 196 readWaiter.InitializeReadWaiter(func() *buf.Buffer { 197 buffer = buf.NewSize(dns.FixedPacketSize) 198 buffer.FullReset() 199 return buffer 200 }) 201 defer readWaiter.InitializeReadWaiter(nil) 202 for { 203 var message mDNS.Msg 204 var destination M.Socksaddr 205 var err error 206 if len(cached) > 0 { 207 packet := cached[0] 208 cached = cached[1:] 209 for _, counter := range readCounters { 210 counter(int64(packet.Buffer.Len())) 211 } 212 err = message.Unpack(packet.Buffer.Bytes()) 213 packet.Buffer.Release() 214 if err != nil { 215 cancel(err) 216 return err 217 } 218 destination = packet.Destination 219 } else { 220 destination, err = readWaiter.WaitReadPacket() 221 if err != nil { 222 buffer.Release() 223 cancel(err) 224 return err 225 } 226 for _, counter := range readCounters { 227 counter(int64(buffer.Len())) 228 } 229 err = message.Unpack(buffer.Bytes()) 230 buffer.Release() 231 if err != nil { 232 cancel(err) 233 return err 234 } 235 timeout.Update() 236 } 237 metadataInQuery := metadata 238 go func() error { 239 response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) 240 if err != nil { 241 cancel(err) 242 return err 243 } 244 timeout.Update() 245 responseBuffer := buf.NewPacket() 246 n, err := response.PackBuffer(responseBuffer.FreeBytes()) 247 if err != nil { 248 cancel(err) 249 responseBuffer.Release() 250 return err 251 } 252 responseBuffer.Truncate(len(n)) 253 err = conn.WritePacket(responseBuffer, destination) 254 if err != nil { 255 cancel(err) 256 } 257 return err 258 }() 259 } 260 }) 261 group.Cleanup(func() { 262 conn.Close() 263 }) 264 return group.Run(fastClose) 265 }