github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/send.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "os" 14 "sync" 15 "time" 16 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 "github.com/koomox/wireguard-go/tun" 21 ) 22 23 /* Outbound flow 24 * 25 * 1. TUN queue 26 * 2. Routing (sequential) 27 * 3. Nonce assignment (sequential) 28 * 4. Encryption (parallel) 29 * 5. Transmission (sequential) 30 * 31 * The functions in this file occur (roughly) in the order in 32 * which the packets are processed. 33 * 34 * Locking, Producers and Consumers 35 * 36 * The order of packets (per peer) must be maintained, 37 * but encryption of packets happen out-of-order: 38 * 39 * The sequential consumers will attempt to take the lock, 40 * workers release lock when they have completed work (encryption) on the packet. 41 * 42 * If the element is inserted into the "encryption queue", 43 * the content is preceded by enough "junk" to contain the transport header 44 * (to allow the construction of transport messages in-place) 45 */ 46 47 type QueueOutboundElement struct { 48 sync.Mutex 49 buffer *[MaxMessageSize]byte // slice holding the packet data 50 packet []byte // slice of "buffer" (always!) 51 nonce uint64 // nonce for encryption 52 keypair *Keypair // keypair for encryption 53 peer *Peer // related peer 54 } 55 56 func (device *Device) NewOutboundElement() *QueueOutboundElement { 57 elem := device.GetOutboundElement() 58 elem.buffer = device.GetMessageBuffer() 59 elem.Mutex = sync.Mutex{} 60 elem.nonce = 0 61 // keypair and peer were cleared (if necessary) by clearPointers. 62 return elem 63 } 64 65 // clearPointers clears elem fields that contain pointers. 66 // This makes the garbage collector's life easier and 67 // avoids accidentally keeping other objects around unnecessarily. 68 // It also reduces the possible collateral damage from use-after-free bugs. 69 func (elem *QueueOutboundElement) clearPointers() { 70 elem.buffer = nil 71 elem.packet = nil 72 elem.keypair = nil 73 elem.peer = nil 74 } 75 76 /* Queues a keepalive if no packets are queued for peer 77 */ 78 func (peer *Peer) SendKeepalive() { 79 if len(peer.queue.staged) == 0 && peer.isRunning.Load() { 80 elem := peer.device.NewOutboundElement() 81 elems := peer.device.GetOutboundElementsSlice() 82 *elems = append(*elems, elem) 83 select { 84 case peer.queue.staged <- elems: 85 peer.device.log.Verbosef("%v - Sending keepalive packet", peer) 86 default: 87 peer.device.PutMessageBuffer(elem.buffer) 88 peer.device.PutOutboundElement(elem) 89 peer.device.PutOutboundElementsSlice(elems) 90 } 91 } 92 peer.SendStagedPackets() 93 } 94 95 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { 96 if !isRetry { 97 peer.timers.handshakeAttempts.Store(0) 98 } 99 100 peer.handshake.mutex.RLock() 101 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 102 peer.handshake.mutex.RUnlock() 103 return nil 104 } 105 peer.handshake.mutex.RUnlock() 106 107 peer.handshake.mutex.Lock() 108 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 109 peer.handshake.mutex.Unlock() 110 return nil 111 } 112 peer.handshake.lastSentHandshake = time.Now() 113 peer.handshake.mutex.Unlock() 114 115 peer.device.log.Verbosef("%v - Sending handshake initiation", peer) 116 117 msg, err := peer.device.CreateMessageInitiation(peer) 118 if err != nil { 119 peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) 120 return err 121 } 122 123 var buf [MessageInitiationSize]byte 124 writer := bytes.NewBuffer(buf[:0]) 125 binary.Write(writer, binary.LittleEndian, msg) 126 packet := writer.Bytes() 127 peer.cookieGenerator.AddMacs(packet) 128 129 peer.timersAnyAuthenticatedPacketTraversal() 130 peer.timersAnyAuthenticatedPacketSent() 131 132 err = peer.SendBuffers([][]byte{packet}) 133 if err != nil { 134 peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) 135 } 136 peer.timersHandshakeInitiated() 137 138 return err 139 } 140 141 func (peer *Peer) SendHandshakeResponse() error { 142 peer.handshake.mutex.Lock() 143 peer.handshake.lastSentHandshake = time.Now() 144 peer.handshake.mutex.Unlock() 145 146 peer.device.log.Verbosef("%v - Sending handshake response", peer) 147 148 response, err := peer.device.CreateMessageResponse(peer) 149 if err != nil { 150 peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) 151 return err 152 } 153 154 var buf [MessageResponseSize]byte 155 writer := bytes.NewBuffer(buf[:0]) 156 binary.Write(writer, binary.LittleEndian, response) 157 packet := writer.Bytes() 158 peer.cookieGenerator.AddMacs(packet) 159 160 err = peer.BeginSymmetricSession() 161 if err != nil { 162 peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 163 return err 164 } 165 166 peer.timersSessionDerived() 167 peer.timersAnyAuthenticatedPacketTraversal() 168 peer.timersAnyAuthenticatedPacketSent() 169 170 // TODO: allocation could be avoided 171 err = peer.SendBuffers([][]byte{packet}) 172 if err != nil { 173 peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) 174 } 175 return err 176 } 177 178 func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { 179 device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) 180 181 sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) 182 reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) 183 if err != nil { 184 device.log.Errorf("Failed to create cookie reply: %v", err) 185 return err 186 } 187 188 var buf [MessageCookieReplySize]byte 189 writer := bytes.NewBuffer(buf[:0]) 190 binary.Write(writer, binary.LittleEndian, reply) 191 // TODO: allocation could be avoided 192 device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) 193 return nil 194 } 195 196 func (peer *Peer) keepKeyFreshSending() { 197 keypair := peer.keypairs.Current() 198 if keypair == nil { 199 return 200 } 201 nonce := keypair.sendNonce.Load() 202 if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { 203 peer.SendHandshakeInitiation(false) 204 } 205 } 206 207 func (device *Device) RoutineReadFromTUN() { 208 defer func() { 209 device.log.Verbosef("Routine: TUN reader - stopped") 210 device.state.stopping.Done() 211 device.queue.encryption.wg.Done() 212 }() 213 214 device.log.Verbosef("Routine: TUN reader - started") 215 216 var ( 217 batchSize = device.BatchSize() 218 readErr error 219 elems = make([]*QueueOutboundElement, batchSize) 220 bufs = make([][]byte, batchSize) 221 elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) 222 count = 0 223 sizes = make([]int, batchSize) 224 offset = MessageTransportHeaderSize 225 ) 226 227 for i := range elems { 228 elems[i] = device.NewOutboundElement() 229 bufs[i] = elems[i].buffer[:] 230 } 231 232 defer func() { 233 for _, elem := range elems { 234 if elem != nil { 235 device.PutMessageBuffer(elem.buffer) 236 device.PutOutboundElement(elem) 237 } 238 } 239 }() 240 241 for { 242 // read packets 243 count, readErr = device.tun.device.Read(bufs, sizes, offset) 244 for i := 0; i < count; i++ { 245 if sizes[i] < 1 { 246 continue 247 } 248 249 elem := elems[i] 250 elem.packet = bufs[i][offset : offset+sizes[i]] 251 252 // lookup peer 253 var peer *Peer 254 switch elem.packet[0] >> 4 { 255 case 4: 256 if len(elem.packet) < ipv4.HeaderLen { 257 continue 258 } 259 dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] 260 peer = device.allowedips.Lookup(dst) 261 262 case 6: 263 if len(elem.packet) < ipv6.HeaderLen { 264 continue 265 } 266 dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] 267 peer = device.allowedips.Lookup(dst) 268 269 default: 270 device.log.Verbosef("Received packet with unknown IP version") 271 } 272 273 if peer == nil { 274 continue 275 } 276 elemsForPeer, ok := elemsByPeer[peer] 277 if !ok { 278 elemsForPeer = device.GetOutboundElementsSlice() 279 elemsByPeer[peer] = elemsForPeer 280 } 281 *elemsForPeer = append(*elemsForPeer, elem) 282 elems[i] = device.NewOutboundElement() 283 bufs[i] = elems[i].buffer[:] 284 } 285 286 for peer, elemsForPeer := range elemsByPeer { 287 if peer.isRunning.Load() { 288 peer.StagePackets(elemsForPeer) 289 peer.SendStagedPackets() 290 } else { 291 for _, elem := range *elemsForPeer { 292 device.PutMessageBuffer(elem.buffer) 293 device.PutOutboundElement(elem) 294 } 295 device.PutOutboundElementsSlice(elemsForPeer) 296 } 297 delete(elemsByPeer, peer) 298 } 299 300 if readErr != nil { 301 if errors.Is(readErr, tun.ErrTooManySegments) { 302 // TODO: record stat for this 303 // This will happen if MSS is surprisingly small (< 576) 304 // coincident with reasonably high throughput. 305 device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) 306 continue 307 } 308 if !device.isClosed() { 309 if !errors.Is(readErr, os.ErrClosed) { 310 device.log.Errorf("Failed to read packet from TUN device: %v", readErr) 311 } 312 go device.Close() 313 } 314 return 315 } 316 } 317 } 318 319 func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { 320 for { 321 select { 322 case peer.queue.staged <- elems: 323 return 324 default: 325 } 326 select { 327 case tooOld := <-peer.queue.staged: 328 for _, elem := range *tooOld { 329 peer.device.PutMessageBuffer(elem.buffer) 330 peer.device.PutOutboundElement(elem) 331 } 332 peer.device.PutOutboundElementsSlice(tooOld) 333 default: 334 } 335 } 336 } 337 338 func (peer *Peer) SendStagedPackets() { 339 top: 340 if len(peer.queue.staged) == 0 || !peer.device.isUp() { 341 return 342 } 343 344 keypair := peer.keypairs.Current() 345 if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { 346 peer.SendHandshakeInitiation(false) 347 return 348 } 349 350 for { 351 var elemsOOO *[]*QueueOutboundElement 352 select { 353 case elems := <-peer.queue.staged: 354 i := 0 355 for _, elem := range *elems { 356 elem.peer = peer 357 elem.nonce = keypair.sendNonce.Add(1) - 1 358 if elem.nonce >= RejectAfterMessages { 359 keypair.sendNonce.Store(RejectAfterMessages) 360 if elemsOOO == nil { 361 elemsOOO = peer.device.GetOutboundElementsSlice() 362 } 363 *elemsOOO = append(*elemsOOO, elem) 364 continue 365 } else { 366 (*elems)[i] = elem 367 i++ 368 } 369 370 elem.keypair = keypair 371 elem.Lock() 372 } 373 *elems = (*elems)[:i] 374 375 if elemsOOO != nil { 376 peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans 377 } 378 379 if len(*elems) == 0 { 380 peer.device.PutOutboundElementsSlice(elems) 381 goto top 382 } 383 384 // add to parallel and sequential queue 385 if peer.isRunning.Load() { 386 peer.queue.outbound.c <- elems 387 for _, elem := range *elems { 388 peer.device.queue.encryption.c <- elem 389 } 390 } else { 391 for _, elem := range *elems { 392 peer.device.PutMessageBuffer(elem.buffer) 393 peer.device.PutOutboundElement(elem) 394 } 395 peer.device.PutOutboundElementsSlice(elems) 396 } 397 398 if elemsOOO != nil { 399 goto top 400 } 401 default: 402 return 403 } 404 } 405 } 406 407 func (peer *Peer) FlushStagedPackets() { 408 for { 409 select { 410 case elems := <-peer.queue.staged: 411 for _, elem := range *elems { 412 peer.device.PutMessageBuffer(elem.buffer) 413 peer.device.PutOutboundElement(elem) 414 } 415 peer.device.PutOutboundElementsSlice(elems) 416 default: 417 return 418 } 419 } 420 } 421 422 func calculatePaddingSize(packetSize, mtu int) int { 423 lastUnit := packetSize 424 if mtu == 0 { 425 return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit 426 } 427 if lastUnit > mtu { 428 lastUnit %= mtu 429 } 430 paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) 431 if paddedSize > mtu { 432 paddedSize = mtu 433 } 434 return paddedSize - lastUnit 435 } 436 437 /* Encrypts the elements in the queue 438 * and marks them for sequential consumption (by releasing the mutex) 439 * 440 * Obs. One instance per core 441 */ 442 func (device *Device) RoutineEncryption(id int) { 443 var paddingZeros [PaddingMultiple]byte 444 var nonce [chacha20poly1305.NonceSize]byte 445 446 defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) 447 device.log.Verbosef("Routine: encryption worker %d - started", id) 448 449 for elem := range device.queue.encryption.c { 450 // populate header fields 451 header := elem.buffer[:MessageTransportHeaderSize] 452 453 fieldType := header[0:4] 454 fieldReceiver := header[4:8] 455 fieldNonce := header[8:16] 456 457 binary.LittleEndian.PutUint32(fieldType, MessageTransportType) 458 binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) 459 binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) 460 461 // pad content to multiple of 16 462 paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) 463 elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) 464 465 // encrypt content and release to consumer 466 467 binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) 468 elem.packet = elem.keypair.send.Seal( 469 header, 470 nonce[:], 471 elem.packet, 472 nil, 473 ) 474 elem.Unlock() 475 } 476 } 477 478 func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { 479 device := peer.device 480 defer func() { 481 defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) 482 peer.stopping.Done() 483 }() 484 device.log.Verbosef("%v - Routine: sequential sender - started", peer) 485 486 bufs := make([][]byte, 0, maxBatchSize) 487 488 for elems := range peer.queue.outbound.c { 489 bufs = bufs[:0] 490 if elems == nil { 491 return 492 } 493 if !peer.isRunning.Load() { 494 // peer has been stopped; return re-usable elems to the shared pool. 495 // This is an optimization only. It is possible for the peer to be stopped 496 // immediately after this check, in which case, elem will get processed. 497 // The timers and SendBuffers code are resilient to a few stragglers. 498 // TODO: rework peer shutdown order to ensure 499 // that we never accidentally keep timers alive longer than necessary. 500 for _, elem := range *elems { 501 elem.Lock() 502 device.PutMessageBuffer(elem.buffer) 503 device.PutOutboundElement(elem) 504 } 505 continue 506 } 507 dataSent := false 508 for _, elem := range *elems { 509 elem.Lock() 510 if len(elem.packet) != MessageKeepaliveSize { 511 dataSent = true 512 } 513 bufs = append(bufs, elem.packet) 514 } 515 516 peer.timersAnyAuthenticatedPacketTraversal() 517 peer.timersAnyAuthenticatedPacketSent() 518 519 err := peer.SendBuffers(bufs) 520 if dataSent { 521 peer.timersDataSent() 522 } 523 for _, elem := range *elems { 524 device.PutMessageBuffer(elem.buffer) 525 device.PutOutboundElement(elem) 526 } 527 device.PutOutboundElementsSlice(elems) 528 if err != nil { 529 device.log.Errorf("%v - Failed to send data packets: %v", peer, err) 530 continue 531 } 532 533 peer.keepKeyFreshSending() 534 } 535 }