github.com/sagernet/sing-box@v1.2.7/outbound/dns.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"net"
     7  	"os"
     8  
     9  	"github.com/sagernet/sing-box/adapter"
    10  	C "github.com/sagernet/sing-box/constant"
    11  	"github.com/sagernet/sing-dns"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/canceler"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/task"
    18  
    19  	mDNS "github.com/miekg/dns"
    20  )
    21  
    22  var _ adapter.Outbound = (*DNS)(nil)
    23  
    24  type DNS struct {
    25  	myOutboundAdapter
    26  }
    27  
    28  func NewDNS(router adapter.Router, tag string) *DNS {
    29  	return &DNS{
    30  		myOutboundAdapter{
    31  			protocol: C.TypeDNS,
    32  			network:  []string{N.NetworkTCP, N.NetworkUDP},
    33  			router:   router,
    34  			tag:      tag,
    35  		},
    36  	}
    37  }
    38  
    39  func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    40  	return nil, os.ErrInvalid
    41  }
    42  
    43  func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    44  	return nil, os.ErrInvalid
    45  }
    46  
    47  func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    48  	defer conn.Close()
    49  	ctx = adapter.WithContext(ctx, &metadata)
    50  	for {
    51  		err := d.handleConnection(ctx, conn, metadata)
    52  		if err != nil {
    53  			return err
    54  		}
    55  	}
    56  }
    57  
    58  func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
    59  	var queryLength uint16
    60  	err := binary.Read(conn, binary.BigEndian, &queryLength)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	if queryLength == 0 {
    65  		return dns.RCodeFormatError
    66  	}
    67  	_buffer := buf.StackNewSize(int(queryLength))
    68  	defer common.KeepAlive(_buffer)
    69  	buffer := common.Dup(_buffer)
    70  	defer buffer.Release()
    71  	_, err = buffer.ReadFullFrom(conn, int(queryLength))
    72  	if err != nil {
    73  		return err
    74  	}
    75  	var message mDNS.Msg
    76  	err = message.Unpack(buffer.Bytes())
    77  	if err != nil {
    78  		return err
    79  	}
    80  	metadataInQuery := metadata
    81  	go func() error {
    82  		response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		_responseBuffer := buf.StackNewPacket()
    87  		defer common.KeepAlive(_responseBuffer)
    88  		responseBuffer := common.Dup(_responseBuffer)
    89  		defer responseBuffer.Release()
    90  		responseBuffer.Resize(2, 0)
    91  		n, err := response.PackBuffer(responseBuffer.FreeBytes())
    92  		if err != nil {
    93  			return err
    94  		}
    95  		responseBuffer.Truncate(len(n))
    96  		binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
    97  		_, err = conn.Write(responseBuffer.Bytes())
    98  		return err
    99  	}()
   100  	return nil
   101  }
   102  
   103  func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   104  	ctx = adapter.WithContext(ctx, &metadata)
   105  	fastClose, cancel := common.ContextWithCancelCause(ctx)
   106  	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
   107  	var group task.Group
   108  	group.Append0(func(ctx context.Context) error {
   109  		_buffer := buf.StackNewSize(dns.FixedPacketSize)
   110  		defer common.KeepAlive(_buffer)
   111  		buffer := common.Dup(_buffer)
   112  		defer buffer.Release()
   113  		for {
   114  			buffer.FullReset()
   115  			destination, err := conn.ReadPacket(buffer)
   116  			if err != nil {
   117  				cancel(err)
   118  				return err
   119  			}
   120  			var message mDNS.Msg
   121  			err = message.Unpack(buffer.Bytes())
   122  			if err != nil {
   123  				cancel(err)
   124  				return err
   125  			}
   126  			timeout.Update()
   127  			metadataInQuery := metadata
   128  			go func() error {
   129  				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
   130  				if err != nil {
   131  					cancel(err)
   132  					return err
   133  				}
   134  				timeout.Update()
   135  				responseBuffer := buf.NewPacket()
   136  				n, err := response.PackBuffer(responseBuffer.FreeBytes())
   137  				if err != nil {
   138  					cancel(err)
   139  					responseBuffer.Release()
   140  					return err
   141  				}
   142  				responseBuffer.Truncate(len(n))
   143  				err = conn.WritePacket(responseBuffer, destination)
   144  				if err != nil {
   145  					cancel(err)
   146  				}
   147  				return err
   148  			}()
   149  		}
   150  	})
   151  	group.Cleanup(func() {
   152  		conn.Close()
   153  	})
   154  	return group.Run(fastClose)
   155  }