github.com/slackhq/nebula@v1.9.0/control_test.go (about)

     1  package nebula
     2  
     3  import (
     4  	"net"
     5  	"reflect"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/sirupsen/logrus"
    10  	"github.com/slackhq/nebula/cert"
    11  	"github.com/slackhq/nebula/iputil"
    12  	"github.com/slackhq/nebula/test"
    13  	"github.com/slackhq/nebula/udp"
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  func TestControl_GetHostInfoByVpnIp(t *testing.T) {
    18  	l := test.NewLogger()
    19  	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
    20  	// To properly ensure we are not exposing core memory to the caller
    21  	hm := newHostMap(l, &net.IPNet{})
    22  	hm.preferredRanges.Store(&[]*net.IPNet{})
    23  
    24  	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
    25  	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
    26  	ipNet := net.IPNet{
    27  		IP:   net.IPv4(1, 2, 3, 4),
    28  		Mask: net.IPMask{255, 255, 255, 0},
    29  	}
    30  
    31  	ipNet2 := net.IPNet{
    32  		IP:   net.ParseIP("1:2:3:4:5:6:7:8"),
    33  		Mask: net.IPMask{255, 255, 255, 0},
    34  	}
    35  
    36  	crt := &cert.NebulaCertificate{
    37  		Details: cert.NebulaCertificateDetails{
    38  			Name:           "test",
    39  			Ips:            []*net.IPNet{&ipNet},
    40  			Subnets:        []*net.IPNet{},
    41  			Groups:         []string{"default-group"},
    42  			NotBefore:      time.Unix(1, 0),
    43  			NotAfter:       time.Unix(2, 0),
    44  			PublicKey:      []byte{5, 6, 7, 8},
    45  			IsCA:           false,
    46  			Issuer:         "the-issuer",
    47  			InvertedGroups: map[string]struct{}{"default-group": {}},
    48  		},
    49  		Signature: []byte{1, 2, 1, 2, 1, 3},
    50  	}
    51  
    52  	remotes := NewRemoteList(nil)
    53  	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
    54  	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
    55  	hm.unlockedAddHostInfo(&HostInfo{
    56  		remote:  remote1,
    57  		remotes: remotes,
    58  		ConnectionState: &ConnectionState{
    59  			peerCert: crt,
    60  		},
    61  		remoteIndexId: 200,
    62  		localIndexId:  201,
    63  		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
    64  		relayState: RelayState{
    65  			relays:        map[iputil.VpnIp]struct{}{},
    66  			relayForByIp:  map[iputil.VpnIp]*Relay{},
    67  			relayForByIdx: map[uint32]*Relay{},
    68  		},
    69  	}, &Interface{})
    70  
    71  	hm.unlockedAddHostInfo(&HostInfo{
    72  		remote:  remote1,
    73  		remotes: remotes,
    74  		ConnectionState: &ConnectionState{
    75  			peerCert: nil,
    76  		},
    77  		remoteIndexId: 200,
    78  		localIndexId:  201,
    79  		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
    80  		relayState: RelayState{
    81  			relays:        map[iputil.VpnIp]struct{}{},
    82  			relayForByIp:  map[iputil.VpnIp]*Relay{},
    83  			relayForByIdx: map[uint32]*Relay{},
    84  		},
    85  	}, &Interface{})
    86  
    87  	c := Control{
    88  		f: &Interface{
    89  			hostMap: hm,
    90  		},
    91  		l: logrus.New(),
    92  	}
    93  
    94  	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
    95  
    96  	expectedInfo := ControlHostInfo{
    97  		VpnIp:                  net.IPv4(1, 2, 3, 4).To4(),
    98  		LocalIndex:             201,
    99  		RemoteIndex:            200,
   100  		RemoteAddrs:            []*udp.Addr{remote2, remote1},
   101  		Cert:                   crt.Copy(),
   102  		MessageCounter:         0,
   103  		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
   104  		CurrentRelaysToMe:      []iputil.VpnIp{},
   105  		CurrentRelaysThroughMe: []iputil.VpnIp{},
   106  	}
   107  
   108  	// Make sure we don't have any unexpected fields
   109  	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
   110  	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
   111  
   112  	// Make sure we don't panic if the host info doesn't have a cert yet
   113  	assert.NotPanics(t, func() {
   114  		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
   115  	})
   116  }
   117  
   118  func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
   119  	val := reflect.ValueOf(actualStruct).Elem()
   120  	fields := make([]string, val.NumField())
   121  	for i := 0; i < val.NumField(); i++ {
   122  		fields[i] = val.Type().Field(i).Name
   123  	}
   124  
   125  	assert.Equal(t, expected, fields)
   126  }