github.com/slackhq/nebula@v1.9.0/connection_manager_test.go (about)

     1  package nebula
     2  
     3  import (
     4  	"context"
     5  	"crypto/ed25519"
     6  	"crypto/rand"
     7  	"net"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/flynn/noise"
    12  	"github.com/slackhq/nebula/cert"
    13  	"github.com/slackhq/nebula/config"
    14  	"github.com/slackhq/nebula/iputil"
    15  	"github.com/slackhq/nebula/test"
    16  	"github.com/slackhq/nebula/udp"
    17  	"github.com/stretchr/testify/assert"
    18  )
    19  
    20  var vpnIp iputil.VpnIp
    21  
    22  func newTestLighthouse() *LightHouse {
    23  	lh := &LightHouse{
    24  		l:         test.NewLogger(),
    25  		addrMap:   map[iputil.VpnIp]*RemoteList{},
    26  		queryChan: make(chan iputil.VpnIp, 10),
    27  	}
    28  	lighthouses := map[iputil.VpnIp]struct{}{}
    29  	staticList := map[iputil.VpnIp]struct{}{}
    30  
    31  	lh.lighthouses.Store(&lighthouses)
    32  	lh.staticList.Store(&staticList)
    33  
    34  	return lh
    35  }
    36  
    37  func Test_NewConnectionManagerTest(t *testing.T) {
    38  	l := test.NewLogger()
    39  	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
    40  	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
    41  	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
    42  	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
    43  	preferredRanges := []*net.IPNet{localrange}
    44  
    45  	// Very incomplete mock objects
    46  	hostMap := newHostMap(l, vpncidr)
    47  	hostMap.preferredRanges.Store(&preferredRanges)
    48  
    49  	cs := &CertState{
    50  		RawCertificate:      []byte{},
    51  		PrivateKey:          []byte{},
    52  		Certificate:         &cert.NebulaCertificate{},
    53  		RawCertificateNoKey: []byte{},
    54  	}
    55  
    56  	lh := newTestLighthouse()
    57  	ifce := &Interface{
    58  		hostMap:          hostMap,
    59  		inside:           &test.NoopTun{},
    60  		outside:          &udp.NoopConn{},
    61  		firewall:         &Firewall{},
    62  		lightHouse:       lh,
    63  		pki:              &PKI{},
    64  		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
    65  		l:                l,
    66  	}
    67  	ifce.pki.cs.Store(cs)
    68  
    69  	// Create manager
    70  	ctx, cancel := context.WithCancel(context.Background())
    71  	defer cancel()
    72  	punchy := NewPunchyFromConfig(l, config.NewC(l))
    73  	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
    74  	p := []byte("")
    75  	nb := make([]byte, 12, 12)
    76  	out := make([]byte, mtu)
    77  
    78  	// Add an ip we have established a connection w/ to hostmap
    79  	hostinfo := &HostInfo{
    80  		vpnIp:         vpnIp,
    81  		localIndexId:  1099,
    82  		remoteIndexId: 9901,
    83  	}
    84  	hostinfo.ConnectionState = &ConnectionState{
    85  		myCert: &cert.NebulaCertificate{},
    86  		H:      &noise.HandshakeState{},
    87  	}
    88  	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
    89  
    90  	// We saw traffic out to vpnIp
    91  	nc.Out(hostinfo.localIndexId)
    92  	nc.In(hostinfo.localIndexId)
    93  	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
    94  	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
    95  	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
    96  	assert.Contains(t, nc.out, hostinfo.localIndexId)
    97  
    98  	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
    99  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   100  	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
   101  	assert.NotContains(t, nc.out, hostinfo.localIndexId)
   102  	assert.NotContains(t, nc.in, hostinfo.localIndexId)
   103  
   104  	// Do another traffic check tick, this host should be pending deletion now
   105  	nc.Out(hostinfo.localIndexId)
   106  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   107  	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
   108  	assert.NotContains(t, nc.out, hostinfo.localIndexId)
   109  	assert.NotContains(t, nc.in, hostinfo.localIndexId)
   110  	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
   111  	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
   112  
   113  	// Do a final traffic check tick, the host should now be removed
   114  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   115  	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
   116  	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
   117  	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
   118  }
   119  
   120  func Test_NewConnectionManagerTest2(t *testing.T) {
   121  	l := test.NewLogger()
   122  	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
   123  	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
   124  	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
   125  	preferredRanges := []*net.IPNet{localrange}
   126  
   127  	// Very incomplete mock objects
   128  	hostMap := newHostMap(l, vpncidr)
   129  	hostMap.preferredRanges.Store(&preferredRanges)
   130  
   131  	cs := &CertState{
   132  		RawCertificate:      []byte{},
   133  		PrivateKey:          []byte{},
   134  		Certificate:         &cert.NebulaCertificate{},
   135  		RawCertificateNoKey: []byte{},
   136  	}
   137  
   138  	lh := newTestLighthouse()
   139  	ifce := &Interface{
   140  		hostMap:          hostMap,
   141  		inside:           &test.NoopTun{},
   142  		outside:          &udp.NoopConn{},
   143  		firewall:         &Firewall{},
   144  		lightHouse:       lh,
   145  		pki:              &PKI{},
   146  		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
   147  		l:                l,
   148  	}
   149  	ifce.pki.cs.Store(cs)
   150  
   151  	// Create manager
   152  	ctx, cancel := context.WithCancel(context.Background())
   153  	defer cancel()
   154  	punchy := NewPunchyFromConfig(l, config.NewC(l))
   155  	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
   156  	p := []byte("")
   157  	nb := make([]byte, 12, 12)
   158  	out := make([]byte, mtu)
   159  
   160  	// Add an ip we have established a connection w/ to hostmap
   161  	hostinfo := &HostInfo{
   162  		vpnIp:         vpnIp,
   163  		localIndexId:  1099,
   164  		remoteIndexId: 9901,
   165  	}
   166  	hostinfo.ConnectionState = &ConnectionState{
   167  		myCert: &cert.NebulaCertificate{},
   168  		H:      &noise.HandshakeState{},
   169  	}
   170  	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
   171  
   172  	// We saw traffic out to vpnIp
   173  	nc.Out(hostinfo.localIndexId)
   174  	nc.In(hostinfo.localIndexId)
   175  	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
   176  	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
   177  	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
   178  
   179  	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
   180  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   181  	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
   182  	assert.NotContains(t, nc.out, hostinfo.localIndexId)
   183  	assert.NotContains(t, nc.in, hostinfo.localIndexId)
   184  
   185  	// Do another traffic check tick, this host should be pending deletion now
   186  	nc.Out(hostinfo.localIndexId)
   187  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   188  	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
   189  	assert.NotContains(t, nc.out, hostinfo.localIndexId)
   190  	assert.NotContains(t, nc.in, hostinfo.localIndexId)
   191  	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
   192  	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
   193  
   194  	// We saw traffic, should no longer be pending deletion
   195  	nc.In(hostinfo.localIndexId)
   196  	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
   197  	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
   198  	assert.NotContains(t, nc.out, hostinfo.localIndexId)
   199  	assert.NotContains(t, nc.in, hostinfo.localIndexId)
   200  	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
   201  	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
   202  }
   203  
   204  // Check if we can disconnect the peer.
   205  // Validate if the peer's certificate is invalid (expired, etc.)
   206  // Disconnect only if disconnectInvalid: true is set.
   207  func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
   208  	now := time.Now()
   209  	l := test.NewLogger()
   210  	ipNet := net.IPNet{
   211  		IP:   net.IPv4(172, 1, 1, 2),
   212  		Mask: net.IPMask{255, 255, 255, 0},
   213  	}
   214  	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
   215  	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
   216  	preferredRanges := []*net.IPNet{localrange}
   217  	hostMap := newHostMap(l, vpncidr)
   218  	hostMap.preferredRanges.Store(&preferredRanges)
   219  
   220  	// Generate keys for CA and peer's cert.
   221  	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
   222  	caCert := cert.NebulaCertificate{
   223  		Details: cert.NebulaCertificateDetails{
   224  			Name:      "ca",
   225  			NotBefore: now,
   226  			NotAfter:  now.Add(1 * time.Hour),
   227  			IsCA:      true,
   228  			PublicKey: pubCA,
   229  		},
   230  	}
   231  
   232  	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
   233  	ncp := &cert.NebulaCAPool{
   234  		CAs: cert.NewCAPool().CAs,
   235  	}
   236  	ncp.CAs["ca"] = &caCert
   237  
   238  	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
   239  	peerCert := cert.NebulaCertificate{
   240  		Details: cert.NebulaCertificateDetails{
   241  			Name:      "host",
   242  			Ips:       []*net.IPNet{&ipNet},
   243  			Subnets:   []*net.IPNet{},
   244  			NotBefore: now,
   245  			NotAfter:  now.Add(60 * time.Second),
   246  			PublicKey: pubCrt,
   247  			IsCA:      false,
   248  			Issuer:    "ca",
   249  		},
   250  	}
   251  	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
   252  
   253  	cs := &CertState{
   254  		RawCertificate:      []byte{},
   255  		PrivateKey:          []byte{},
   256  		Certificate:         &cert.NebulaCertificate{},
   257  		RawCertificateNoKey: []byte{},
   258  	}
   259  
   260  	lh := newTestLighthouse()
   261  	ifce := &Interface{
   262  		hostMap:          hostMap,
   263  		inside:           &test.NoopTun{},
   264  		outside:          &udp.NoopConn{},
   265  		firewall:         &Firewall{},
   266  		lightHouse:       lh,
   267  		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
   268  		l:                l,
   269  		pki:              &PKI{},
   270  	}
   271  	ifce.pki.cs.Store(cs)
   272  	ifce.pki.caPool.Store(ncp)
   273  	ifce.disconnectInvalid.Store(true)
   274  
   275  	// Create manager
   276  	ctx, cancel := context.WithCancel(context.Background())
   277  	defer cancel()
   278  	punchy := NewPunchyFromConfig(l, config.NewC(l))
   279  	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
   280  	ifce.connectionManager = nc
   281  
   282  	hostinfo := &HostInfo{
   283  		vpnIp: vpnIp,
   284  		ConnectionState: &ConnectionState{
   285  			myCert:   &cert.NebulaCertificate{},
   286  			peerCert: &peerCert,
   287  			H:        &noise.HandshakeState{},
   288  		},
   289  	}
   290  	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
   291  
   292  	// Move ahead 45s.
   293  	// Check if to disconnect with invalid certificate.
   294  	// Should be alive.
   295  	nextTick := now.Add(45 * time.Second)
   296  	invalid := nc.isInvalidCertificate(nextTick, hostinfo)
   297  	assert.False(t, invalid)
   298  
   299  	// Move ahead 61s.
   300  	// Check if to disconnect with invalid certificate.
   301  	// Should be disconnected.
   302  	nextTick = now.Add(61 * time.Second)
   303  	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
   304  	assert.True(t, invalid)
   305  }