github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/device_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"io"
    13  	"math/rand"
    14  	"net/netip"
    15  	"os"
    16  	"runtime"
    17  	"runtime/pprof"
    18  	"sync"
    19  	"sync/atomic"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/koomox/wireguard-go/conn"
    24  	"github.com/koomox/wireguard-go/conn/bindtest"
    25  	"github.com/koomox/wireguard-go/tun"
    26  	"github.com/koomox/wireguard-go/tun/tuntest"
    27  )
    28  
    29  // uapiCfg returns a string that contains cfg formatted use with IpcSet.
    30  // cfg is a series of alternating key/value strings.
    31  // uapiCfg exists because editors and humans like to insert
    32  // whitespace into configs, which can cause failures, some of which are silent.
    33  // For example, a leading blank newline causes the remainder
    34  // of the config to be silently ignored.
    35  func uapiCfg(cfg ...string) string {
    36  	if len(cfg)%2 != 0 {
    37  		panic("odd number of args to uapiReader")
    38  	}
    39  	buf := new(bytes.Buffer)
    40  	for i, s := range cfg {
    41  		buf.WriteString(s)
    42  		sep := byte('\n')
    43  		if i%2 == 0 {
    44  			sep = '='
    45  		}
    46  		buf.WriteByte(sep)
    47  	}
    48  	return buf.String()
    49  }
    50  
    51  // genConfigs generates a pair of configs that connect to each other.
    52  // The configs use distinct, probably-usable ports.
    53  func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
    54  	var key1, key2 NoisePrivateKey
    55  	_, err := rand.Read(key1[:])
    56  	if err != nil {
    57  		tb.Errorf("unable to generate private key random bytes: %v", err)
    58  	}
    59  	_, err = rand.Read(key2[:])
    60  	if err != nil {
    61  		tb.Errorf("unable to generate private key random bytes: %v", err)
    62  	}
    63  	pub1, pub2 := key1.publicKey(), key2.publicKey()
    64  
    65  	cfgs[0] = uapiCfg(
    66  		"private_key", hex.EncodeToString(key1[:]),
    67  		"listen_port", "0",
    68  		"replace_peers", "true",
    69  		"public_key", hex.EncodeToString(pub2[:]),
    70  		"protocol_version", "1",
    71  		"replace_allowed_ips", "true",
    72  		"allowed_ip", "1.0.0.2/32",
    73  	)
    74  	endpointCfgs[0] = uapiCfg(
    75  		"public_key", hex.EncodeToString(pub2[:]),
    76  		"endpoint", "127.0.0.1:%d",
    77  	)
    78  	cfgs[1] = uapiCfg(
    79  		"private_key", hex.EncodeToString(key2[:]),
    80  		"listen_port", "0",
    81  		"replace_peers", "true",
    82  		"public_key", hex.EncodeToString(pub1[:]),
    83  		"protocol_version", "1",
    84  		"replace_allowed_ips", "true",
    85  		"allowed_ip", "1.0.0.1/32",
    86  	)
    87  	endpointCfgs[1] = uapiCfg(
    88  		"public_key", hex.EncodeToString(pub1[:]),
    89  		"endpoint", "127.0.0.1:%d",
    90  	)
    91  	return
    92  }
    93  
    94  // A testPair is a pair of testPeers.
    95  type testPair [2]testPeer
    96  
    97  // A testPeer is a peer used for testing.
    98  type testPeer struct {
    99  	tun *tuntest.ChannelTUN
   100  	dev *Device
   101  	ip  netip.Addr
   102  }
   103  
   104  type SendDirection bool
   105  
   106  const (
   107  	Ping SendDirection = true
   108  	Pong SendDirection = false
   109  )
   110  
   111  func (d SendDirection) String() string {
   112  	if d == Ping {
   113  		return "ping"
   114  	}
   115  	return "pong"
   116  }
   117  
   118  func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
   119  	tb.Helper()
   120  	p0, p1 := pair[0], pair[1]
   121  	if !ping {
   122  		// pong is the new ping
   123  		p0, p1 = p1, p0
   124  	}
   125  	msg := tuntest.Ping(p0.ip, p1.ip)
   126  	p1.tun.Outbound <- msg
   127  	timer := time.NewTimer(5 * time.Second)
   128  	defer timer.Stop()
   129  	var err error
   130  	select {
   131  	case msgRecv := <-p0.tun.Inbound:
   132  		if !bytes.Equal(msg, msgRecv) {
   133  			err = fmt.Errorf("%s did not transit correctly", ping)
   134  		}
   135  	case <-timer.C:
   136  		err = fmt.Errorf("%s did not transit", ping)
   137  	case <-done:
   138  	}
   139  	if err != nil {
   140  		// The error may have occurred because the test is done.
   141  		select {
   142  		case <-done:
   143  			return
   144  		default:
   145  		}
   146  		// Real error.
   147  		tb.Error(err)
   148  	}
   149  }
   150  
   151  // genTestPair creates a testPair.
   152  func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
   153  	cfg, endpointCfg := genConfigs(tb)
   154  	var binds [2]conn.Bind
   155  	if realSocket {
   156  		binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
   157  	} else {
   158  		binds = bindtest.NewChannelBinds()
   159  	}
   160  	// Bring up a ChannelTun for each config.
   161  	for i := range pair {
   162  		p := &pair[i]
   163  		p.tun = tuntest.NewChannelTUN()
   164  		p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
   165  		level := LogLevelVerbose
   166  		if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
   167  			level = LogLevelError
   168  		}
   169  		p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
   170  		if err := p.dev.IpcSet(cfg[i]); err != nil {
   171  			tb.Errorf("failed to configure device %d: %v", i, err)
   172  			p.dev.Close()
   173  			continue
   174  		}
   175  		if err := p.dev.Up(); err != nil {
   176  			tb.Errorf("failed to bring up device %d: %v", i, err)
   177  			p.dev.Close()
   178  			continue
   179  		}
   180  		endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
   181  	}
   182  	for i := range pair {
   183  		p := &pair[i]
   184  		if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
   185  			tb.Errorf("failed to configure device endpoint %d: %v", i, err)
   186  			p.dev.Close()
   187  			continue
   188  		}
   189  		// The device is ready. Close it when the test completes.
   190  		tb.Cleanup(p.dev.Close)
   191  	}
   192  	return
   193  }
   194  
   195  func TestTwoDevicePing(t *testing.T) {
   196  	goroutineLeakCheck(t)
   197  	pair := genTestPair(t, true)
   198  	t.Run("ping 1.0.0.1", func(t *testing.T) {
   199  		pair.Send(t, Ping, nil)
   200  	})
   201  	t.Run("ping 1.0.0.2", func(t *testing.T) {
   202  		pair.Send(t, Pong, nil)
   203  	})
   204  }
   205  
   206  func TestUpDown(t *testing.T) {
   207  	goroutineLeakCheck(t)
   208  	const itrials = 50
   209  	const otrials = 10
   210  
   211  	for n := 0; n < otrials; n++ {
   212  		pair := genTestPair(t, false)
   213  		for i := range pair {
   214  			for k := range pair[i].dev.peers.keyMap {
   215  				pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
   216  			}
   217  		}
   218  		var wg sync.WaitGroup
   219  		wg.Add(len(pair))
   220  		for i := range pair {
   221  			go func(d *Device) {
   222  				defer wg.Done()
   223  				for i := 0; i < itrials; i++ {
   224  					if err := d.Up(); err != nil {
   225  						t.Errorf("failed up bring up device: %v", err)
   226  					}
   227  					time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
   228  					if err := d.Down(); err != nil {
   229  						t.Errorf("failed to bring down device: %v", err)
   230  					}
   231  					time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
   232  				}
   233  			}(pair[i].dev)
   234  		}
   235  		wg.Wait()
   236  		for i := range pair {
   237  			pair[i].dev.Up()
   238  			pair[i].dev.Close()
   239  		}
   240  	}
   241  }
   242  
   243  // TestConcurrencySafety does other things concurrently with tunnel use.
   244  // It is intended to be used with the race detector to catch data races.
   245  func TestConcurrencySafety(t *testing.T) {
   246  	pair := genTestPair(t, true)
   247  	done := make(chan struct{})
   248  
   249  	const warmupIters = 10
   250  	var warmup sync.WaitGroup
   251  	warmup.Add(warmupIters)
   252  	go func() {
   253  		// Send data continuously back and forth until we're done.
   254  		// Note that we may continue to attempt to send data
   255  		// even after done is closed.
   256  		i := warmupIters
   257  		for ping := Ping; ; ping = !ping {
   258  			pair.Send(t, ping, done)
   259  			select {
   260  			case <-done:
   261  				return
   262  			default:
   263  			}
   264  			if i > 0 {
   265  				warmup.Done()
   266  				i--
   267  			}
   268  		}
   269  	}()
   270  	warmup.Wait()
   271  
   272  	applyCfg := func(cfg string) {
   273  		err := pair[0].dev.IpcSet(cfg)
   274  		if err != nil {
   275  			t.Fatal(err)
   276  		}
   277  	}
   278  
   279  	// Change persistent_keepalive_interval concurrently with tunnel use.
   280  	t.Run("persistentKeepaliveInterval", func(t *testing.T) {
   281  		var pub NoisePublicKey
   282  		for key := range pair[0].dev.peers.keyMap {
   283  			pub = key
   284  			break
   285  		}
   286  		cfg := uapiCfg(
   287  			"public_key", hex.EncodeToString(pub[:]),
   288  			"persistent_keepalive_interval", "1",
   289  		)
   290  		for i := 0; i < 1000; i++ {
   291  			applyCfg(cfg)
   292  		}
   293  	})
   294  
   295  	// Change private keys concurrently with tunnel use.
   296  	t.Run("privateKey", func(t *testing.T) {
   297  		bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
   298  		good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
   299  		// Set iters to a large number like 1000 to flush out data races quickly.
   300  		// Don't leave it large. That can cause logical races
   301  		// in which the handshake is interleaved with key changes
   302  		// such that the private key appears to be unchanging but
   303  		// other state gets reset, which can cause handshake failures like
   304  		// "Received packet with invalid mac1".
   305  		const iters = 1
   306  		for i := 0; i < iters; i++ {
   307  			applyCfg(bad)
   308  			applyCfg(good)
   309  		}
   310  	})
   311  
   312  	// Perform bind updates and keepalive sends concurrently with tunnel use.
   313  	t.Run("bindUpdate and keepalive", func(t *testing.T) {
   314  		const iters = 10
   315  		for i := 0; i < iters; i++ {
   316  			for _, peer := range pair {
   317  				peer.dev.BindUpdate()
   318  				peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
   319  			}
   320  		}
   321  	})
   322  
   323  	close(done)
   324  }
   325  
   326  func BenchmarkLatency(b *testing.B) {
   327  	pair := genTestPair(b, true)
   328  
   329  	// Establish a connection.
   330  	pair.Send(b, Ping, nil)
   331  	pair.Send(b, Pong, nil)
   332  
   333  	b.ResetTimer()
   334  	for i := 0; i < b.N; i++ {
   335  		pair.Send(b, Ping, nil)
   336  		pair.Send(b, Pong, nil)
   337  	}
   338  }
   339  
   340  func BenchmarkThroughput(b *testing.B) {
   341  	pair := genTestPair(b, true)
   342  
   343  	// Establish a connection.
   344  	pair.Send(b, Ping, nil)
   345  	pair.Send(b, Pong, nil)
   346  
   347  	// Measure how long it takes to receive b.N packets,
   348  	// starting when we receive the first packet.
   349  	var recv atomic.Uint64
   350  	var elapsed time.Duration
   351  	var wg sync.WaitGroup
   352  	wg.Add(1)
   353  	go func() {
   354  		defer wg.Done()
   355  		var start time.Time
   356  		for {
   357  			<-pair[0].tun.Inbound
   358  			new := recv.Add(1)
   359  			if new == 1 {
   360  				start = time.Now()
   361  			}
   362  			// Careful! Don't change this to else if; b.N can be equal to 1.
   363  			if new == uint64(b.N) {
   364  				elapsed = time.Since(start)
   365  				return
   366  			}
   367  		}
   368  	}()
   369  
   370  	// Send packets as fast as we can until we've received enough.
   371  	ping := tuntest.Ping(pair[0].ip, pair[1].ip)
   372  	pingc := pair[1].tun.Outbound
   373  	var sent uint64
   374  	for recv.Load() != uint64(b.N) {
   375  		sent++
   376  		pingc <- ping
   377  	}
   378  	wg.Wait()
   379  
   380  	b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
   381  	b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
   382  }
   383  
   384  func BenchmarkUAPIGet(b *testing.B) {
   385  	pair := genTestPair(b, true)
   386  	pair.Send(b, Ping, nil)
   387  	pair.Send(b, Pong, nil)
   388  	b.ReportAllocs()
   389  	b.ResetTimer()
   390  	for i := 0; i < b.N; i++ {
   391  		pair[0].dev.IpcGetOperation(io.Discard)
   392  	}
   393  }
   394  
   395  func goroutineLeakCheck(t *testing.T) {
   396  	goroutines := func() (int, []byte) {
   397  		p := pprof.Lookup("goroutine")
   398  		b := new(bytes.Buffer)
   399  		p.WriteTo(b, 1)
   400  		return p.Count(), b.Bytes()
   401  	}
   402  
   403  	startGoroutines, startStacks := goroutines()
   404  	t.Cleanup(func() {
   405  		if t.Failed() {
   406  			return
   407  		}
   408  		// Give goroutines time to exit, if they need it.
   409  		for i := 0; i < 10000; i++ {
   410  			if runtime.NumGoroutine() <= startGoroutines {
   411  				return
   412  			}
   413  			time.Sleep(1 * time.Millisecond)
   414  		}
   415  		endGoroutines, endStacks := goroutines()
   416  		t.Logf("starting stacks:\n%s\n", startStacks)
   417  		t.Logf("ending stacks:\n%s\n", endStacks)
   418  		t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
   419  	})
   420  }
   421  
   422  type fakeBindSized struct {
   423  	size int
   424  }
   425  
   426  func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
   427  	return nil, 0, nil
   428  }
   429  func (b *fakeBindSized) Close() error                                  { return nil }
   430  func (b *fakeBindSized) SetMark(mark uint32) error                     { return nil }
   431  func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error    { return nil }
   432  func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
   433  func (b *fakeBindSized) BatchSize() int                                { return b.size }
   434  
   435  type fakeTUNDeviceSized struct {
   436  	size int
   437  }
   438  
   439  func (t *fakeTUNDeviceSized) File() *os.File { return nil }
   440  func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
   441  	return 0, nil
   442  }
   443  func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
   444  func (t *fakeTUNDeviceSized) MTU() (int, error)                            { return 0, nil }
   445  func (t *fakeTUNDeviceSized) Name() (string, error)                        { return "", nil }
   446  func (t *fakeTUNDeviceSized) Events() <-chan tun.Event                     { return nil }
   447  func (t *fakeTUNDeviceSized) Close() error                                 { return nil }
   448  func (t *fakeTUNDeviceSized) BatchSize() int                               { return t.size }
   449  
   450  func TestBatchSize(t *testing.T) {
   451  	d := Device{}
   452  
   453  	d.net.bind = &fakeBindSized{1}
   454  	d.tun.device = &fakeTUNDeviceSized{1}
   455  	if want, got := 1, d.BatchSize(); got != want {
   456  		t.Errorf("expected batch size %d, got %d", want, got)
   457  	}
   458  
   459  	d.net.bind = &fakeBindSized{1}
   460  	d.tun.device = &fakeTUNDeviceSized{128}
   461  	if want, got := 128, d.BatchSize(); got != want {
   462  		t.Errorf("expected batch size %d, got %d", want, got)
   463  	}
   464  
   465  	d.net.bind = &fakeBindSized{128}
   466  	d.tun.device = &fakeTUNDeviceSized{1}
   467  	if want, got := 128, d.BatchSize(); got != want {
   468  		t.Errorf("expected batch size %d, got %d", want, got)
   469  	}
   470  
   471  	d.net.bind = &fakeBindSized{128}
   472  	d.tun.device = &fakeTUNDeviceSized{128}
   473  	if want, got := 128, d.BatchSize(); got != want {
   474  		t.Errorf("expected batch size %d, got %d", want, got)
   475  	}
   476  }