github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/noise_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"testing"
    12  
    13  	"github.com/amnezia-vpn/amneziawg-go/conn"
    14  	"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
    15  )
    16  
    17  func TestCurveWrappers(t *testing.T) {
    18  	sk1, err := newPrivateKey()
    19  	assertNil(t, err)
    20  
    21  	sk2, err := newPrivateKey()
    22  	assertNil(t, err)
    23  
    24  	pk1 := sk1.publicKey()
    25  	pk2 := sk2.publicKey()
    26  
    27  	ss1, err1 := sk1.sharedSecret(pk2)
    28  	ss2, err2 := sk2.sharedSecret(pk1)
    29  
    30  	if ss1 != ss2 || err1 != nil || err2 != nil {
    31  		t.Fatal("Failed to compute shared secet")
    32  	}
    33  }
    34  
    35  func randDevice(t *testing.T) *Device {
    36  	sk, err := newPrivateKey()
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	tun := tuntest.NewChannelTUN()
    41  	logger := NewLogger(LogLevelError, "")
    42  	device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
    43  	device.SetPrivateKey(sk)
    44  	return device
    45  }
    46  
    47  func assertNil(t *testing.T, err error) {
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  }
    52  
    53  func assertEqual(t *testing.T, a, b []byte) {
    54  	if !bytes.Equal(a, b) {
    55  		t.Fatal(a, "!=", b)
    56  	}
    57  }
    58  
    59  func TestNoiseHandshake(t *testing.T) {
    60  	dev1 := randDevice(t)
    61  	dev2 := randDevice(t)
    62  
    63  	defer dev1.Close()
    64  	defer dev2.Close()
    65  
    66  	peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	peer1.Start()
    75  	peer2.Start()
    76  
    77  	assertEqual(
    78  		t,
    79  		peer1.handshake.precomputedStaticStatic[:],
    80  		peer2.handshake.precomputedStaticStatic[:],
    81  	)
    82  
    83  	/* simulate handshake */
    84  
    85  	// initiation message
    86  
    87  	t.Log("exchange initiation message")
    88  
    89  	msg1, err := dev1.CreateMessageInitiation(peer2)
    90  	assertNil(t, err)
    91  
    92  	packet := make([]byte, 0, 256)
    93  	writer := bytes.NewBuffer(packet)
    94  	err = binary.Write(writer, binary.LittleEndian, msg1)
    95  	assertNil(t, err)
    96  	peer := dev2.ConsumeMessageInitiation(msg1)
    97  	if peer == nil {
    98  		t.Fatal("handshake failed at initiation message")
    99  	}
   100  
   101  	assertEqual(
   102  		t,
   103  		peer1.handshake.chainKey[:],
   104  		peer2.handshake.chainKey[:],
   105  	)
   106  
   107  	assertEqual(
   108  		t,
   109  		peer1.handshake.hash[:],
   110  		peer2.handshake.hash[:],
   111  	)
   112  
   113  	// response message
   114  
   115  	t.Log("exchange response message")
   116  
   117  	msg2, err := dev2.CreateMessageResponse(peer1)
   118  	assertNil(t, err)
   119  
   120  	peer = dev1.ConsumeMessageResponse(msg2)
   121  	if peer == nil {
   122  		t.Fatal("handshake failed at response message")
   123  	}
   124  
   125  	assertEqual(
   126  		t,
   127  		peer1.handshake.chainKey[:],
   128  		peer2.handshake.chainKey[:],
   129  	)
   130  
   131  	assertEqual(
   132  		t,
   133  		peer1.handshake.hash[:],
   134  		peer2.handshake.hash[:],
   135  	)
   136  
   137  	// key pairs
   138  
   139  	t.Log("deriving keys")
   140  
   141  	err = peer1.BeginSymmetricSession()
   142  	if err != nil {
   143  		t.Fatal("failed to derive keypair for peer 1", err)
   144  	}
   145  
   146  	err = peer2.BeginSymmetricSession()
   147  	if err != nil {
   148  		t.Fatal("failed to derive keypair for peer 2", err)
   149  	}
   150  
   151  	key1 := peer1.keypairs.next.Load()
   152  	key2 := peer2.keypairs.current
   153  
   154  	// encrypting / decryption test
   155  
   156  	t.Log("test key pairs")
   157  
   158  	func() {
   159  		testMsg := []byte("wireguard test message 1")
   160  		var err error
   161  		var out []byte
   162  		var nonce [12]byte
   163  		out = key1.send.Seal(out, nonce[:], testMsg, nil)
   164  		out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
   165  		assertNil(t, err)
   166  		assertEqual(t, out, testMsg)
   167  	}()
   168  
   169  	func() {
   170  		testMsg := []byte("wireguard test message 2")
   171  		var err error
   172  		var out []byte
   173  		var nonce [12]byte
   174  		out = key2.send.Seal(out, nonce[:], testMsg, nil)
   175  		out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
   176  		assertNil(t, err)
   177  		assertEqual(t, out, testMsg)
   178  	}()
   179  }