github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/message_test.go (about)

     1  package fastdns
     2  
     3  import (
     4  	"encoding/hex"
     5  	"reflect"
     6  	"testing"
     7  )
     8  
     9  func TestParseMessageOK(t *testing.T) {
    10  	var cases = [2]struct {
    11  		Raw     []byte
    12  		Message *Message
    13  	}{}
    14  
    15  	/*
    16  		Domain Name System (query)
    17  		    Transaction ID: 0x0001
    18  		    Flags: 0x0100 Standard query
    19  		        0... .... .... .... = Response: Message is a query
    20  		        .000 0... .... .... = Opcode: Standard query (0)
    21  		        .... ..0. .... .... = Truncated: Message is not truncated
    22  		        .... ...1 .... .... = Recursion desired: Do query recursively
    23  		        .... .... .0.. .... = Z: reserved (0)
    24  		        .... .... ...0 .... = Non-authenticated data: Unacceptable
    25  		    Questions: 1
    26  		    Answer RRs: 0
    27  		    Authority RRs: 0
    28  		    Additional RRs: 0
    29  		    Queries
    30  		        1.50.168.192.in-addr.arpa: type PTR, class IN
    31  		            Name: 1.50.168.192.in-addr.arpa
    32  		            [Name Length: 25]
    33  		            [Label Count: 6]
    34  		            Type: PTR (domain name PoinTeR) (12)
    35  		            Class: IN (0x0001)
    36  	*/
    37  	cases[0].Raw, _ = hex.DecodeString("0001010000010000000000000131023530033136380331393207696e2d61646472046172706100000c0001")
    38  	cases[0].Message = AcquireMessage()
    39  	cases[0].Message.Raw = cases[0].Raw
    40  	cases[0].Message.Domain = []byte("1.50.168.192.in-addr.arpa")
    41  	cases[0].Message.Header.ID = 0x0001
    42  	cases[0].Message.Header.Flags = 0b0000000100000000
    43  	cases[0].Message.Header.QDCount = 0x01
    44  	cases[0].Message.Header.ANCount = 0x00
    45  	cases[0].Message.Header.NSCount = 0x00
    46  	cases[0].Message.Header.ARCount = 0x00
    47  	cases[0].Message.Question.Name = []byte("\x011\x0250\x03168\x03192\x07in-addr\x04arpa\x00")
    48  	cases[0].Message.Question.Type = TypePTR
    49  	cases[0].Message.Question.Class = ClassINET
    50  
    51  	/*
    52  		Domain Name System (query)
    53  		    Transaction ID: 0x0002
    54  		    Flags: 0x0100 Standard query
    55  		        0... .... .... .... = Response: Message is a query
    56  		        .000 0... .... .... = Opcode: Standard query (0)
    57  		        .... ..0. .... .... = Truncated: Message is not truncated
    58  		        .... ...1 .... .... = Recursion desired: Do query recursively
    59  		        .... .... .0.. .... = Z: reserved (0)
    60  		        .... .... ...0 .... = Non-authenticated data: Unacceptable
    61  		    Questions: 1
    62  		    Answer RRs: 0
    63  		    Authority RRs: 0
    64  		    Additional RRs: 0
    65  		    Queries
    66  		        hk.phus.lu: type A, class IN
    67  		            Name: hk.phus.lu
    68  		            [Name Length: 10]
    69  		            [Label Count: 3]
    70  		            Type: A (Host Address) (1)
    71  		            Class: IN (0x0001)
    72  	*/
    73  	cases[1].Raw, _ = hex.DecodeString("00020100000100000000000002686b0470687573026c750000010001")
    74  	cases[1].Message = AcquireMessage()
    75  	cases[1].Message.Raw = cases[1].Raw
    76  	cases[1].Message.Domain = []byte("hk.phus.lu")
    77  	cases[1].Message.Header.ID = 0x0002
    78  	cases[1].Message.Header.Flags = 0b0000000100000000
    79  	cases[1].Message.Header.QDCount = 0x01
    80  	cases[1].Message.Header.ANCount = 0x00
    81  	cases[1].Message.Header.NSCount = 0x00
    82  	cases[1].Message.Header.ARCount = 0x00
    83  	cases[1].Message.Question.Name = []byte("\x02hk\x04phus\x02lu\x00")
    84  	cases[1].Message.Question.Type = TypeA
    85  	cases[1].Message.Question.Class = ClassINET
    86  
    87  	for _, c := range cases {
    88  		msg := AcquireMessage()
    89  		err := ParseMessage(msg, c.Raw, true)
    90  		if err != nil {
    91  			t.Errorf("ParseMessage(%x) error: %+v", c.Message.Raw, err)
    92  		}
    93  		if got, want := msg, c.Message; !reflect.DeepEqual(got, want) {
    94  			t.Errorf("ParseMessage(%x) error got=%#v want=%#v", c.Message.Raw, got, want)
    95  		}
    96  		ReleaseMessage(msg)
    97  	}
    98  }
    99  
   100  func TestParseMessageError(t *testing.T) {
   101  	var cases = []struct {
   102  		Hex   string
   103  		Error error
   104  	}{
   105  		{
   106  			"0001010000010000000000",
   107  			ErrInvalidHeader,
   108  		},
   109  		{
   110  			"00020100000000000000000002686b0470687573026c7500000100",
   111  			ErrInvalidHeader,
   112  		},
   113  		{
   114  			"00020100000100000000000002686b0470687573026c7500000100",
   115  			ErrInvalidQuestion,
   116  		},
   117  	}
   118  
   119  	for _, c := range cases {
   120  		payload, err := hex.DecodeString(c.Hex)
   121  		if err != nil {
   122  			t.Errorf("hex.DecodeString(%v) error: %+v", c.Hex, err)
   123  		}
   124  		var msg Message
   125  		err = ParseMessage(&msg, payload, true)
   126  		if err != c.Error {
   127  			t.Errorf("ParseMessage(%x) should error: %+v", payload, c.Error)
   128  		}
   129  	}
   130  }
   131  
   132  func TestSetQuestion(t *testing.T) {
   133  	req := AcquireMessage()
   134  	defer ReleaseMessage(req)
   135  
   136  	req.SetRequestQuestion("mail.google.com", TypeA, ClassINET)
   137  
   138  	if req.Header.ID == 0 {
   139  		t.Errorf("req.Header.ID should not empty after SetQuestion")
   140  	}
   141  
   142  	if got, want := req.Header.Flags, Flags(0b0000000100000000); got != want {
   143  		t.Errorf("req.Header.Flags got=%x want=%x", got, want)
   144  	}
   145  
   146  	if got, want := req.Header.QDCount, uint16(1); got != want {
   147  		t.Errorf("req.Header.QDCount got=%d want=%d", got, want)
   148  	}
   149  
   150  	if got, want := req.Header.ANCount, uint16(0); got != want {
   151  		t.Errorf("req.Header.ANCount got=%d want=%d", got, want)
   152  	}
   153  
   154  	if got, want := req.Header.NSCount, uint16(0); got != want {
   155  		t.Errorf("req.Header.NSCount got=%d want=%d", got, want)
   156  	}
   157  
   158  	if got, want := req.Header.ARCount, uint16(0); got != want {
   159  		t.Errorf("req.Header.ARCount got=%d want=%d", got, want)
   160  	}
   161  
   162  	if got, want := string(req.Question.Name), "\x04mail\x06google\x03com\x00"; got != want {
   163  		t.Errorf("req.Question.Name got=%s want=%s", got, want)
   164  	}
   165  
   166  	if got, want := req.Question.Type, TypeA; got != want {
   167  		t.Errorf("req.Question.Type got=%s want=%s", got, want)
   168  	}
   169  
   170  	if got, want := req.Question.Class, ClassINET; got != want {
   171  		t.Errorf("req.Question.Class got=%s want=%s", got, want)
   172  	}
   173  
   174  	if got, want := string(req.Domain), "mail.google.com"; got != want {
   175  		t.Errorf("req.Question.Class got=%s want=%s", got, want)
   176  	}
   177  }
   178  
   179  func TestDecodeName(t *testing.T) {
   180  	payload, _ := hex.DecodeString("8e5281800001000200000000047632657803636f6d0000020001c00c000200010000545f0014036b696d026e730a636c6f7564666c617265c011c00c000200010000545f000704746f6464c02a")
   181  
   182  	resp := AcquireMessage()
   183  	defer ReleaseMessage(resp)
   184  
   185  	err := ParseMessage(resp, payload, true)
   186  	if err != nil {
   187  		t.Errorf("ParseMessage(%+v) error: %+v", payload, err)
   188  	}
   189  
   190  	if got, want := string(resp.DecodeName(nil, []byte("\x04todd\xc0\x2a"))), "todd.ns.cloudflare.com"; got != want {
   191  		t.Errorf("DecodeName(0xc02a) got=%s want=%s", got, want)
   192  	}
   193  }
   194  
   195  func BenchmarkParseMessage(b *testing.B) {
   196  	payload, _ := hex.DecodeString("00020100000100000000000002686b0470687573026c750000010001")
   197  	var msg Message
   198  
   199  	for i := 0; i < b.N; i++ {
   200  		if err := ParseMessage(&msg, payload, false); err != nil {
   201  			b.Errorf("ParseMessage(%+v) error: %+v", payload, err)
   202  		}
   203  	}
   204  }
   205  
   206  func BenchmarkSetQuestion(b *testing.B) {
   207  	req := AcquireMessage()
   208  	defer ReleaseMessage(req)
   209  
   210  	for i := 0; i < b.N; i++ {
   211  		req.SetRequestQuestion("mail.google.com", TypeA, ClassINET)
   212  	}
   213  }
   214  
   215  func BenchmarkSetResponseHeader(b *testing.B) {
   216  	req := AcquireMessage()
   217  	defer ReleaseMessage(req)
   218  
   219  	req.SetRequestQuestion("mail.google.com", TypeA, ClassINET)
   220  
   221  	for i := 0; i < b.N; i++ {
   222  		req.SetResponseHeader(RcodeNoError, 4)
   223  	}
   224  }
   225  
   226  func BenchmarkDecodeName(b *testing.B) {
   227  	payload, _ := hex.DecodeString("8e5281800001000200000000047632657803636f6d0000020001c00c000200010000545f0014036b696d026e730a636c6f7564666c617265c011c00c000200010000545f000704746f6464c02a")
   228  
   229  	resp := AcquireMessage()
   230  	defer ReleaseMessage(resp)
   231  
   232  	err := ParseMessage(resp, payload, true)
   233  	if err != nil {
   234  		b.Errorf("ParseMessage(%+v) error: %+v", payload, err)
   235  	}
   236  
   237  	var dst [256]byte
   238  	name := []byte("\x04todd\xc0\x2a")
   239  	for i := 0; i < b.N; i++ {
   240  		resp.DecodeName(dst[:0], name)
   241  	}
   242  }