github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/noise_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "testing" 12 13 "github.com/amnezia-vpn/amnezia-wg/conn" 14 "github.com/amnezia-vpn/amnezia-wg/tun/tuntest" 15 ) 16 17 func TestCurveWrappers(t *testing.T) { 18 sk1, err := newPrivateKey() 19 assertNil(t, err) 20 21 sk2, err := newPrivateKey() 22 assertNil(t, err) 23 24 pk1 := sk1.publicKey() 25 pk2 := sk2.publicKey() 26 27 ss1, err1 := sk1.sharedSecret(pk2) 28 ss2, err2 := sk2.sharedSecret(pk1) 29 30 if ss1 != ss2 || err1 != nil || err2 != nil { 31 t.Fatal("Failed to compute shared secet") 32 } 33 } 34 35 func randDevice(t *testing.T) *Device { 36 sk, err := newPrivateKey() 37 if err != nil { 38 t.Fatal(err) 39 } 40 tun := tuntest.NewChannelTUN() 41 logger := NewLogger(LogLevelError, "") 42 device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) 43 device.SetPrivateKey(sk) 44 return device 45 } 46 47 func assertNil(t *testing.T, err error) { 48 if err != nil { 49 t.Fatal(err) 50 } 51 } 52 53 func assertEqual(t *testing.T, a, b []byte) { 54 if !bytes.Equal(a, b) { 55 t.Fatal(a, "!=", b) 56 } 57 } 58 59 func TestNoiseHandshake(t *testing.T) { 60 dev1 := randDevice(t) 61 dev2 := randDevice(t) 62 63 defer dev1.Close() 64 defer dev2.Close() 65 66 peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 if err != nil { 68 t.Fatal(err) 69 } 70 peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 if err != nil { 72 t.Fatal(err) 73 } 74 peer1.Start() 75 peer2.Start() 76 77 assertEqual( 78 t, 79 peer1.handshake.precomputedStaticStatic[:], 80 peer2.handshake.precomputedStaticStatic[:], 81 ) 82 83 /* simulate handshake */ 84 85 // initiation message 86 87 t.Log("exchange initiation message") 88 89 msg1, err := dev1.CreateMessageInitiation(peer2) 90 assertNil(t, err) 91 92 packet := make([]byte, 0, 256) 93 writer := bytes.NewBuffer(packet) 94 err = binary.Write(writer, binary.LittleEndian, msg1) 95 assertNil(t, err) 96 peer := dev2.ConsumeMessageInitiation(msg1) 97 if peer == nil { 98 t.Fatal("handshake failed at initiation message") 99 } 100 101 assertEqual( 102 t, 103 peer1.handshake.chainKey[:], 104 peer2.handshake.chainKey[:], 105 ) 106 107 assertEqual( 108 t, 109 peer1.handshake.hash[:], 110 peer2.handshake.hash[:], 111 ) 112 113 // response message 114 115 t.Log("exchange response message") 116 117 msg2, err := dev2.CreateMessageResponse(peer1) 118 assertNil(t, err) 119 120 peer = dev1.ConsumeMessageResponse(msg2) 121 if peer == nil { 122 t.Fatal("handshake failed at response message") 123 } 124 125 assertEqual( 126 t, 127 peer1.handshake.chainKey[:], 128 peer2.handshake.chainKey[:], 129 ) 130 131 assertEqual( 132 t, 133 peer1.handshake.hash[:], 134 peer2.handshake.hash[:], 135 ) 136 137 // key pairs 138 139 t.Log("deriving keys") 140 141 err = peer1.BeginSymmetricSession() 142 if err != nil { 143 t.Fatal("failed to derive keypair for peer 1", err) 144 } 145 146 err = peer2.BeginSymmetricSession() 147 if err != nil { 148 t.Fatal("failed to derive keypair for peer 2", err) 149 } 150 151 key1 := peer1.keypairs.next.Load() 152 key2 := peer2.keypairs.current 153 154 // encrypting / decryption test 155 156 t.Log("test key pairs") 157 158 func() { 159 testMsg := []byte("wireguard test message 1") 160 var err error 161 var out []byte 162 var nonce [12]byte 163 out = key1.send.Seal(out, nonce[:], testMsg, nil) 164 out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 165 assertNil(t, err) 166 assertEqual(t, out, testMsg) 167 }() 168 169 func() { 170 testMsg := []byte("wireguard test message 2") 171 var err error 172 var out []byte 173 var nonce [12]byte 174 out = key2.send.Seal(out, nonce[:], testMsg, nil) 175 out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 176 assertNil(t, err) 177 assertEqual(t, out, testMsg) 178 }() 179 }