github.com/sagernet/sing-box@v1.9.0-rc.20/common/sniff/dns.go (about)

     1  package sniff
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"io"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing-box/adapter"
    11  	C "github.com/sagernet/sing-box/constant"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	"github.com/sagernet/sing/common/task"
    16  
    17  	mDNS "github.com/miekg/dns"
    18  )
    19  
    20  func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
    21  	var length uint16
    22  	err := binary.Read(reader, binary.BigEndian, &length)
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	if length == 0 {
    27  		return nil, os.ErrInvalid
    28  	}
    29  	buffer := buf.NewSize(int(length))
    30  	defer buffer.Release()
    31  
    32  	readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
    33  	var readTask task.Group
    34  	readTask.Append0(func(ctx context.Context) error {
    35  		return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
    36  	})
    37  	err = readTask.Run(readCtx)
    38  	cancel()
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	return DomainNameQuery(readCtx, buffer.Bytes())
    43  }
    44  
    45  func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
    46  	var msg mDNS.Msg
    47  	err := msg.Unpack(packet)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	if len(msg.Question) == 0 || msg.Question[0].Qclass != mDNS.ClassINET || !M.IsDomainName(msg.Question[0].Name) {
    52  		return nil, os.ErrInvalid
    53  	}
    54  	return &adapter.InboundContext{Protocol: C.ProtocolDNS}, nil
    55  }