github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/message.go (about)

     1  package fastdns
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  )
     7  
     8  // Message represents an DNS request received by a server or to be sent by a client.
     9  type Message struct {
    10  	// Raw refers to the raw query packet.
    11  	Raw []byte
    12  
    13  	// Domain represents to the parsed query domain in the query.
    14  	Domain []byte
    15  
    16  	// Header encapsulates the construct of the header part of the DNS query message.
    17  	// It follows the conventions stated at RFC1035 section 4.1.1.
    18  	Header struct {
    19  		// ID is an arbitrary 16bit request identifier that is
    20  		// forwarded back in the response so that we can match them up.
    21  		//
    22  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    23  		// |                      ID                       |
    24  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    25  		ID uint16
    26  
    27  		// Flags is an arbitrary 16bit represents QR, Opcode, AA, TC, RD, RA, Z and RCODE.
    28  		//
    29  		//   0  1  2  3  4  5  6  7  8  9  A  B  C  D  E  F
    30  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    31  		// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
    32  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    33  		Flags Flags
    34  
    35  		// QDCOUNT specifies the number of entries in the question section
    36  		//
    37  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    38  		// |                    QDCOUNT                    |
    39  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    40  		QDCount uint16
    41  
    42  		// ANCount specifies the number of resource records (RR) in the answer section
    43  		//
    44  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    45  		// |                    ANCOUNT                    |
    46  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    47  		ANCount uint16
    48  
    49  		// NSCount specifies the number of name server resource records in the authority section
    50  		//
    51  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    52  		// |                    NSCOUNT                    |
    53  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    54  		NSCount uint16
    55  
    56  		// ARCount specifies the number of resource records in the additional records section
    57  		//
    58  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    59  		// |                    ARCOUNT                    |
    60  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    61  		ARCount uint16
    62  	}
    63  
    64  	// Question encapsulates the construct of the question part of the DNS query message.
    65  	// It follows the conventions stated at RFC1035 section 4.1.2.
    66  	Question struct {
    67  		// Name refers to the raw query name to be resolved in the query.
    68  		//
    69  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    70  		// |                                               |
    71  		// /                     QNAME                     /
    72  		// /                                               /
    73  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    74  		Name []byte
    75  
    76  		// Type specifies the type of the query to perform.
    77  		//
    78  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    79  		// |                     QTYPE                     |
    80  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    81  		Type Type
    82  
    83  		// Class specifies the class of the query to perform.
    84  		//
    85  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    86  		// |                     QCLASS                    |
    87  		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    88  		Class Class
    89  	}
    90  }
    91  
    92  var (
    93  	// ErrInvalidHeader is returned when dns message does not have the expected header size.
    94  	ErrInvalidHeader = errors.New("dns message does not have the expected header size")
    95  	// ErrInvalidQuestion is returned when dns message does not have the expected question size.
    96  	ErrInvalidQuestion = errors.New("dns message does not have the expected question size")
    97  	// ErrInvalidAnswer is returned when dns message does not have the expected answer size.
    98  	ErrInvalidAnswer = errors.New("dns message does not have the expected answer size")
    99  )
   100  
   101  // ParseMessage parses dns request from payload into dst and returns the error.
   102  func ParseMessage(dst *Message, payload []byte, copying bool) error {
   103  	if copying {
   104  		dst.Raw = append(dst.Raw[:0], payload...)
   105  		payload = dst.Raw
   106  	}
   107  
   108  	if len(payload) < 12 {
   109  		return ErrInvalidHeader
   110  	}
   111  
   112  	// hint golang compiler remove ip bounds check
   113  	_ = payload[11]
   114  
   115  	// ID
   116  	dst.Header.ID = uint16(payload[0])<<8 | uint16(payload[1])
   117  
   118  	// RD, TC, AA, Opcode, QR, RA, Z, RCODE
   119  	dst.Header.Flags = Flags(payload[2])<<8 | Flags(payload[3])
   120  
   121  	// QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT
   122  	dst.Header.QDCount = uint16(payload[4])<<8 | uint16(payload[5])
   123  	dst.Header.ANCount = uint16(payload[6])<<8 | uint16(payload[7])
   124  	dst.Header.NSCount = uint16(payload[8])<<8 | uint16(payload[9])
   125  	dst.Header.ARCount = uint16(payload[10])<<8 | uint16(payload[11])
   126  
   127  	if dst.Header.QDCount != 1 {
   128  		return ErrInvalidHeader
   129  	}
   130  
   131  	// QNAME
   132  	payload = payload[12:]
   133  	var i int
   134  	var b byte
   135  	for i, b = range payload {
   136  		if b == 0 {
   137  			break
   138  		}
   139  	}
   140  	if i == 0 || i+5 > len(payload) {
   141  		return ErrInvalidQuestion
   142  	}
   143  	dst.Question.Name = payload[:i+1]
   144  
   145  	// QTYPE, QCLASS
   146  	payload = payload[i:]
   147  	dst.Question.Class = Class(uint16(payload[4]) | uint16(payload[3])<<8)
   148  	dst.Question.Type = Type(uint16(payload[2]) | uint16(payload[1])<<8)
   149  
   150  	// Domain
   151  	i = int(dst.Question.Name[0])
   152  	payload = append(dst.Domain[:0], dst.Question.Name[1:]...)
   153  	for payload[i] != 0 {
   154  		j := int(payload[i])
   155  		payload[i] = '.'
   156  		i += j + 1
   157  	}
   158  	dst.Domain = payload[:len(payload)-1]
   159  
   160  	return nil
   161  }
   162  
   163  // DecodeName decodes dns labels to dst.
   164  func (msg *Message) DecodeName(dst []byte, name []byte) []byte {
   165  	if len(name) < 2 {
   166  		return dst
   167  	}
   168  
   169  	// fast path for domain pointer
   170  	if name[1] == 12 && name[0] == 0b11000000 {
   171  		return append(dst, msg.Domain...)
   172  	}
   173  
   174  	pos := len(dst)
   175  	var offset int
   176  	if name[len(name)-1] == 0 {
   177  		dst = append(dst, name...)
   178  	} else {
   179  		dst = append(dst, name[:len(name)-2]...)
   180  		offset = int(name[len(name)-2]&0b00111111)<<8 + int(name[len(name)-1])
   181  	}
   182  
   183  	for offset != 0 {
   184  		for i := offset; i < len(msg.Raw); {
   185  			b := int(msg.Raw[i])
   186  			if b == 0 {
   187  				offset = 0
   188  				dst = append(dst, 0)
   189  				break
   190  			} else if b&0b11000000 == 0b11000000 {
   191  				offset = int(b&0b00111111)<<8 + int(msg.Raw[i+1])
   192  				break
   193  			} else {
   194  				dst = append(dst, msg.Raw[i:i+b+1]...)
   195  				i += b + 1
   196  			}
   197  		}
   198  	}
   199  
   200  	n := pos
   201  	for dst[pos] != 0 {
   202  		i := int(dst[pos])
   203  		dst[pos] = '.'
   204  		pos += i + 1
   205  	}
   206  
   207  	if n == 0 {
   208  		dst = dst[1 : len(dst)-1]
   209  	} else {
   210  		dst = append(dst[:n], dst[n+1:len(dst)-1]...)
   211  	}
   212  
   213  	return dst
   214  }
   215  
   216  // Walk calls f for each item in the msg in the original order of the parsed RR.
   217  func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error {
   218  	n := msg.Header.ANCount + msg.Header.NSCount
   219  	if n == 0 {
   220  		return ErrInvalidAnswer
   221  	}
   222  
   223  	payload := msg.Raw[16+len(msg.Question.Name):]
   224  
   225  	for i := uint16(0); i < n; i++ {
   226  		var name []byte
   227  		for j, b := range payload {
   228  			if b&0b11000000 == 0b11000000 {
   229  				name = payload[:j+2]
   230  				payload = payload[j+2:]
   231  				break
   232  			} else if b == 0 {
   233  				name = payload[:j+1]
   234  				payload = payload[j+1:]
   235  				break
   236  			}
   237  		}
   238  		if name == nil {
   239  			return ErrInvalidAnswer
   240  		}
   241  		_ = payload[9] // hint compiler to remove bounds check
   242  		typ := Type(payload[0])<<8 | Type(payload[1])
   243  		class := Class(payload[2])<<8 | Class(payload[3])
   244  		ttl := uint32(payload[4])<<24 | uint32(payload[5])<<16 | uint32(payload[6])<<8 | uint32(payload[7])
   245  		length := uint16(payload[8])<<8 | uint16(payload[9])
   246  		data := payload[10 : 10+length]
   247  		payload = payload[10+length:]
   248  		ok := f(name, typ, class, ttl, data)
   249  		if !ok {
   250  			break
   251  		}
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  // WalkAdditionalRecords calls f for each item in the msg in the original order of the parsed AR.
   258  func (msg *Message) WalkAdditionalRecords(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error {
   259  	panic("not implemented")
   260  }
   261  
   262  // SetRequestQuestion set question for DNS request.
   263  func (msg *Message) SetRequestQuestion(domain string, typ Type, class Class) {
   264  	// random head id
   265  	msg.Header.ID = uint16(cheaprandn(65536))
   266  
   267  	// QR = 0, RCODE = 0, RD = 1
   268  	//
   269  	//   0  1  2  3  4  5  6  7  8  9  A  B  C  D  E  F
   270  	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
   271  	// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
   272  	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
   273  	msg.Header.Flags &= 0b0111111111110000
   274  	msg.Header.Flags |= 0b0000000100000000
   275  
   276  	msg.Header.QDCount = 1
   277  	msg.Header.ANCount = 0
   278  	msg.Header.NSCount = 0
   279  	msg.Header.ARCount = 0
   280  
   281  	header := [...]byte{
   282  		// ID
   283  		byte(msg.Header.ID >> 8), byte(msg.Header.ID),
   284  		// Flags
   285  		byte(msg.Header.Flags >> 8), byte(msg.Header.Flags),
   286  		// QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT
   287  		0, 1, 0, 0, 0, 0, 0, 0,
   288  	}
   289  
   290  	msg.Raw = append(msg.Raw[:0], header[:]...)
   291  
   292  	// QNAME
   293  	msg.Raw = EncodeDomain(msg.Raw, domain)
   294  	msg.Question.Name = msg.Raw[len(header) : len(header)+len(domain)+2]
   295  	// QTYPE
   296  	msg.Raw = append(msg.Raw, byte(typ>>8), byte(typ))
   297  	msg.Question.Type = typ
   298  	// QCLASS
   299  	msg.Raw = append(msg.Raw, byte(class>>8), byte(class))
   300  	msg.Question.Class = class
   301  
   302  	// Domain
   303  	msg.Domain = append(msg.Domain[:0], domain...)
   304  }
   305  
   306  // SetResponseHeader sets QR=1, RCODE=rcode, ANCount=ancount then updates Raw.
   307  func (msg *Message) SetResponseHeader(rcode Rcode, ancount uint16) {
   308  	// QR = 1, RCODE = rcode
   309  	//
   310  	//   0  1  2  3  4  5  6  7  8  9  A  B  C  D  E  F
   311  	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
   312  	// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
   313  	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
   314  	msg.Header.Flags &= 0b1111111111110000
   315  	msg.Header.Flags |= 0b1000000000000000 | Flags(rcode)
   316  
   317  	// Error
   318  	if rcode != RcodeNoError {
   319  		msg.Header.QDCount = 0
   320  		msg.Header.ANCount = 0
   321  		msg.Header.NSCount = 0
   322  		msg.Header.ARCount = 0
   323  
   324  		msg.Raw = msg.Raw[:12]
   325  
   326  		// Flags
   327  		msg.Raw[2] = byte(msg.Header.Flags >> 8)
   328  		msg.Raw[3] = byte(msg.Header.Flags)
   329  
   330  		// QDCount
   331  		msg.Raw[4] = 0
   332  		msg.Raw[5] = 0
   333  
   334  		// ANCOUNT
   335  		msg.Raw[6] = 0
   336  		msg.Raw[7] = 0
   337  
   338  		// NSCOUNT
   339  		msg.Raw[8] = 0
   340  		msg.Raw[9] = 0
   341  
   342  		// ARCOUNT
   343  		msg.Raw[10] = 0
   344  		msg.Raw[11] = 0
   345  
   346  		return
   347  	}
   348  
   349  	msg.Header.QDCount = 1
   350  	msg.Header.ANCount = ancount
   351  	msg.Header.NSCount = 0
   352  	msg.Header.ARCount = 0
   353  
   354  	msg.Raw = msg.Raw[:12+len(msg.Question.Name)+4]
   355  	header := msg.Raw[:12]
   356  
   357  	// Flags
   358  	header[2] = byte(msg.Header.Flags >> 8)
   359  	header[3] = byte(msg.Header.Flags)
   360  
   361  	// QDCount
   362  	header[4] = 0
   363  	header[5] = 1
   364  
   365  	// ANCOUNT
   366  	header[6] = byte(ancount >> 8)
   367  	header[7] = byte(ancount)
   368  
   369  	// NSCOUNT
   370  	header[8] = 0
   371  	header[9] = 0
   372  
   373  	// ARCOUNT
   374  	header[10] = 0
   375  	header[11] = 0
   376  }
   377  
   378  var msgPool = sync.Pool{
   379  	New: func() interface{} {
   380  		msg := new(Message)
   381  		msg.Raw = make([]byte, 0, 1024)
   382  		msg.Domain = make([]byte, 0, 256)
   383  		return msg
   384  	},
   385  }
   386  
   387  // AcquireMessage returns new dns request.
   388  func AcquireMessage() *Message {
   389  	return msgPool.Get().(*Message)
   390  }
   391  
   392  // ReleaseMessage returnes the dns request to the pool.
   393  func ReleaseMessage(msg *Message) {
   394  	msgPool.Put(msg)
   395  }