github.com/sagernet/sing-box@v1.9.0-rc.20/outbound/dns.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"net"
     7  	"os"
     8  
     9  	"github.com/sagernet/sing-box/adapter"
    10  	C "github.com/sagernet/sing-box/constant"
    11  	"github.com/sagernet/sing-dns"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	"github.com/sagernet/sing/common/canceler"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/task"
    19  
    20  	mDNS "github.com/miekg/dns"
    21  )
    22  
    23  var _ adapter.Outbound = (*DNS)(nil)
    24  
    25  type DNS struct {
    26  	myOutboundAdapter
    27  }
    28  
    29  func NewDNS(router adapter.Router, tag string) *DNS {
    30  	return &DNS{
    31  		myOutboundAdapter{
    32  			protocol: C.TypeDNS,
    33  			network:  []string{N.NetworkTCP, N.NetworkUDP},
    34  			router:   router,
    35  			tag:      tag,
    36  		},
    37  	}
    38  }
    39  
    40  func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    41  	return nil, os.ErrInvalid
    42  }
    43  
    44  func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    45  	return nil, os.ErrInvalid
    46  }
    47  
    48  func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    49  	metadata.Destination = M.Socksaddr{}
    50  	defer conn.Close()
    51  	for {
    52  		err := d.handleConnection(ctx, conn, metadata)
    53  		if err != nil {
    54  			return err
    55  		}
    56  	}
    57  }
    58  
    59  func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    60  	var queryLength uint16
    61  	err := binary.Read(conn, binary.BigEndian, &queryLength)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	if queryLength == 0 {
    66  		return dns.RCodeFormatError
    67  	}
    68  	buffer := buf.NewSize(int(queryLength))
    69  	defer buffer.Release()
    70  	_, err = buffer.ReadFullFrom(conn, int(queryLength))
    71  	if err != nil {
    72  		return err
    73  	}
    74  	var message mDNS.Msg
    75  	err = message.Unpack(buffer.Bytes())
    76  	if err != nil {
    77  		return err
    78  	}
    79  	metadataInQuery := metadata
    80  	go func() error {
    81  		response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
    82  		if err != nil {
    83  			return err
    84  		}
    85  		responseBuffer := buf.NewPacket()
    86  		defer responseBuffer.Release()
    87  		responseBuffer.Resize(2, 0)
    88  		n, err := response.PackBuffer(responseBuffer.FreeBytes())
    89  		if err != nil {
    90  			return err
    91  		}
    92  		responseBuffer.Truncate(len(n))
    93  		binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
    94  		_, err = conn.Write(responseBuffer.Bytes())
    95  		return err
    96  	}()
    97  	return nil
    98  }
    99  
   100  func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   101  	metadata.Destination = M.Socksaddr{}
   102  	var reader N.PacketReader = conn
   103  	var counters []N.CountFunc
   104  	var cachedPackets []*N.PacketBuffer
   105  	for {
   106  		reader, counters = N.UnwrapCountPacketReader(reader, counters)
   107  		if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
   108  			packet := cachedReader.ReadCachedPacket()
   109  			if packet != nil {
   110  				cachedPackets = append(cachedPackets, packet)
   111  				continue
   112  			}
   113  		}
   114  		if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
   115  			readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
   116  			return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
   117  		}
   118  		break
   119  	}
   120  	fastClose, cancel := common.ContextWithCancelCause(ctx)
   121  	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
   122  	var group task.Group
   123  	group.Append0(func(ctx context.Context) error {
   124  		for {
   125  			var message mDNS.Msg
   126  			var destination M.Socksaddr
   127  			var err error
   128  			if len(cachedPackets) > 0 {
   129  				packet := cachedPackets[0]
   130  				cachedPackets = cachedPackets[1:]
   131  				for _, counter := range counters {
   132  					counter(int64(packet.Buffer.Len()))
   133  				}
   134  				err = message.Unpack(packet.Buffer.Bytes())
   135  				packet.Buffer.Release()
   136  				if err != nil {
   137  					cancel(err)
   138  					return err
   139  				}
   140  				destination = packet.Destination
   141  			} else {
   142  				buffer := buf.NewPacket()
   143  				destination, err = conn.ReadPacket(buffer)
   144  				if err != nil {
   145  					buffer.Release()
   146  					cancel(err)
   147  					return err
   148  				}
   149  				for _, counter := range counters {
   150  					counter(int64(buffer.Len()))
   151  				}
   152  				err = message.Unpack(buffer.Bytes())
   153  				buffer.Release()
   154  				if err != nil {
   155  					cancel(err)
   156  					return err
   157  				}
   158  				timeout.Update()
   159  			}
   160  			metadataInQuery := metadata
   161  			go func() error {
   162  				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
   163  				if err != nil {
   164  					cancel(err)
   165  					return err
   166  				}
   167  				timeout.Update()
   168  				responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
   169  				if err != nil {
   170  					cancel(err)
   171  					return err
   172  				}
   173  				err = conn.WritePacket(responseBuffer, destination)
   174  				if err != nil {
   175  					cancel(err)
   176  				}
   177  				return err
   178  			}()
   179  		}
   180  	})
   181  	group.Cleanup(func() {
   182  		conn.Close()
   183  	})
   184  	return group.Run(fastClose)
   185  }
   186  
   187  func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
   188  	ctx = adapter.WithContext(ctx, &metadata)
   189  	fastClose, cancel := common.ContextWithCancelCause(ctx)
   190  	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
   191  	var group task.Group
   192  	group.Append0(func(ctx context.Context) error {
   193  		for {
   194  			var (
   195  				message     mDNS.Msg
   196  				destination M.Socksaddr
   197  				err         error
   198  				buffer      *buf.Buffer
   199  			)
   200  			if len(cached) > 0 {
   201  				packet := cached[0]
   202  				cached = cached[1:]
   203  				for _, counter := range readCounters {
   204  					counter(int64(packet.Buffer.Len()))
   205  				}
   206  				err = message.Unpack(packet.Buffer.Bytes())
   207  				packet.Buffer.Release()
   208  				if err != nil {
   209  					cancel(err)
   210  					return err
   211  				}
   212  				destination = packet.Destination
   213  			} else {
   214  				buffer, destination, err = readWaiter.WaitReadPacket()
   215  				if err != nil {
   216  					cancel(err)
   217  					return err
   218  				}
   219  				for _, counter := range readCounters {
   220  					counter(int64(buffer.Len()))
   221  				}
   222  				err = message.Unpack(buffer.Bytes())
   223  				buffer.Release()
   224  				if err != nil {
   225  					cancel(err)
   226  					return err
   227  				}
   228  				timeout.Update()
   229  			}
   230  			metadataInQuery := metadata
   231  			go func() error {
   232  				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
   233  				if err != nil {
   234  					cancel(err)
   235  					return err
   236  				}
   237  				timeout.Update()
   238  				responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
   239  				if err != nil {
   240  					cancel(err)
   241  					return err
   242  				}
   243  				err = conn.WritePacket(responseBuffer, destination)
   244  				if err != nil {
   245  					cancel(err)
   246  				}
   247  				return err
   248  			}()
   249  		}
   250  	})
   251  	group.Cleanup(func() {
   252  		conn.Close()
   253  	})
   254  	return group.Run(fastClose)
   255  }