github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/cookie.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"crypto/hmac"
    10  	"crypto/rand"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  )
    17  
    18  type CookieChecker struct {
    19  	sync.RWMutex
    20  	mac1 struct {
    21  		key [blake2s.Size]byte
    22  	}
    23  	mac2 struct {
    24  		secret        [blake2s.Size]byte
    25  		secretSet     time.Time
    26  		encryptionKey [chacha20poly1305.KeySize]byte
    27  	}
    28  }
    29  
    30  type CookieGenerator struct {
    31  	sync.RWMutex
    32  	mac1 struct {
    33  		key [blake2s.Size]byte
    34  	}
    35  	mac2 struct {
    36  		cookie        [blake2s.Size128]byte
    37  		cookieSet     time.Time
    38  		hasLastMAC1   bool
    39  		lastMAC1      [blake2s.Size128]byte
    40  		encryptionKey [chacha20poly1305.KeySize]byte
    41  	}
    42  }
    43  
    44  func (st *CookieChecker) Init(pk NoisePublicKey) {
    45  	st.Lock()
    46  	defer st.Unlock()
    47  
    48  	// mac1 state
    49  
    50  	func() {
    51  		hash, _ := blake2s.New256(nil)
    52  		hash.Write([]byte(WGLabelMAC1))
    53  		hash.Write(pk[:])
    54  		hash.Sum(st.mac1.key[:0])
    55  	}()
    56  
    57  	// mac2 state
    58  
    59  	func() {
    60  		hash, _ := blake2s.New256(nil)
    61  		hash.Write([]byte(WGLabelCookie))
    62  		hash.Write(pk[:])
    63  		hash.Sum(st.mac2.encryptionKey[:0])
    64  	}()
    65  
    66  	st.mac2.secretSet = time.Time{}
    67  }
    68  
    69  func (st *CookieChecker) CheckMAC1(msg []byte) bool {
    70  	st.RLock()
    71  	defer st.RUnlock()
    72  
    73  	size := len(msg)
    74  	smac2 := size - blake2s.Size128
    75  	smac1 := smac2 - blake2s.Size128
    76  
    77  	var mac1 [blake2s.Size128]byte
    78  
    79  	mac, _ := blake2s.New128(st.mac1.key[:])
    80  	mac.Write(msg[:smac1])
    81  	mac.Sum(mac1[:0])
    82  
    83  	return hmac.Equal(mac1[:], msg[smac1:smac2])
    84  }
    85  
    86  func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
    87  	st.RLock()
    88  	defer st.RUnlock()
    89  
    90  	if time.Since(st.mac2.secretSet) > CookieRefreshTime {
    91  		return false
    92  	}
    93  
    94  	// derive cookie key
    95  
    96  	var cookie [blake2s.Size128]byte
    97  	func() {
    98  		mac, _ := blake2s.New128(st.mac2.secret[:])
    99  		mac.Write(src)
   100  		mac.Sum(cookie[:0])
   101  	}()
   102  
   103  	// calculate mac of packet (including mac1)
   104  
   105  	smac2 := len(msg) - blake2s.Size128
   106  
   107  	var mac2 [blake2s.Size128]byte
   108  	func() {
   109  		mac, _ := blake2s.New128(cookie[:])
   110  		mac.Write(msg[:smac2])
   111  		mac.Sum(mac2[:0])
   112  	}()
   113  
   114  	return hmac.Equal(mac2[:], msg[smac2:])
   115  }
   116  
   117  func (st *CookieChecker) CreateReply(
   118  	msg []byte,
   119  	recv uint32,
   120  	src []byte,
   121  ) (*MessageCookieReply, error) {
   122  	st.RLock()
   123  
   124  	// refresh cookie secret
   125  
   126  	if time.Since(st.mac2.secretSet) > CookieRefreshTime {
   127  		st.RUnlock()
   128  		st.Lock()
   129  		_, err := rand.Read(st.mac2.secret[:])
   130  		if err != nil {
   131  			st.Unlock()
   132  			return nil, err
   133  		}
   134  		st.mac2.secretSet = time.Now()
   135  		st.Unlock()
   136  		st.RLock()
   137  	}
   138  
   139  	// derive cookie
   140  
   141  	var cookie [blake2s.Size128]byte
   142  	func() {
   143  		mac, _ := blake2s.New128(st.mac2.secret[:])
   144  		mac.Write(src)
   145  		mac.Sum(cookie[:0])
   146  	}()
   147  
   148  	// encrypt cookie
   149  
   150  	size := len(msg)
   151  
   152  	smac2 := size - blake2s.Size128
   153  	smac1 := smac2 - blake2s.Size128
   154  
   155  	reply := new(MessageCookieReply)
   156  	reply.Type = MessageCookieReplyType
   157  	reply.Receiver = recv
   158  
   159  	_, err := rand.Read(reply.Nonce[:])
   160  	if err != nil {
   161  		st.RUnlock()
   162  		return nil, err
   163  	}
   164  
   165  	xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
   166  	xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
   167  
   168  	st.RUnlock()
   169  
   170  	return reply, nil
   171  }
   172  
   173  func (st *CookieGenerator) Init(pk NoisePublicKey) {
   174  	st.Lock()
   175  	defer st.Unlock()
   176  
   177  	func() {
   178  		hash, _ := blake2s.New256(nil)
   179  		hash.Write([]byte(WGLabelMAC1))
   180  		hash.Write(pk[:])
   181  		hash.Sum(st.mac1.key[:0])
   182  	}()
   183  
   184  	func() {
   185  		hash, _ := blake2s.New256(nil)
   186  		hash.Write([]byte(WGLabelCookie))
   187  		hash.Write(pk[:])
   188  		hash.Sum(st.mac2.encryptionKey[:0])
   189  	}()
   190  
   191  	st.mac2.cookieSet = time.Time{}
   192  }
   193  
   194  func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
   195  	st.Lock()
   196  	defer st.Unlock()
   197  
   198  	if !st.mac2.hasLastMAC1 {
   199  		return false
   200  	}
   201  
   202  	var cookie [blake2s.Size128]byte
   203  
   204  	xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
   205  	_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
   206  	if err != nil {
   207  		return false
   208  	}
   209  
   210  	st.mac2.cookieSet = time.Now()
   211  	st.mac2.cookie = cookie
   212  	return true
   213  }
   214  
   215  func (st *CookieGenerator) AddMacs(msg []byte) {
   216  	size := len(msg)
   217  
   218  	smac2 := size - blake2s.Size128
   219  	smac1 := smac2 - blake2s.Size128
   220  
   221  	mac1 := msg[smac1:smac2]
   222  	mac2 := msg[smac2:]
   223  
   224  	st.Lock()
   225  	defer st.Unlock()
   226  
   227  	// set mac1
   228  
   229  	func() {
   230  		mac, _ := blake2s.New128(st.mac1.key[:])
   231  		mac.Write(msg[:smac1])
   232  		mac.Sum(mac1[:0])
   233  	}()
   234  	copy(st.mac2.lastMAC1[:], mac1)
   235  	st.mac2.hasLastMAC1 = true
   236  
   237  	// set mac2
   238  
   239  	if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
   240  		return
   241  	}
   242  
   243  	func() {
   244  		mac, _ := blake2s.New128(st.mac2.cookie[:])
   245  		mac.Write(msg[:smac2])
   246  		mac.Sum(mac2[:0])
   247  	}()
   248  }