github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/timers.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"log/slog"
    36  	"sync"
    37  	"time"
    38  	_ "unsafe"
    39  )
    40  
    41  //go:linkname fastrandn runtime.fastrandn
    42  func fastrandn(n uint32) uint32
    43  
    44  // A Timer manages time-based aspects of the WireGuard protocol.
    45  // Timer roughly copies the interface of the Linux kernel's struct timer_list.
    46  type Timer struct {
    47  	*time.Timer
    48  	modifyingLock sync.RWMutex
    49  	runningLock   sync.Mutex
    50  	isPending     bool
    51  }
    52  
    53  func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
    54  	timer := &Timer{}
    55  	timer.Timer = time.AfterFunc(time.Hour, func() {
    56  		timer.runningLock.Lock()
    57  		defer timer.runningLock.Unlock()
    58  
    59  		timer.modifyingLock.Lock()
    60  		if !timer.isPending {
    61  			timer.modifyingLock.Unlock()
    62  			return
    63  		}
    64  		timer.isPending = false
    65  		timer.modifyingLock.Unlock()
    66  
    67  		expirationFunction(peer)
    68  	})
    69  	timer.Stop()
    70  	return timer
    71  }
    72  
    73  func (timer *Timer) Mod(d time.Duration) {
    74  	timer.modifyingLock.Lock()
    75  	timer.isPending = true
    76  	timer.Reset(d)
    77  	timer.modifyingLock.Unlock()
    78  }
    79  
    80  func (timer *Timer) Del() {
    81  	timer.modifyingLock.Lock()
    82  	timer.isPending = false
    83  	timer.Stop()
    84  	timer.modifyingLock.Unlock()
    85  }
    86  
    87  func (timer *Timer) DelSync() {
    88  	timer.Del()
    89  	timer.runningLock.Lock()
    90  	timer.Del()
    91  	timer.runningLock.Unlock()
    92  }
    93  
    94  func (timer *Timer) IsPending() bool {
    95  	timer.modifyingLock.RLock()
    96  	defer timer.modifyingLock.RUnlock()
    97  	return timer.isPending
    98  }
    99  
   100  func (peer *Peer) timersActive() bool {
   101  	return peer.isRunning.Load() && peer.transport != nil && peer.transport.isUp()
   102  }
   103  
   104  func expiredRetransmitHandshake(peer *Peer) {
   105  	if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
   106  		peer.transport.logger.Error("Handshake did not complete after multiple attempts, giving up",
   107  			slog.String("peer", peer.String()), slog.Int("maxAttempts", MaxTimerHandshakes+2))
   108  
   109  		if peer.timersActive() {
   110  			peer.timers.sendKeepalive.Del()
   111  		}
   112  
   113  		/* We drop all packets without a keypair and don't try again,
   114  		 * if we try unsuccessfully for too long to make a handshake.
   115  		 */
   116  		peer.FlushStagedPackets()
   117  
   118  		/* We set a timer for destroying any residue that might be left
   119  		 * of a partial exchange.
   120  		 */
   121  		if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
   122  			peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
   123  		}
   124  	} else {
   125  		peer.timers.handshakeAttempts.Add(1)
   126  		peer.transport.logger.Warn("Handshake did not complete within timeout, retrying",
   127  			slog.String("peer", peer.String()), slog.Int("timeout", int(RekeyTimeout.Seconds())),
   128  			slog.Int("try", int(peer.timers.handshakeAttempts.Load()+1)))
   129  
   130  		if err := peer.SendHandshakeInitiation(true); err != nil {
   131  			peer.transport.logger.Error("Failed to retransmit handshake initiation",
   132  				slog.String("peer", peer.String()), slog.Any("error", err))
   133  		}
   134  	}
   135  }
   136  
   137  func expiredSendKeepalive(peer *Peer) {
   138  	if err := peer.SendKeepalive(); err != nil {
   139  		peer.transport.logger.Error("Failed to send keepalive",
   140  			slog.String("peer", peer.String()), slog.Any("error", err))
   141  	}
   142  
   143  	if peer.timers.needAnotherKeepalive.Load() {
   144  		peer.timers.needAnotherKeepalive.Store(false)
   145  		if peer.timersActive() {
   146  			peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
   147  		}
   148  	}
   149  }
   150  
   151  func expiredNewHandshake(peer *Peer) {
   152  	peer.transport.logger.Debug("Retrying handshake because we stopped hearing back",
   153  		slog.String("peer", peer.String()), slog.Int("timeout", int(KeepaliveTimeout.Seconds())))
   154  	if err := peer.SendHandshakeInitiation(false); err != nil {
   155  		peer.transport.logger.Error("Failed to retransmit handshake initiation",
   156  			slog.String("peer", peer.String()), slog.Any("error", err))
   157  	}
   158  }
   159  
   160  func expiredZeroKeyMaterial(peer *Peer) {
   161  	peer.transport.logger.Debug("Removing all keys, since we haven't received a new one in time",
   162  		slog.String("peer", peer.String()), slog.Int("timeout", int((RejectAfterTime*3).Seconds())))
   163  	peer.ZeroAndFlushAll()
   164  }
   165  
   166  func expiredKeepAlive(peer *Peer) {
   167  	if peer.keepAliveInterval.Load() > 0 {
   168  		if err := peer.SendKeepalive(); err != nil {
   169  			peer.transport.logger.Error("Failed to send keepalive",
   170  				slog.String("peer", peer.String()), slog.Any("error", err))
   171  		}
   172  	}
   173  }
   174  
   175  /* Should be called after an authenticated data packet is sent. */
   176  func (peer *Peer) timersDataSent() {
   177  	if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
   178  		peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
   179  	}
   180  }
   181  
   182  /* Should be called after an authenticated data packet is received. */
   183  func (peer *Peer) timersDataReceived() {
   184  	if peer.timersActive() {
   185  		if !peer.timers.sendKeepalive.IsPending() {
   186  			peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
   187  		} else {
   188  			peer.timers.needAnotherKeepalive.Store(true)
   189  		}
   190  	}
   191  }
   192  
   193  /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
   194  func (peer *Peer) timersAnyAuthenticatedPacketSent() {
   195  	if peer.timersActive() {
   196  		peer.timers.sendKeepalive.Del()
   197  	}
   198  }
   199  
   200  /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
   201  func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
   202  	if peer.timersActive() {
   203  		peer.timers.newHandshake.Del()
   204  	}
   205  }
   206  
   207  /* Should be called after a handshake initiation message is sent. */
   208  func (peer *Peer) timersHandshakeInitiated() {
   209  	if peer.timersActive() {
   210  		peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
   211  	}
   212  }
   213  
   214  /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
   215  func (peer *Peer) timersHandshakeComplete() {
   216  	if peer.timersActive() {
   217  		peer.timers.retransmitHandshake.Del()
   218  	}
   219  	peer.timers.handshakeAttempts.Store(0)
   220  	peer.timers.sentLastMinuteHandshake.Store(false)
   221  	peer.lastHandshakeNano.Store(time.Now().UnixNano())
   222  }
   223  
   224  /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
   225  func (peer *Peer) timersSessionDerived() {
   226  	if peer.timersActive() {
   227  		peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
   228  	}
   229  }
   230  
   231  /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
   232  func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
   233  	keepAlive := peer.keepAliveInterval.Load()
   234  	if keepAlive > 0 && peer.timersActive() {
   235  		peer.timers.keepAlive.Mod(time.Duration(keepAlive) * time.Second)
   236  	}
   237  }
   238  
   239  func (peer *Peer) timersInit() {
   240  	peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
   241  	peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
   242  	peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
   243  	peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
   244  	peer.timers.keepAlive = peer.NewTimer(expiredKeepAlive)
   245  }
   246  
   247  func (peer *Peer) timersStart() {
   248  	peer.timers.handshakeAttempts.Store(0)
   249  	peer.timers.sentLastMinuteHandshake.Store(false)
   250  	peer.timers.needAnotherKeepalive.Store(false)
   251  }
   252  
   253  func (peer *Peer) timersStop() {
   254  	peer.timers.retransmitHandshake.DelSync()
   255  	peer.timers.sendKeepalive.DelSync()
   256  	peer.timers.newHandshake.DelSync()
   257  	peer.timers.zeroKeyMaterial.DelSync()
   258  	peer.timers.keepAlive.DelSync()
   259  }