github.com/sandwich-go/boost@v1.3.29/httputil/dns/dns_test.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/sandwich-go/boost/z"
     8  	. "github.com/smartystreets/goconvey/convey"
     9  	"net"
    10  	"strconv"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  var (
    17  	mockPort      = strconv.FormatInt(int64(z.FastRand()), 10)
    18  	mockIPAddrs   = []net.IPAddr{{IP: []byte("a")}, {IP: []byte("b")}}
    19  	mockConn      = &net.TCPConn{}
    20  	errLookupFail = errors.New("lookup fail")
    21  	errDailFail   = errors.New("dail fail")
    22  	errTimeout    = errors.New("timeout")
    23  )
    24  
    25  type mockDialer struct{}
    26  
    27  func (mockDialer) DialContext(_ context.Context, _, address string) (net.Conn, error) {
    28  	for _, ipAddr := range mockIPAddrs {
    29  		if strings.HasPrefix(address, ipAddr.String()) && strings.HasSuffix(address, mockPort) {
    30  			return mockConn, nil
    31  		}
    32  	}
    33  	return nil, errDailFail
    34  }
    35  
    36  type mockResolver struct{}
    37  
    38  func (mockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) {
    39  	if len(host) == 0 {
    40  		return nil, errLookupFail
    41  	}
    42  	host = strings.TrimSpace(host)
    43  	if len(host) == 0 {
    44  		return nil, nil
    45  	}
    46  	if ss, _ := strconv.ParseInt(host, 10, 64); ss > 0 {
    47  		select {
    48  		case <-ctx.Done():
    49  			return nil, errTimeout
    50  		}
    51  	}
    52  	return mockIPAddrs, nil
    53  }
    54  
    55  func TestDNS(t *testing.T) {
    56  	Convey("dns", t, func() {
    57  		var lookupSuccess bool
    58  		d := New(
    59  			WithLookupTimeout(10*time.Millisecond),
    60  			WithDialer(mockDialer{}),
    61  			WithResolver(mockResolver{}),
    62  			WithOnLookup(func(ctx context.Context, host string, cost time.Duration, ipAddrs []net.IPAddr) {
    63  				lookupSuccess = true
    64  			}))
    65  		for _, test := range []struct {
    66  			host string
    67  			err  error
    68  		}{
    69  			{host: "", err: errLookupFail},
    70  			{host: " ", err: ErrNotFound},
    71  			{host: "1", err: errTimeout},
    72  			{host: "0"},
    73  		} {
    74  			ipAddrs, err := d.LookupIPAddr(context.Background(), test.host)
    75  			if test.err != nil {
    76  				So(err, ShouldNotBeNil)
    77  				So(test.err, ShouldEqual, err)
    78  				So(lookupSuccess, ShouldBeFalse)
    79  			} else {
    80  				So(err, ShouldBeNil)
    81  				So(len(ipAddrs), ShouldEqual, len(mockIPAddrs))
    82  				So(ipAddrs, ShouldResemble, mockIPAddrs)
    83  				So(lookupSuccess, ShouldBeTrue)
    84  			}
    85  		}
    86  
    87  		lookupSuccess = false
    88  
    89  		dail := d.GetDialContext()
    90  		for _, test := range []struct {
    91  			host string
    92  			err  error
    93  		}{
    94  			{host: fmt.Sprintf(":%s", mockPort), err: errLookupFail},
    95  			{host: fmt.Sprintf(" :%s", mockPort), err: ErrNotFound},
    96  			{host: fmt.Sprintf("1:%s", mockPort), err: errTimeout},
    97  			{host: fmt.Sprintf("0:%s", mockPort)},
    98  		} {
    99  			conn, err := dail(context.Background(), "mock", test.host)
   100  			if test.err != nil {
   101  				So(err, ShouldNotBeNil)
   102  				So(conn, ShouldBeNil)
   103  				So(test.err, ShouldEqual, err)
   104  				So(lookupSuccess, ShouldBeFalse)
   105  			} else {
   106  				So(err, ShouldBeNil)
   107  				So(conn, ShouldNotBeNil)
   108  				So(conn, ShouldEqual, mockConn)
   109  				So(lookupSuccess, ShouldBeTrue)
   110  			}
   111  		}
   112  	})
   113  }
   114  
   115  func TestCacheDNS(t *testing.T) {
   116  	Convey("cache dns", t, func() {
   117  		var lookupSuccess bool
   118  		d := NewCache(
   119  			WithLookupTimeout(10*time.Millisecond),
   120  			WithDialer(mockDialer{}),
   121  			WithResolver(mockResolver{}),
   122  			WithOnLookup(func(ctx context.Context, host string, cost time.Duration, ipAddrs []net.IPAddr) {
   123  				lookupSuccess = true
   124  			}))
   125  		for _, test := range []struct {
   126  			host string
   127  			err  error
   128  		}{
   129  			{host: "", err: errLookupFail},
   130  			{host: " ", err: ErrNotFound},
   131  			{host: "1", err: errTimeout},
   132  			{host: "0"},
   133  		} {
   134  			ipAddrs, err := d.LookupIPAddr(context.Background(), test.host)
   135  			if test.err != nil {
   136  				So(err, ShouldNotBeNil)
   137  				So(test.err, ShouldEqual, err)
   138  				So(lookupSuccess, ShouldBeFalse)
   139  			} else {
   140  				So(err, ShouldBeNil)
   141  				So(len(ipAddrs), ShouldEqual, len(mockIPAddrs))
   142  				So(ipAddrs, ShouldResemble, mockIPAddrs)
   143  				So(lookupSuccess, ShouldBeTrue)
   144  
   145  				_, ok := d.Get(test.host)
   146  				So(ok, ShouldBeTrue)
   147  			}
   148  		}
   149  
   150  		lookupSuccess = false
   151  
   152  		dail := d.GetDialContext()
   153  		for _, test := range []struct {
   154  			host string
   155  			err  error
   156  		}{
   157  			{host: fmt.Sprintf(":%s", mockPort), err: errLookupFail},
   158  			{host: fmt.Sprintf(" :%s", mockPort), err: ErrNotFound},
   159  			{host: fmt.Sprintf("1:%s", mockPort), err: errTimeout},
   160  			{host: fmt.Sprintf("0:%s", mockPort)},
   161  		} {
   162  			conn, err := dail(context.Background(), "mock", test.host)
   163  			if test.err != nil {
   164  				So(err, ShouldNotBeNil)
   165  				So(conn, ShouldBeNil)
   166  				So(test.err, ShouldEqual, err)
   167  				So(lookupSuccess, ShouldBeFalse)
   168  			} else {
   169  				So(err, ShouldBeNil)
   170  				So(conn, ShouldNotBeNil)
   171  				So(conn, ShouldEqual, mockConn)
   172  				So(lookupSuccess, ShouldBeFalse) // 说明没有真正的去 lookup,而是走的缓存
   173  			}
   174  		}
   175  	})
   176  }