github.com/sandwich-go/boost@v1.3.29/httputil/dns/dns_test.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "github.com/sandwich-go/boost/z" 8 . "github.com/smartystreets/goconvey/convey" 9 "net" 10 "strconv" 11 "strings" 12 "testing" 13 "time" 14 ) 15 16 var ( 17 mockPort = strconv.FormatInt(int64(z.FastRand()), 10) 18 mockIPAddrs = []net.IPAddr{{IP: []byte("a")}, {IP: []byte("b")}} 19 mockConn = &net.TCPConn{} 20 errLookupFail = errors.New("lookup fail") 21 errDailFail = errors.New("dail fail") 22 errTimeout = errors.New("timeout") 23 ) 24 25 type mockDialer struct{} 26 27 func (mockDialer) DialContext(_ context.Context, _, address string) (net.Conn, error) { 28 for _, ipAddr := range mockIPAddrs { 29 if strings.HasPrefix(address, ipAddr.String()) && strings.HasSuffix(address, mockPort) { 30 return mockConn, nil 31 } 32 } 33 return nil, errDailFail 34 } 35 36 type mockResolver struct{} 37 38 func (mockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { 39 if len(host) == 0 { 40 return nil, errLookupFail 41 } 42 host = strings.TrimSpace(host) 43 if len(host) == 0 { 44 return nil, nil 45 } 46 if ss, _ := strconv.ParseInt(host, 10, 64); ss > 0 { 47 select { 48 case <-ctx.Done(): 49 return nil, errTimeout 50 } 51 } 52 return mockIPAddrs, nil 53 } 54 55 func TestDNS(t *testing.T) { 56 Convey("dns", t, func() { 57 var lookupSuccess bool 58 d := New( 59 WithLookupTimeout(10*time.Millisecond), 60 WithDialer(mockDialer{}), 61 WithResolver(mockResolver{}), 62 WithOnLookup(func(ctx context.Context, host string, cost time.Duration, ipAddrs []net.IPAddr) { 63 lookupSuccess = true 64 })) 65 for _, test := range []struct { 66 host string 67 err error 68 }{ 69 {host: "", err: errLookupFail}, 70 {host: " ", err: ErrNotFound}, 71 {host: "1", err: errTimeout}, 72 {host: "0"}, 73 } { 74 ipAddrs, err := d.LookupIPAddr(context.Background(), test.host) 75 if test.err != nil { 76 So(err, ShouldNotBeNil) 77 So(test.err, ShouldEqual, err) 78 So(lookupSuccess, ShouldBeFalse) 79 } else { 80 So(err, ShouldBeNil) 81 So(len(ipAddrs), ShouldEqual, len(mockIPAddrs)) 82 So(ipAddrs, ShouldResemble, mockIPAddrs) 83 So(lookupSuccess, ShouldBeTrue) 84 } 85 } 86 87 lookupSuccess = false 88 89 dail := d.GetDialContext() 90 for _, test := range []struct { 91 host string 92 err error 93 }{ 94 {host: fmt.Sprintf(":%s", mockPort), err: errLookupFail}, 95 {host: fmt.Sprintf(" :%s", mockPort), err: ErrNotFound}, 96 {host: fmt.Sprintf("1:%s", mockPort), err: errTimeout}, 97 {host: fmt.Sprintf("0:%s", mockPort)}, 98 } { 99 conn, err := dail(context.Background(), "mock", test.host) 100 if test.err != nil { 101 So(err, ShouldNotBeNil) 102 So(conn, ShouldBeNil) 103 So(test.err, ShouldEqual, err) 104 So(lookupSuccess, ShouldBeFalse) 105 } else { 106 So(err, ShouldBeNil) 107 So(conn, ShouldNotBeNil) 108 So(conn, ShouldEqual, mockConn) 109 So(lookupSuccess, ShouldBeTrue) 110 } 111 } 112 }) 113 } 114 115 func TestCacheDNS(t *testing.T) { 116 Convey("cache dns", t, func() { 117 var lookupSuccess bool 118 d := NewCache( 119 WithLookupTimeout(10*time.Millisecond), 120 WithDialer(mockDialer{}), 121 WithResolver(mockResolver{}), 122 WithOnLookup(func(ctx context.Context, host string, cost time.Duration, ipAddrs []net.IPAddr) { 123 lookupSuccess = true 124 })) 125 for _, test := range []struct { 126 host string 127 err error 128 }{ 129 {host: "", err: errLookupFail}, 130 {host: " ", err: ErrNotFound}, 131 {host: "1", err: errTimeout}, 132 {host: "0"}, 133 } { 134 ipAddrs, err := d.LookupIPAddr(context.Background(), test.host) 135 if test.err != nil { 136 So(err, ShouldNotBeNil) 137 So(test.err, ShouldEqual, err) 138 So(lookupSuccess, ShouldBeFalse) 139 } else { 140 So(err, ShouldBeNil) 141 So(len(ipAddrs), ShouldEqual, len(mockIPAddrs)) 142 So(ipAddrs, ShouldResemble, mockIPAddrs) 143 So(lookupSuccess, ShouldBeTrue) 144 145 _, ok := d.Get(test.host) 146 So(ok, ShouldBeTrue) 147 } 148 } 149 150 lookupSuccess = false 151 152 dail := d.GetDialContext() 153 for _, test := range []struct { 154 host string 155 err error 156 }{ 157 {host: fmt.Sprintf(":%s", mockPort), err: errLookupFail}, 158 {host: fmt.Sprintf(" :%s", mockPort), err: ErrNotFound}, 159 {host: fmt.Sprintf("1:%s", mockPort), err: errTimeout}, 160 {host: fmt.Sprintf("0:%s", mockPort)}, 161 } { 162 conn, err := dail(context.Background(), "mock", test.host) 163 if test.err != nil { 164 So(err, ShouldNotBeNil) 165 So(conn, ShouldBeNil) 166 So(test.err, ShouldEqual, err) 167 So(lookupSuccess, ShouldBeFalse) 168 } else { 169 So(err, ShouldBeNil) 170 So(conn, ShouldNotBeNil) 171 So(conn, ShouldEqual, mockConn) 172 So(lookupSuccess, ShouldBeFalse) // 说明没有真正的去 lookup,而是走的缓存 173 } 174 } 175 }) 176 }