github.com/netdata/go.d.plugin@v0.58.1/modules/wireguard/wireguard_test.go (about)

     1  // SPDX-License-Identifier: GPL-3.0-or-later
     2  
     3  package wireguard
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/netdata/go.d.plugin/agent/module"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
    17  )
    18  
    19  func TestWireGuard_Init(t *testing.T) {
    20  	assert.True(t, New().Init())
    21  }
    22  
    23  func TestWireGuard_Charts(t *testing.T) {
    24  	assert.Len(t, *New().Charts(), 0)
    25  
    26  }
    27  
    28  func TestWireGuard_Cleanup(t *testing.T) {
    29  	tests := map[string]struct {
    30  		prepare   func(w *WireGuard)
    31  		wantClose bool
    32  	}{
    33  		"after New": {
    34  			wantClose: false,
    35  			prepare:   func(w *WireGuard) {},
    36  		},
    37  		"after Init": {
    38  			wantClose: false,
    39  			prepare:   func(w *WireGuard) { w.Init() },
    40  		},
    41  		"after Check": {
    42  			wantClose: true,
    43  			prepare:   func(w *WireGuard) { w.Init(); w.Check() },
    44  		},
    45  		"after Collect": {
    46  			wantClose: true,
    47  			prepare:   func(w *WireGuard) { w.Init(); w.Collect() },
    48  		},
    49  	}
    50  
    51  	for name, test := range tests {
    52  		t.Run(name, func(t *testing.T) {
    53  			w := New()
    54  			m := &mockClient{}
    55  			w.newWGClient = func() (wgClient, error) { return m, nil }
    56  
    57  			test.prepare(w)
    58  
    59  			require.NotPanics(t, w.Cleanup)
    60  
    61  			if test.wantClose {
    62  				assert.True(t, m.closeCalled)
    63  			} else {
    64  				assert.False(t, m.closeCalled)
    65  			}
    66  		})
    67  	}
    68  }
    69  
    70  func TestWireGuard_Check(t *testing.T) {
    71  	tests := map[string]struct {
    72  		wantFail bool
    73  		prepare  func(w *WireGuard)
    74  	}{
    75  		"success when devices and peers found": {
    76  			wantFail: false,
    77  			prepare: func(w *WireGuard) {
    78  				m := &mockClient{}
    79  				d1 := prepareDevice(1)
    80  				d1.Peers = append(d1.Peers, preparePeer("11"))
    81  				d1.Peers = append(d1.Peers, preparePeer("12"))
    82  				m.devices = append(m.devices, d1)
    83  				w.client = m
    84  			},
    85  		},
    86  		"success when devices and no peers found": {
    87  			wantFail: false,
    88  			prepare: func(w *WireGuard) {
    89  				m := &mockClient{}
    90  				m.devices = append(m.devices, prepareDevice(1))
    91  				w.client = m
    92  			},
    93  		},
    94  		"fail when no devices and no peers found": {
    95  			wantFail: true,
    96  			prepare: func(w *WireGuard) {
    97  				w.client = &mockClient{}
    98  			},
    99  		},
   100  		"fail when error on retrieving devices": {
   101  			wantFail: true,
   102  			prepare: func(w *WireGuard) {
   103  				w.client = &mockClient{errOnDevices: true}
   104  			},
   105  		},
   106  		"fail when error on creating client": {
   107  			wantFail: true,
   108  			prepare: func(w *WireGuard) {
   109  				w.newWGClient = func() (wgClient, error) { return nil, errors.New("mock.newWGClient() error") }
   110  			},
   111  		},
   112  	}
   113  
   114  	for name, test := range tests {
   115  		t.Run(name, func(t *testing.T) {
   116  			w := New()
   117  			require.True(t, w.Init())
   118  			test.prepare(w)
   119  
   120  			if test.wantFail {
   121  				assert.False(t, w.Check())
   122  			} else {
   123  				assert.True(t, w.Check())
   124  			}
   125  		})
   126  	}
   127  }
   128  
   129  func TestWireGuard_Collect(t *testing.T) {
   130  	type testCaseStep struct {
   131  		prepareMock func(m *mockClient)
   132  		check       func(t *testing.T, w *WireGuard)
   133  	}
   134  	tests := map[string][]testCaseStep{
   135  		"several devices no peers": {
   136  			{
   137  				prepareMock: func(m *mockClient) {
   138  					m.devices = append(m.devices, prepareDevice(1))
   139  					m.devices = append(m.devices, prepareDevice(2))
   140  				},
   141  				check: func(t *testing.T, w *WireGuard) {
   142  					mx := w.Collect()
   143  
   144  					expected := map[string]int64{
   145  						"device_wg1_peers":    0,
   146  						"device_wg1_receive":  0,
   147  						"device_wg1_transmit": 0,
   148  						"device_wg2_peers":    0,
   149  						"device_wg2_receive":  0,
   150  						"device_wg2_transmit": 0,
   151  					}
   152  
   153  					copyLatestHandshake(mx, expected)
   154  					assert.Equal(t, expected, mx)
   155  					assert.Equal(t, len(deviceChartsTmpl)*2, len(*w.Charts()))
   156  				},
   157  			},
   158  		},
   159  		"several devices several peers each": {
   160  			{
   161  				prepareMock: func(m *mockClient) {
   162  					d1 := prepareDevice(1)
   163  					d1.Peers = append(d1.Peers, preparePeer("11"))
   164  					d1.Peers = append(d1.Peers, preparePeer("12"))
   165  					m.devices = append(m.devices, d1)
   166  
   167  					d2 := prepareDevice(2)
   168  					d2.Peers = append(d2.Peers, preparePeer("21"))
   169  					d2.Peers = append(d2.Peers, preparePeer("22"))
   170  					m.devices = append(m.devices, d2)
   171  				},
   172  				check: func(t *testing.T, w *WireGuard) {
   173  					mx := w.Collect()
   174  
   175  					expected := map[string]int64{
   176  						"device_wg1_peers":    2,
   177  						"device_wg1_receive":  0,
   178  						"device_wg1_transmit": 0,
   179  						"device_wg2_peers":    2,
   180  						"device_wg2_receive":  0,
   181  						"device_wg2_transmit": 0,
   182  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   183  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   184  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   185  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   186  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   187  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   188  						"peer_wg2_cGVlcjIxAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   189  						"peer_wg2_cGVlcjIxAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   190  						"peer_wg2_cGVlcjIxAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   191  						"peer_wg2_cGVlcjIyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   192  						"peer_wg2_cGVlcjIyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   193  						"peer_wg2_cGVlcjIyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   194  					}
   195  
   196  					copyLatestHandshake(mx, expected)
   197  					assert.Equal(t, expected, mx)
   198  					assert.Equal(t, len(deviceChartsTmpl)*2+len(peerChartsTmpl)*4, len(*w.Charts()))
   199  				},
   200  			},
   201  		},
   202  		"peers without last handshake time": {
   203  			{
   204  				prepareMock: func(m *mockClient) {
   205  					d1 := prepareDevice(1)
   206  					d1.Peers = append(d1.Peers, preparePeer("11"))
   207  					d1.Peers = append(d1.Peers, preparePeer("12"))
   208  					d1.Peers = append(d1.Peers, prepareNoLastHandshakePeer("13"))
   209  					d1.Peers = append(d1.Peers, prepareNoLastHandshakePeer("14"))
   210  					m.devices = append(m.devices, d1)
   211  				},
   212  				check: func(t *testing.T, w *WireGuard) {
   213  					mx := w.Collect()
   214  
   215  					expected := map[string]int64{
   216  						"device_wg1_peers":    4,
   217  						"device_wg1_receive":  0,
   218  						"device_wg1_transmit": 0,
   219  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   220  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   221  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   222  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   223  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   224  						"peer_wg1_cGVlcjEyAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   225  					}
   226  
   227  					copyLatestHandshake(mx, expected)
   228  					assert.Equal(t, expected, mx)
   229  					assert.Equal(t, len(deviceChartsTmpl)+len(peerChartsTmpl)*2, len(*w.Charts()))
   230  				},
   231  			},
   232  		},
   233  		"device added at runtime": {
   234  			{
   235  				prepareMock: func(m *mockClient) {
   236  					m.devices = append(m.devices, prepareDevice(1))
   237  				},
   238  				check: func(t *testing.T, w *WireGuard) {
   239  					_ = w.Collect()
   240  					assert.Equal(t, len(deviceChartsTmpl)*1, len(*w.Charts()))
   241  				},
   242  			},
   243  			{
   244  				prepareMock: func(m *mockClient) {
   245  					m.devices = append(m.devices, prepareDevice(2))
   246  				},
   247  				check: func(t *testing.T, w *WireGuard) {
   248  					mx := w.Collect()
   249  
   250  					expected := map[string]int64{
   251  						"device_wg1_peers":    0,
   252  						"device_wg1_receive":  0,
   253  						"device_wg1_transmit": 0,
   254  						"device_wg2_peers":    0,
   255  						"device_wg2_receive":  0,
   256  						"device_wg2_transmit": 0,
   257  					}
   258  					copyLatestHandshake(mx, expected)
   259  					assert.Equal(t, expected, mx)
   260  					assert.Equal(t, len(deviceChartsTmpl)*2, len(*w.Charts()))
   261  
   262  				},
   263  			},
   264  		},
   265  		"device removed at run time, no cleanup occurred": {
   266  			{
   267  				prepareMock: func(m *mockClient) {
   268  					m.devices = append(m.devices, prepareDevice(1))
   269  					m.devices = append(m.devices, prepareDevice(2))
   270  				},
   271  				check: func(t *testing.T, w *WireGuard) {
   272  					_ = w.Collect()
   273  				},
   274  			},
   275  			{
   276  				prepareMock: func(m *mockClient) {
   277  					m.devices = m.devices[:len(m.devices)-1]
   278  				},
   279  				check: func(t *testing.T, w *WireGuard) {
   280  					_ = w.Collect()
   281  					assert.Equal(t, len(deviceChartsTmpl)*2, len(*w.Charts()))
   282  					assert.Equal(t, 0, calcObsoleteCharts(w.Charts()))
   283  				},
   284  			},
   285  		},
   286  		"device removed at run time, cleanup occurred": {
   287  			{
   288  				prepareMock: func(m *mockClient) {
   289  					m.devices = append(m.devices, prepareDevice(1))
   290  					m.devices = append(m.devices, prepareDevice(2))
   291  				},
   292  				check: func(t *testing.T, w *WireGuard) {
   293  					_ = w.Collect()
   294  				},
   295  			},
   296  			{
   297  				prepareMock: func(m *mockClient) {
   298  					m.devices = m.devices[:len(m.devices)-1]
   299  				},
   300  				check: func(t *testing.T, w *WireGuard) {
   301  					w.cleanupEvery = time.Second
   302  					time.Sleep(time.Second)
   303  					_ = w.Collect()
   304  					assert.Equal(t, len(deviceChartsTmpl)*2, len(*w.Charts()))
   305  					assert.Equal(t, len(deviceChartsTmpl)*1, calcObsoleteCharts(w.Charts()))
   306  				},
   307  			},
   308  		},
   309  		"peer added at runtime": {
   310  			{
   311  				prepareMock: func(m *mockClient) {
   312  					m.devices = append(m.devices, prepareDevice(1))
   313  				},
   314  				check: func(t *testing.T, w *WireGuard) {
   315  					_ = w.Collect()
   316  					assert.Equal(t, len(deviceChartsTmpl)*1, len(*w.Charts()))
   317  				},
   318  			},
   319  			{
   320  				prepareMock: func(m *mockClient) {
   321  					d1 := m.devices[0]
   322  					d1.Peers = append(d1.Peers, preparePeer("11"))
   323  				},
   324  				check: func(t *testing.T, w *WireGuard) {
   325  					mx := w.Collect()
   326  
   327  					expected := map[string]int64{
   328  						"device_wg1_peers":    1,
   329  						"device_wg1_receive":  0,
   330  						"device_wg1_transmit": 0,
   331  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_latest_handshake_ago": 60,
   332  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_receive":              0,
   333  						"peer_wg1_cGVlcjExAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=_transmit":             0,
   334  					}
   335  					copyLatestHandshake(mx, expected)
   336  					assert.Equal(t, expected, mx)
   337  					assert.Equal(t, len(deviceChartsTmpl)*1+len(peerChartsTmpl)*1, len(*w.Charts()))
   338  				},
   339  			},
   340  		},
   341  		"peer removed at run time, no cleanup occurred": {
   342  			{
   343  				prepareMock: func(m *mockClient) {
   344  					d1 := prepareDevice(1)
   345  					d1.Peers = append(d1.Peers, preparePeer("11"))
   346  					d1.Peers = append(d1.Peers, preparePeer("12"))
   347  					m.devices = append(m.devices, d1)
   348  				},
   349  				check: func(t *testing.T, w *WireGuard) {
   350  					_ = w.Collect()
   351  				},
   352  			},
   353  			{
   354  				prepareMock: func(m *mockClient) {
   355  					d1 := m.devices[0]
   356  					d1.Peers = d1.Peers[:len(d1.Peers)-1]
   357  				},
   358  				check: func(t *testing.T, w *WireGuard) {
   359  					_ = w.Collect()
   360  					assert.Equal(t, len(deviceChartsTmpl)*1+len(peerChartsTmpl)*2, len(*w.Charts()))
   361  					assert.Equal(t, 0, calcObsoleteCharts(w.Charts()))
   362  				},
   363  			},
   364  		},
   365  		"peer removed at run time, cleanup occurred": {
   366  			{
   367  				prepareMock: func(m *mockClient) {
   368  					d1 := prepareDevice(1)
   369  					d1.Peers = append(d1.Peers, preparePeer("11"))
   370  					d1.Peers = append(d1.Peers, preparePeer("12"))
   371  					m.devices = append(m.devices, d1)
   372  				},
   373  				check: func(t *testing.T, w *WireGuard) {
   374  					_ = w.Collect()
   375  				},
   376  			},
   377  			{
   378  				prepareMock: func(m *mockClient) {
   379  					d1 := m.devices[0]
   380  					d1.Peers = d1.Peers[:len(d1.Peers)-1]
   381  				},
   382  				check: func(t *testing.T, w *WireGuard) {
   383  					w.cleanupEvery = time.Second
   384  					time.Sleep(time.Second)
   385  					_ = w.Collect()
   386  					assert.Equal(t, len(deviceChartsTmpl)*1+len(peerChartsTmpl)*2, len(*w.Charts()))
   387  					assert.Equal(t, len(peerChartsTmpl)*1, calcObsoleteCharts(w.Charts()))
   388  				},
   389  			},
   390  		},
   391  		"fails if no devices found": {
   392  			{
   393  				prepareMock: func(m *mockClient) {},
   394  				check: func(t *testing.T, w *WireGuard) {
   395  					assert.Equal(t, map[string]int64(nil), w.Collect())
   396  				},
   397  			},
   398  		},
   399  		"fails if error on getting devices list": {
   400  			{
   401  				prepareMock: func(m *mockClient) {
   402  					m.errOnDevices = true
   403  				},
   404  				check: func(t *testing.T, w *WireGuard) {
   405  					assert.Equal(t, map[string]int64(nil), w.Collect())
   406  				},
   407  			},
   408  		},
   409  	}
   410  
   411  	for name, test := range tests {
   412  		t.Run(name, func(t *testing.T) {
   413  			w := New()
   414  			require.True(t, w.Init())
   415  			m := &mockClient{}
   416  			w.client = m
   417  
   418  			for i, step := range test {
   419  				t.Run(fmt.Sprintf("step[%d]", i), func(t *testing.T) {
   420  					step.prepareMock(m)
   421  					step.check(t, w)
   422  				})
   423  			}
   424  		})
   425  	}
   426  }
   427  
   428  type mockClient struct {
   429  	devices      []*wgtypes.Device
   430  	errOnDevices bool
   431  	closeCalled  bool
   432  }
   433  
   434  func (m *mockClient) Devices() ([]*wgtypes.Device, error) {
   435  	if m.errOnDevices {
   436  		return nil, errors.New("mock.Devices() error")
   437  	}
   438  	return m.devices, nil
   439  }
   440  
   441  func (m *mockClient) Close() error {
   442  	m.closeCalled = true
   443  	return nil
   444  }
   445  
   446  func prepareDevice(num uint8) *wgtypes.Device {
   447  	return &wgtypes.Device{
   448  		Name: fmt.Sprintf("wg%d", num),
   449  	}
   450  }
   451  
   452  func preparePeer(s string) wgtypes.Peer {
   453  	b := make([]byte, 32)
   454  	b = append(b[:0], fmt.Sprintf("peer%s", s)...)
   455  	k, _ := wgtypes.NewKey(b[:32])
   456  
   457  	return wgtypes.Peer{
   458  		PublicKey:         k,
   459  		LastHandshakeTime: time.Now().Add(-time.Minute),
   460  		ReceiveBytes:      0,
   461  		TransmitBytes:     0,
   462  	}
   463  }
   464  
   465  func prepareNoLastHandshakePeer(s string) wgtypes.Peer {
   466  	p := preparePeer(s)
   467  	var lh time.Time
   468  	p.LastHandshakeTime = lh
   469  	return p
   470  }
   471  
   472  func copyLatestHandshake(dst, src map[string]int64) {
   473  	for k, v := range src {
   474  		if strings.HasSuffix(k, "latest_handshake_ago") {
   475  			if _, ok := dst[k]; ok {
   476  				dst[k] = v
   477  			}
   478  		}
   479  	}
   480  }
   481  
   482  func calcObsoleteCharts(charts *module.Charts) int {
   483  	var num int
   484  	for _, c := range *charts {
   485  		if c.Obsolete {
   486  			num++
   487  		}
   488  	}
   489  	return num
   490  }