github.com/sagernet/sing-box@v1.2.7/common/sniff/dns.go (about)

     1  package sniff
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"io"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing-box/adapter"
    11  	C "github.com/sagernet/sing-box/constant"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	"github.com/sagernet/sing/common/task"
    16  
    17  	mDNS "github.com/miekg/dns"
    18  )
    19  
    20  func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
    21  	var length uint16
    22  	err := binary.Read(reader, binary.BigEndian, &length)
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	if length == 0 {
    27  		return nil, os.ErrInvalid
    28  	}
    29  	_buffer := buf.StackNewSize(int(length))
    30  	defer common.KeepAlive(_buffer)
    31  	buffer := common.Dup(_buffer)
    32  	defer buffer.Release()
    33  
    34  	readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
    35  	var readTask task.Group
    36  	readTask.Append0(func(ctx context.Context) error {
    37  		return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
    38  	})
    39  	err = readTask.Run(readCtx)
    40  	cancel()
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	return DomainNameQuery(readCtx, buffer.Bytes())
    45  }
    46  
    47  func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
    48  	var msg mDNS.Msg
    49  	err := msg.Unpack(packet)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	if len(msg.Question) == 0 || msg.Question[0].Qclass != mDNS.ClassINET || !M.IsDomainName(msg.Question[0].Name) {
    54  		return nil, os.ErrInvalid
    55  	}
    56  	return &adapter.InboundContext{Protocol: C.ProtocolDNS}, nil
    57  }