github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/device_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"io"
    13  	"math/rand"
    14  	"net/netip"
    15  	"os"
    16  	"runtime"
    17  	"runtime/pprof"
    18  	"sync"
    19  	"sync/atomic"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/amnezia-vpn/amneziawg-go/conn"
    24  	"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
    25  	"github.com/amnezia-vpn/amneziawg-go/tun"
    26  	"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
    27  )
    28  
    29  // uapiCfg returns a string that contains cfg formatted use with IpcSet.
    30  // cfg is a series of alternating key/value strings.
    31  // uapiCfg exists because editors and humans like to insert
    32  // whitespace into configs, which can cause failures, some of which are silent.
    33  // For example, a leading blank newline causes the remainder
    34  // of the config to be silently ignored.
    35  func uapiCfg(cfg ...string) string {
    36  	if len(cfg)%2 != 0 {
    37  		panic("odd number of args to uapiReader")
    38  	}
    39  	buf := new(bytes.Buffer)
    40  	for i, s := range cfg {
    41  		buf.WriteString(s)
    42  		sep := byte('\n')
    43  		if i%2 == 0 {
    44  			sep = '='
    45  		}
    46  		buf.WriteByte(sep)
    47  	}
    48  	return buf.String()
    49  }
    50  
    51  // genConfigs generates a pair of configs that connect to each other.
    52  // The configs use distinct, probably-usable ports.
    53  func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
    54  	var key1, key2 NoisePrivateKey
    55  	_, err := rand.Read(key1[:])
    56  	if err != nil {
    57  		tb.Errorf("unable to generate private key random bytes: %v", err)
    58  	}
    59  	_, err = rand.Read(key2[:])
    60  	if err != nil {
    61  		tb.Errorf("unable to generate private key random bytes: %v", err)
    62  	}
    63  	pub1, pub2 := key1.publicKey(), key2.publicKey()
    64  
    65  	cfgs[0] = uapiCfg(
    66  		"private_key", hex.EncodeToString(key1[:]),
    67  		"listen_port", "0",
    68  		"replace_peers", "true",
    69  		"public_key", hex.EncodeToString(pub2[:]),
    70  		"protocol_version", "1",
    71  		"replace_allowed_ips", "true",
    72  		"allowed_ip", "1.0.0.2/32",
    73  	)
    74  	endpointCfgs[0] = uapiCfg(
    75  		"public_key", hex.EncodeToString(pub2[:]),
    76  		"endpoint", "127.0.0.1:%d",
    77  	)
    78  	cfgs[1] = uapiCfg(
    79  		"private_key", hex.EncodeToString(key2[:]),
    80  		"listen_port", "0",
    81  		"replace_peers", "true",
    82  		"public_key", hex.EncodeToString(pub1[:]),
    83  		"protocol_version", "1",
    84  		"replace_allowed_ips", "true",
    85  		"allowed_ip", "1.0.0.1/32",
    86  	)
    87  	endpointCfgs[1] = uapiCfg(
    88  		"public_key", hex.EncodeToString(pub1[:]),
    89  		"endpoint", "127.0.0.1:%d",
    90  	)
    91  	return
    92  }
    93  
    94  func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
    95  	var key1, key2 NoisePrivateKey
    96  	_, err := rand.Read(key1[:])
    97  	if err != nil {
    98  		tb.Errorf("unable to generate private key random bytes: %v", err)
    99  	}
   100  	_, err = rand.Read(key2[:])
   101  	if err != nil {
   102  		tb.Errorf("unable to generate private key random bytes: %v", err)
   103  	}
   104  	pub1, pub2 := key1.publicKey(), key2.publicKey()
   105  
   106  	cfgs[0] = uapiCfg(
   107  		"private_key", hex.EncodeToString(key1[:]),
   108  		"listen_port", "0",
   109  		"replace_peers", "true",
   110  		"jc", "5",
   111  		"jmin", "500",
   112  		"jmax", "501",
   113  		"s1", "30",
   114  		"s2", "40",
   115  		"h1", "123456",
   116  		"h2", "67543",
   117  		"h4", "32345",
   118  		"h3", "123123",
   119  		"public_key", hex.EncodeToString(pub2[:]),
   120  		"protocol_version", "1",
   121  		"replace_allowed_ips", "true",
   122  		"allowed_ip", "1.0.0.2/32",
   123  	)
   124  	endpointCfgs[0] = uapiCfg(
   125  		"public_key", hex.EncodeToString(pub2[:]),
   126  		"endpoint", "127.0.0.1:%d",
   127  	)
   128  	cfgs[1] = uapiCfg(
   129  		"private_key", hex.EncodeToString(key2[:]),
   130  		"listen_port", "0",
   131  		"replace_peers", "true",
   132  		"jc", "5",
   133  		"jmin", "500",
   134  		"jmax", "501",
   135  		"s1", "30",
   136  		"s2", "40",
   137  		"h1", "123456",
   138  		"h2", "67543",
   139  		"h4", "32345",
   140  		"h3", "123123",
   141  		"public_key", hex.EncodeToString(pub1[:]),
   142  		"protocol_version", "1",
   143  		"replace_allowed_ips", "true",
   144  		"allowed_ip", "1.0.0.1/32",
   145  	)
   146  	endpointCfgs[1] = uapiCfg(
   147  		"public_key", hex.EncodeToString(pub1[:]),
   148  		"endpoint", "127.0.0.1:%d",
   149  	)
   150  	return
   151  }
   152  
   153  // A testPair is a pair of testPeers.
   154  type testPair [2]testPeer
   155  
   156  // A testPeer is a peer used for testing.
   157  type testPeer struct {
   158  	tun *tuntest.ChannelTUN
   159  	dev *Device
   160  	ip  netip.Addr
   161  }
   162  
   163  type SendDirection bool
   164  
   165  const (
   166  	Ping SendDirection = true
   167  	Pong SendDirection = false
   168  )
   169  
   170  func (d SendDirection) String() string {
   171  	if d == Ping {
   172  		return "ping"
   173  	}
   174  	return "pong"
   175  }
   176  
   177  func (pair *testPair) Send(
   178  	tb testing.TB,
   179  	ping SendDirection,
   180  	done chan struct{},
   181  ) {
   182  	tb.Helper()
   183  	p0, p1 := pair[0], pair[1]
   184  	if !ping {
   185  		// pong is the new ping
   186  		p0, p1 = p1, p0
   187  	}
   188  	msg := tuntest.Ping(p0.ip, p1.ip)
   189  	p1.tun.Outbound <- msg
   190  	timer := time.NewTimer(5 * time.Second)
   191  	defer timer.Stop()
   192  	var err error
   193  	select {
   194  	case msgRecv := <-p0.tun.Inbound:
   195  		if !bytes.Equal(msg, msgRecv) {
   196  			err = fmt.Errorf("%s did not transit correctly", ping)
   197  		}
   198  	case <-timer.C:
   199  		err = fmt.Errorf("%s did not transit", ping)
   200  	case <-done:
   201  	}
   202  	if err != nil {
   203  		// The error may have occurred because the test is done.
   204  		select {
   205  		case <-done:
   206  			return
   207  		default:
   208  		}
   209  		// Real error.
   210  		tb.Error(err)
   211  	}
   212  }
   213  
   214  // genTestPair creates a testPair.
   215  func genTestPair(
   216  	tb testing.TB,
   217  	realSocket, withASecurity bool,
   218  ) (pair testPair) {
   219  	var cfg, endpointCfg [2]string
   220  	if withASecurity {
   221  		cfg, endpointCfg = genASecurityConfigs(tb)
   222  	} else {
   223  		cfg, endpointCfg = genConfigs(tb)
   224  	}
   225  	var binds [2]conn.Bind
   226  	if realSocket {
   227  		binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
   228  	} else {
   229  		binds = bindtest.NewChannelBinds()
   230  	}
   231  	// Bring up a ChannelTun for each config.
   232  	for i := range pair {
   233  		p := &pair[i]
   234  		p.tun = tuntest.NewChannelTUN()
   235  		p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
   236  		level := LogLevelVerbose
   237  		if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
   238  			level = LogLevelError
   239  		}
   240  		p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
   241  		if err := p.dev.IpcSet(cfg[i]); err != nil {
   242  			tb.Errorf("failed to configure device %d: %v", i, err)
   243  			p.dev.Close()
   244  			continue
   245  		}
   246  		if err := p.dev.Up(); err != nil {
   247  			tb.Errorf("failed to bring up device %d: %v", i, err)
   248  			p.dev.Close()
   249  			continue
   250  		}
   251  		endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
   252  	}
   253  	for i := range pair {
   254  		p := &pair[i]
   255  		if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
   256  			tb.Errorf("failed to configure device endpoint %d: %v", i, err)
   257  			p.dev.Close()
   258  			continue
   259  		}
   260  		// The device is ready. Close it when the test completes.
   261  		tb.Cleanup(p.dev.Close)
   262  	}
   263  	return
   264  }
   265  
   266  func TestTwoDevicePing(t *testing.T) {
   267  	goroutineLeakCheck(t)
   268  	pair := genTestPair(t, true, false)
   269  	t.Run("ping 1.0.0.1", func(t *testing.T) {
   270  		pair.Send(t, Ping, nil)
   271  	})
   272  	t.Run("ping 1.0.0.2", func(t *testing.T) {
   273  		pair.Send(t, Pong, nil)
   274  	})
   275  }
   276  
   277  func TestTwoDevicePingASecurity(t *testing.T) {
   278  	goroutineLeakCheck(t)
   279  	pair := genTestPair(t, true, true)
   280  	t.Run("ping 1.0.0.1", func(t *testing.T) {
   281  		pair.Send(t, Ping, nil)
   282  	})
   283  	t.Run("ping 1.0.0.2", func(t *testing.T) {
   284  		pair.Send(t, Pong, nil)
   285  	})
   286  }
   287  
   288  func TestUpDown(t *testing.T) {
   289  	goroutineLeakCheck(t)
   290  	const itrials = 50
   291  	const otrials = 10
   292  
   293  	for n := 0; n < otrials; n++ {
   294  		pair := genTestPair(t, false, false)
   295  		for i := range pair {
   296  			for k := range pair[i].dev.peers.keyMap {
   297  				pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
   298  			}
   299  		}
   300  		var wg sync.WaitGroup
   301  		wg.Add(len(pair))
   302  		for i := range pair {
   303  			go func(d *Device) {
   304  				defer wg.Done()
   305  				for i := 0; i < itrials; i++ {
   306  					if err := d.Up(); err != nil {
   307  						t.Errorf("failed up bring up device: %v", err)
   308  					}
   309  					time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
   310  					if err := d.Down(); err != nil {
   311  						t.Errorf("failed to bring down device: %v", err)
   312  					}
   313  					time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
   314  				}
   315  			}(pair[i].dev)
   316  		}
   317  		wg.Wait()
   318  		for i := range pair {
   319  			pair[i].dev.Up()
   320  			pair[i].dev.Close()
   321  		}
   322  	}
   323  }
   324  
   325  // TestConcurrencySafety does other things concurrently with tunnel use.
   326  // It is intended to be used with the race detector to catch data races.
   327  func TestConcurrencySafety(t *testing.T) {
   328  	pair := genTestPair(t, true, false)
   329  	done := make(chan struct{})
   330  
   331  	const warmupIters = 10
   332  	var warmup sync.WaitGroup
   333  	warmup.Add(warmupIters)
   334  	go func() {
   335  		// Send data continuously back and forth until we're done.
   336  		// Note that we may continue to attempt to send data
   337  		// even after done is closed.
   338  		i := warmupIters
   339  		for ping := Ping; ; ping = !ping {
   340  			pair.Send(t, ping, done)
   341  			select {
   342  			case <-done:
   343  				return
   344  			default:
   345  			}
   346  			if i > 0 {
   347  				warmup.Done()
   348  				i--
   349  			}
   350  		}
   351  	}()
   352  	warmup.Wait()
   353  
   354  	applyCfg := func(cfg string) {
   355  		err := pair[0].dev.IpcSet(cfg)
   356  		if err != nil {
   357  			t.Fatal(err)
   358  		}
   359  	}
   360  
   361  	// Change persistent_keepalive_interval concurrently with tunnel use.
   362  	t.Run("persistentKeepaliveInterval", func(t *testing.T) {
   363  		var pub NoisePublicKey
   364  		for key := range pair[0].dev.peers.keyMap {
   365  			pub = key
   366  			break
   367  		}
   368  		cfg := uapiCfg(
   369  			"public_key", hex.EncodeToString(pub[:]),
   370  			"persistent_keepalive_interval", "1",
   371  		)
   372  		for i := 0; i < 1000; i++ {
   373  			applyCfg(cfg)
   374  		}
   375  	})
   376  
   377  	// Change private keys concurrently with tunnel use.
   378  	t.Run("privateKey", func(t *testing.T) {
   379  		bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
   380  		good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
   381  		// Set iters to a large number like 1000 to flush out data races quickly.
   382  		// Don't leave it large. That can cause logical races
   383  		// in which the handshake is interleaved with key changes
   384  		// such that the private key appears to be unchanging but
   385  		// other state gets reset, which can cause handshake failures like
   386  		// "Received packet with invalid mac1".
   387  		const iters = 1
   388  		for i := 0; i < iters; i++ {
   389  			applyCfg(bad)
   390  			applyCfg(good)
   391  		}
   392  	})
   393  
   394  	// Perform bind updates and keepalive sends concurrently with tunnel use.
   395  	t.Run("bindUpdate and keepalive", func(t *testing.T) {
   396  		const iters = 10
   397  		for i := 0; i < iters; i++ {
   398  			for _, peer := range pair {
   399  				peer.dev.BindUpdate()
   400  				peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
   401  			}
   402  		}
   403  	})
   404  
   405  	close(done)
   406  }
   407  
   408  func BenchmarkLatency(b *testing.B) {
   409  	pair := genTestPair(b, true, false)
   410  
   411  	// Establish a connection.
   412  	pair.Send(b, Ping, nil)
   413  	pair.Send(b, Pong, nil)
   414  
   415  	b.ResetTimer()
   416  	for i := 0; i < b.N; i++ {
   417  		pair.Send(b, Ping, nil)
   418  		pair.Send(b, Pong, nil)
   419  	}
   420  }
   421  
   422  func BenchmarkThroughput(b *testing.B) {
   423  	pair := genTestPair(b, true, false)
   424  
   425  	// Establish a connection.
   426  	pair.Send(b, Ping, nil)
   427  	pair.Send(b, Pong, nil)
   428  
   429  	// Measure how long it takes to receive b.N packets,
   430  	// starting when we receive the first packet.
   431  	var recv atomic.Uint64
   432  	var elapsed time.Duration
   433  	var wg sync.WaitGroup
   434  	wg.Add(1)
   435  	go func() {
   436  		defer wg.Done()
   437  		var start time.Time
   438  		for {
   439  			<-pair[0].tun.Inbound
   440  			new := recv.Add(1)
   441  			if new == 1 {
   442  				start = time.Now()
   443  			}
   444  			// Careful! Don't change this to else if; b.N can be equal to 1.
   445  			if new == uint64(b.N) {
   446  				elapsed = time.Since(start)
   447  				return
   448  			}
   449  		}
   450  	}()
   451  
   452  	// Send packets as fast as we can until we've received enough.
   453  	ping := tuntest.Ping(pair[0].ip, pair[1].ip)
   454  	pingc := pair[1].tun.Outbound
   455  	var sent uint64
   456  	for recv.Load() != uint64(b.N) {
   457  		sent++
   458  		pingc <- ping
   459  	}
   460  	wg.Wait()
   461  
   462  	b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
   463  	b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
   464  }
   465  
   466  func BenchmarkUAPIGet(b *testing.B) {
   467  	pair := genTestPair(b, true, false)
   468  	pair.Send(b, Ping, nil)
   469  	pair.Send(b, Pong, nil)
   470  	b.ReportAllocs()
   471  	b.ResetTimer()
   472  	for i := 0; i < b.N; i++ {
   473  		pair[0].dev.IpcGetOperation(io.Discard)
   474  	}
   475  }
   476  
   477  func goroutineLeakCheck(t *testing.T) {
   478  	goroutines := func() (int, []byte) {
   479  		p := pprof.Lookup("goroutine")
   480  		b := new(bytes.Buffer)
   481  		p.WriteTo(b, 1)
   482  		return p.Count(), b.Bytes()
   483  	}
   484  
   485  	startGoroutines, startStacks := goroutines()
   486  	t.Cleanup(func() {
   487  		if t.Failed() {
   488  			return
   489  		}
   490  		// Give goroutines time to exit, if they need it.
   491  		for i := 0; i < 10000; i++ {
   492  			if runtime.NumGoroutine() <= startGoroutines {
   493  				return
   494  			}
   495  			time.Sleep(1 * time.Millisecond)
   496  		}
   497  		endGoroutines, endStacks := goroutines()
   498  		t.Logf("starting stacks:\n%s\n", startStacks)
   499  		t.Logf("ending stacks:\n%s\n", endStacks)
   500  		t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
   501  	})
   502  }
   503  
   504  type fakeBindSized struct {
   505  	size int
   506  }
   507  
   508  func (b *fakeBindSized) Open(
   509  	port uint16,
   510  ) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
   511  	return nil, 0, nil
   512  }
   513  
   514  func (b *fakeBindSized) Close() error { return nil }
   515  
   516  func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
   517  
   518  func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
   519  
   520  func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
   521  
   522  func (b *fakeBindSized) BatchSize() int { return b.size }
   523  
   524  type fakeTUNDeviceSized struct {
   525  	size int
   526  }
   527  
   528  func (t *fakeTUNDeviceSized) File() *os.File { return nil }
   529  
   530  func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
   531  	return 0, nil
   532  }
   533  
   534  func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
   535  
   536  func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
   537  
   538  func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
   539  
   540  func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
   541  
   542  func (t *fakeTUNDeviceSized) Close() error { return nil }
   543  
   544  func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
   545  
   546  func TestBatchSize(t *testing.T) {
   547  	d := Device{}
   548  
   549  	d.net.bind = &fakeBindSized{1}
   550  	d.tun.device = &fakeTUNDeviceSized{1}
   551  	if want, got := 1, d.BatchSize(); got != want {
   552  		t.Errorf("expected batch size %d, got %d", want, got)
   553  	}
   554  
   555  	d.net.bind = &fakeBindSized{1}
   556  	d.tun.device = &fakeTUNDeviceSized{128}
   557  	if want, got := 128, d.BatchSize(); got != want {
   558  		t.Errorf("expected batch size %d, got %d", want, got)
   559  	}
   560  
   561  	d.net.bind = &fakeBindSized{128}
   562  	d.tun.device = &fakeTUNDeviceSized{1}
   563  	if want, got := 128, d.BatchSize(); got != want {
   564  		t.Errorf("expected batch size %d, got %d", want, got)
   565  	}
   566  
   567  	d.net.bind = &fakeBindSized{128}
   568  	d.tun.device = &fakeTUNDeviceSized{128}
   569  	if want, got := 128, d.BatchSize(); got != want {
   570  		t.Errorf("expected batch size %d, got %d", want, got)
   571  	}
   572  }