github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/dns.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"net"
     7  	"os"
     8  
     9  	"github.com/inazumav/sing-box/adapter"
    10  	C "github.com/inazumav/sing-box/constant"
    11  	"github.com/sagernet/sing-dns"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	"github.com/sagernet/sing/common/canceler"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/task"
    19  
    20  	mDNS "github.com/miekg/dns"
    21  )
    22  
    23  var _ adapter.Outbound = (*DNS)(nil)
    24  
    25  type DNS struct {
    26  	myOutboundAdapter
    27  }
    28  
    29  func NewDNS(router adapter.Router, tag string) *DNS {
    30  	return &DNS{
    31  		myOutboundAdapter{
    32  			protocol: C.TypeDNS,
    33  			network:  []string{N.NetworkTCP, N.NetworkUDP},
    34  			router:   router,
    35  			tag:      tag,
    36  		},
    37  	}
    38  }
    39  
    40  func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    41  	return nil, os.ErrInvalid
    42  }
    43  
    44  func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    45  	return nil, os.ErrInvalid
    46  }
    47  
    48  func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    49  	defer conn.Close()
    50  	ctx = adapter.WithContext(ctx, &metadata)
    51  	for {
    52  		err := d.handleConnection(ctx, conn, metadata)
    53  		if err != nil {
    54  			return err
    55  		}
    56  	}
    57  }
    58  
    59  func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    60  	var queryLength uint16
    61  	err := binary.Read(conn, binary.BigEndian, &queryLength)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	if queryLength == 0 {
    66  		return dns.RCodeFormatError
    67  	}
    68  	buffer := buf.NewSize(int(queryLength))
    69  	defer buffer.Release()
    70  	_, err = buffer.ReadFullFrom(conn, int(queryLength))
    71  	if err != nil {
    72  		return err
    73  	}
    74  	var message mDNS.Msg
    75  	err = message.Unpack(buffer.Bytes())
    76  	if err != nil {
    77  		return err
    78  	}
    79  	metadataInQuery := metadata
    80  	go func() error {
    81  		response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
    82  		if err != nil {
    83  			return err
    84  		}
    85  		responseBuffer := buf.NewPacket()
    86  		defer responseBuffer.Release()
    87  		responseBuffer.Resize(2, 0)
    88  		n, err := response.PackBuffer(responseBuffer.FreeBytes())
    89  		if err != nil {
    90  			return err
    91  		}
    92  		responseBuffer.Truncate(len(n))
    93  		binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
    94  		_, err = conn.Write(responseBuffer.Bytes())
    95  		return err
    96  	}()
    97  	return nil
    98  }
    99  
   100  func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   101  	var reader N.PacketReader = conn
   102  	var counters []N.CountFunc
   103  	var cachedPackets []*N.PacketBuffer
   104  	for {
   105  		reader, counters = N.UnwrapCountPacketReader(reader, counters)
   106  		if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
   107  			packet := cachedReader.ReadCachedPacket()
   108  			if packet != nil {
   109  				cachedPackets = append(cachedPackets, packet)
   110  				continue
   111  			}
   112  		}
   113  		if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
   114  			return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
   115  		}
   116  		break
   117  	}
   118  	ctx = adapter.WithContext(ctx, &metadata)
   119  	fastClose, cancel := common.ContextWithCancelCause(ctx)
   120  	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
   121  	var group task.Group
   122  	group.Append0(func(ctx context.Context) error {
   123  		for {
   124  			var message mDNS.Msg
   125  			var destination M.Socksaddr
   126  			var err error
   127  			if len(cachedPackets) > 0 {
   128  				packet := cachedPackets[0]
   129  				cachedPackets = cachedPackets[1:]
   130  				for _, counter := range counters {
   131  					counter(int64(packet.Buffer.Len()))
   132  				}
   133  				err = message.Unpack(packet.Buffer.Bytes())
   134  				packet.Buffer.Release()
   135  				if err != nil {
   136  					cancel(err)
   137  					return err
   138  				}
   139  				destination = packet.Destination
   140  			} else {
   141  				buffer := buf.NewPacket()
   142  				destination, err = conn.ReadPacket(buffer)
   143  				if err != nil {
   144  					buffer.Release()
   145  					cancel(err)
   146  					return err
   147  				}
   148  				for _, counter := range counters {
   149  					counter(int64(buffer.Len()))
   150  				}
   151  				err = message.Unpack(buffer.Bytes())
   152  				buffer.Release()
   153  				if err != nil {
   154  					cancel(err)
   155  					return err
   156  				}
   157  				timeout.Update()
   158  			}
   159  			metadataInQuery := metadata
   160  			go func() error {
   161  				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
   162  				if err != nil {
   163  					cancel(err)
   164  					return err
   165  				}
   166  				timeout.Update()
   167  				responseBuffer := buf.NewPacket()
   168  				n, err := response.PackBuffer(responseBuffer.FreeBytes())
   169  				if err != nil {
   170  					cancel(err)
   171  					responseBuffer.Release()
   172  					return err
   173  				}
   174  				responseBuffer.Truncate(len(n))
   175  				err = conn.WritePacket(responseBuffer, destination)
   176  				if err != nil {
   177  					cancel(err)
   178  				}
   179  				return err
   180  			}()
   181  		}
   182  	})
   183  	group.Cleanup(func() {
   184  		conn.Close()
   185  	})
   186  	return group.Run(fastClose)
   187  }
   188  
   189  func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
   190  	ctx = adapter.WithContext(ctx, &metadata)
   191  	fastClose, cancel := common.ContextWithCancelCause(ctx)
   192  	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
   193  	var group task.Group
   194  	group.Append0(func(ctx context.Context) error {
   195  		var buffer *buf.Buffer
   196  		readWaiter.InitializeReadWaiter(func() *buf.Buffer {
   197  			buffer = buf.NewSize(dns.FixedPacketSize)
   198  			buffer.FullReset()
   199  			return buffer
   200  		})
   201  		defer readWaiter.InitializeReadWaiter(nil)
   202  		for {
   203  			var message mDNS.Msg
   204  			var destination M.Socksaddr
   205  			var err error
   206  			if len(cached) > 0 {
   207  				packet := cached[0]
   208  				cached = cached[1:]
   209  				for _, counter := range readCounters {
   210  					counter(int64(packet.Buffer.Len()))
   211  				}
   212  				err = message.Unpack(packet.Buffer.Bytes())
   213  				packet.Buffer.Release()
   214  				if err != nil {
   215  					cancel(err)
   216  					return err
   217  				}
   218  				destination = packet.Destination
   219  			} else {
   220  				destination, err = readWaiter.WaitReadPacket()
   221  				if err != nil {
   222  					buffer.Release()
   223  					cancel(err)
   224  					return err
   225  				}
   226  				for _, counter := range readCounters {
   227  					counter(int64(buffer.Len()))
   228  				}
   229  				err = message.Unpack(buffer.Bytes())
   230  				buffer.Release()
   231  				if err != nil {
   232  					cancel(err)
   233  					return err
   234  				}
   235  				timeout.Update()
   236  			}
   237  			metadataInQuery := metadata
   238  			go func() error {
   239  				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
   240  				if err != nil {
   241  					cancel(err)
   242  					return err
   243  				}
   244  				timeout.Update()
   245  				responseBuffer := buf.NewPacket()
   246  				n, err := response.PackBuffer(responseBuffer.FreeBytes())
   247  				if err != nil {
   248  					cancel(err)
   249  					responseBuffer.Release()
   250  					return err
   251  				}
   252  				responseBuffer.Truncate(len(n))
   253  				err = conn.WritePacket(responseBuffer, destination)
   254  				if err != nil {
   255  					cancel(err)
   256  				}
   257  				return err
   258  			}()
   259  		}
   260  	})
   261  	group.Cleanup(func() {
   262  		conn.Close()
   263  	})
   264  	return group.Run(fastClose)
   265  }