github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/noise_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"testing"
    12  
    13  	"github.com/tailscale/wireguard-go/conn"
    14  	"github.com/tailscale/wireguard-go/tun/tuntest"
    15  )
    16  
    17  func TestCurveWrappers(t *testing.T) {
    18  	sk1, err := newPrivateKey()
    19  	assertNil(t, err)
    20  
    21  	sk2, err := newPrivateKey()
    22  	assertNil(t, err)
    23  
    24  	pk1 := sk1.publicKey()
    25  	pk2 := sk2.publicKey()
    26  
    27  	ss1 := sk1.sharedSecret(pk2)
    28  	ss2 := sk2.sharedSecret(pk1)
    29  
    30  	if ss1 != ss2 {
    31  		t.Fatal("Failed to compute shared secet")
    32  	}
    33  }
    34  
    35  func randDevice(t *testing.T) *Device {
    36  	sk, err := newPrivateKey()
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	tun := tuntest.NewChannelTUN()
    41  	logger := NewLogger(LogLevelError, "")
    42  	device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
    43  	device.SetPrivateKey(sk)
    44  	return device
    45  }
    46  
    47  func assertNil(t *testing.T, err error) {
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  }
    52  
    53  func assertEqual(t *testing.T, a, b []byte) {
    54  	if !bytes.Equal(a, b) {
    55  		t.Fatal(a, "!=", b)
    56  	}
    57  }
    58  
    59  func TestNoiseHandshake(t *testing.T) {
    60  	dev1 := randDevice(t)
    61  	dev2 := randDevice(t)
    62  
    63  	defer dev1.Close()
    64  	defer dev2.Close()
    65  
    66  	peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  
    75  	assertEqual(
    76  		t,
    77  		peer1.handshake.precomputedStaticStatic[:],
    78  		peer2.handshake.precomputedStaticStatic[:],
    79  	)
    80  
    81  	/* simulate handshake */
    82  
    83  	// initiation message
    84  
    85  	t.Log("exchange initiation message")
    86  
    87  	msg1, err := dev1.CreateMessageInitiation(peer2)
    88  	assertNil(t, err)
    89  
    90  	packet := make([]byte, 0, 256)
    91  	writer := bytes.NewBuffer(packet)
    92  	err = binary.Write(writer, binary.LittleEndian, msg1)
    93  	assertNil(t, err)
    94  	peer := dev2.ConsumeMessageInitiation(msg1)
    95  	if peer == nil {
    96  		t.Fatal("handshake failed at initiation message")
    97  	}
    98  
    99  	assertEqual(
   100  		t,
   101  		peer1.handshake.chainKey[:],
   102  		peer2.handshake.chainKey[:],
   103  	)
   104  
   105  	assertEqual(
   106  		t,
   107  		peer1.handshake.hash[:],
   108  		peer2.handshake.hash[:],
   109  	)
   110  
   111  	// response message
   112  
   113  	t.Log("exchange response message")
   114  
   115  	msg2, err := dev2.CreateMessageResponse(peer1)
   116  	assertNil(t, err)
   117  
   118  	peer = dev1.ConsumeMessageResponse(msg2)
   119  	if peer == nil {
   120  		t.Fatal("handshake failed at response message")
   121  	}
   122  
   123  	assertEqual(
   124  		t,
   125  		peer1.handshake.chainKey[:],
   126  		peer2.handshake.chainKey[:],
   127  	)
   128  
   129  	assertEqual(
   130  		t,
   131  		peer1.handshake.hash[:],
   132  		peer2.handshake.hash[:],
   133  	)
   134  
   135  	// key pairs
   136  
   137  	t.Log("deriving keys")
   138  
   139  	err = peer1.BeginSymmetricSession()
   140  	if err != nil {
   141  		t.Fatal("failed to derive keypair for peer 1", err)
   142  	}
   143  
   144  	err = peer2.BeginSymmetricSession()
   145  	if err != nil {
   146  		t.Fatal("failed to derive keypair for peer 2", err)
   147  	}
   148  
   149  	key1 := peer1.keypairs.loadNext()
   150  	key2 := peer2.keypairs.current
   151  
   152  	// encrypting / decryption test
   153  
   154  	t.Log("test key pairs")
   155  
   156  	func() {
   157  		testMsg := []byte("wireguard test message 1")
   158  		var err error
   159  		var out []byte
   160  		var nonce [12]byte
   161  		out = key1.send.Seal(out, nonce[:], testMsg, nil)
   162  		out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
   163  		assertNil(t, err)
   164  		assertEqual(t, out, testMsg)
   165  	}()
   166  
   167  	func() {
   168  		testMsg := []byte("wireguard test message 2")
   169  		var err error
   170  		var out []byte
   171  		var nonce [12]byte
   172  		out = key2.send.Seal(out, nonce[:], testMsg, nil)
   173  		out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
   174  		assertNil(t, err)
   175  		assertEqual(t, out, testMsg)
   176  	}()
   177  }