github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/cookie.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"crypto/hmac"
    10  	"crypto/rand"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  )
    17  
    18  type CookieChecker struct {
    19  	sync.RWMutex
    20  	mac1 struct {
    21  		key [blake2s.Size]byte
    22  	}
    23  	mac2 struct {
    24  		secret        [blake2s.Size]byte
    25  		secretSet     time.Time
    26  		encryptionKey [chacha20poly1305.KeySize]byte
    27  	}
    28  }
    29  
    30  type CookieGenerator struct {
    31  	sync.RWMutex
    32  	mac1 struct {
    33  		key [blake2s.Size]byte
    34  	}
    35  	mac2 struct {
    36  		cookie        [blake2s.Size128]byte
    37  		cookieSet     time.Time
    38  		hasLastMAC1   bool
    39  		lastMAC1      [blake2s.Size128]byte
    40  		encryptionKey [chacha20poly1305.KeySize]byte
    41  	}
    42  }
    43  
    44  func (st *CookieChecker) Init(pk NoisePublicKey) {
    45  	st.Lock()
    46  	defer st.Unlock()
    47  
    48  	// mac1 state
    49  
    50  	func() {
    51  		hash, _ := blake2s.New256(nil)
    52  		hash.Write([]byte(WGLabelMAC1))
    53  		hash.Write(pk[:])
    54  		hash.Sum(st.mac1.key[:0])
    55  	}()
    56  
    57  	// mac2 state
    58  
    59  	func() {
    60  		hash, _ := blake2s.New256(nil)
    61  		hash.Write([]byte(WGLabelCookie))
    62  		hash.Write(pk[:])
    63  		hash.Sum(st.mac2.encryptionKey[:0])
    64  	}()
    65  
    66  	st.mac2.secretSet = time.Time{}
    67  }
    68  
    69  func (st *CookieChecker) CheckMAC1(msg []byte) bool {
    70  	st.RLock()
    71  	defer st.RUnlock()
    72  
    73  	size := len(msg)
    74  	smac2 := size - blake2s.Size128
    75  	smac1 := smac2 - blake2s.Size128
    76  
    77  	var mac1 [blake2s.Size128]byte
    78  
    79  	mac, _ := blake2s.New128(st.mac1.key[:])
    80  	mac.Write(msg[:smac1])
    81  	mac.Sum(mac1[:0])
    82  
    83  	return hmac.Equal(mac1[:], msg[smac1:smac2])
    84  }
    85  
    86  func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
    87  	st.RLock()
    88  	defer st.RUnlock()
    89  
    90  	if time.Since(st.mac2.secretSet) > CookieRefreshTime {
    91  		return false
    92  	}
    93  
    94  	// derive cookie key
    95  
    96  	var cookie [blake2s.Size128]byte
    97  	func() {
    98  		mac, _ := blake2s.New128(st.mac2.secret[:])
    99  		mac.Write(src)
   100  		mac.Sum(cookie[:0])
   101  	}()
   102  
   103  	// calculate mac of packet (including mac1)
   104  
   105  	smac2 := len(msg) - blake2s.Size128
   106  
   107  	var mac2 [blake2s.Size128]byte
   108  	func() {
   109  		mac, _ := blake2s.New128(cookie[:])
   110  		mac.Write(msg[:smac2])
   111  		mac.Sum(mac2[:0])
   112  	}()
   113  
   114  	return hmac.Equal(mac2[:], msg[smac2:])
   115  }
   116  
   117  func (st *CookieChecker) CreateReply(
   118  	msg []byte,
   119  	recv uint32,
   120  	src []byte,
   121  ) (*MessageCookieReply, error) {
   122  
   123  	st.RLock()
   124  
   125  	// refresh cookie secret
   126  
   127  	if time.Since(st.mac2.secretSet) > CookieRefreshTime {
   128  		st.RUnlock()
   129  		st.Lock()
   130  		_, err := rand.Read(st.mac2.secret[:])
   131  		if err != nil {
   132  			st.Unlock()
   133  			return nil, err
   134  		}
   135  		st.mac2.secretSet = time.Now()
   136  		st.Unlock()
   137  		st.RLock()
   138  	}
   139  
   140  	// derive cookie
   141  
   142  	var cookie [blake2s.Size128]byte
   143  	func() {
   144  		mac, _ := blake2s.New128(st.mac2.secret[:])
   145  		mac.Write(src)
   146  		mac.Sum(cookie[:0])
   147  	}()
   148  
   149  	// encrypt cookie
   150  
   151  	size := len(msg)
   152  
   153  	smac2 := size - blake2s.Size128
   154  	smac1 := smac2 - blake2s.Size128
   155  
   156  	reply := new(MessageCookieReply)
   157  	reply.Type = MessageCookieReplyType
   158  	reply.Receiver = recv
   159  
   160  	_, err := rand.Read(reply.Nonce[:])
   161  	if err != nil {
   162  		st.RUnlock()
   163  		return nil, err
   164  	}
   165  
   166  	xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
   167  	xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
   168  
   169  	st.RUnlock()
   170  
   171  	return reply, nil
   172  }
   173  
   174  func (st *CookieGenerator) Init(pk NoisePublicKey) {
   175  	st.Lock()
   176  	defer st.Unlock()
   177  
   178  	func() {
   179  		hash, _ := blake2s.New256(nil)
   180  		hash.Write([]byte(WGLabelMAC1))
   181  		hash.Write(pk[:])
   182  		hash.Sum(st.mac1.key[:0])
   183  	}()
   184  
   185  	func() {
   186  		hash, _ := blake2s.New256(nil)
   187  		hash.Write([]byte(WGLabelCookie))
   188  		hash.Write(pk[:])
   189  		hash.Sum(st.mac2.encryptionKey[:0])
   190  	}()
   191  
   192  	st.mac2.cookieSet = time.Time{}
   193  }
   194  
   195  func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
   196  	st.Lock()
   197  	defer st.Unlock()
   198  
   199  	if !st.mac2.hasLastMAC1 {
   200  		return false
   201  	}
   202  
   203  	var cookie [blake2s.Size128]byte
   204  
   205  	xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
   206  	_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
   207  
   208  	if err != nil {
   209  		return false
   210  	}
   211  
   212  	st.mac2.cookieSet = time.Now()
   213  	st.mac2.cookie = cookie
   214  	return true
   215  }
   216  
   217  func (st *CookieGenerator) AddMacs(msg []byte) {
   218  
   219  	size := len(msg)
   220  
   221  	smac2 := size - blake2s.Size128
   222  	smac1 := smac2 - blake2s.Size128
   223  
   224  	mac1 := msg[smac1:smac2]
   225  	mac2 := msg[smac2:]
   226  
   227  	st.Lock()
   228  	defer st.Unlock()
   229  
   230  	// set mac1
   231  
   232  	func() {
   233  		mac, _ := blake2s.New128(st.mac1.key[:])
   234  		mac.Write(msg[:smac1])
   235  		mac.Sum(mac1[:0])
   236  	}()
   237  	copy(st.mac2.lastMAC1[:], mac1)
   238  	st.mac2.hasLastMAC1 = true
   239  
   240  	// set mac2
   241  
   242  	if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
   243  		return
   244  	}
   245  
   246  	func() {
   247  		mac, _ := blake2s.New128(st.mac2.cookie[:])
   248  		mac.Write(msg[:smac2])
   249  		mac.Sum(mac2[:0])
   250  	}()
   251  }