github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/send.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "math/rand" 13 "net" 14 "os" 15 "sync" 16 "time" 17 18 "github.com/amnezia-vpn/amnezia-wg/conn" 19 "github.com/amnezia-vpn/amnezia-wg/tun" 20 "golang.org/x/crypto/chacha20poly1305" 21 "golang.org/x/net/ipv4" 22 "golang.org/x/net/ipv6" 23 ) 24 25 /* Outbound flow 26 * 27 * 1. TUN queue 28 * 2. Routing (sequential) 29 * 3. Nonce assignment (sequential) 30 * 4. Encryption (parallel) 31 * 5. Transmission (sequential) 32 * 33 * The functions in this file occur (roughly) in the order in 34 * which the packets are processed. 35 * 36 * Locking, Producers and Consumers 37 * 38 * The order of packets (per peer) must be maintained, 39 * but encryption of packets happen out-of-order: 40 * 41 * The sequential consumers will attempt to take the lock, 42 * workers release lock when they have completed work (encryption) on the packet. 43 * 44 * If the element is inserted into the "encryption queue", 45 * the content is preceded by enough "junk" to contain the transport header 46 * (to allow the construction of transport messages in-place) 47 */ 48 49 type QueueOutboundElement struct { 50 buffer *[MaxMessageSize]byte // slice holding the packet data 51 packet []byte // slice of "buffer" (always!) 52 nonce uint64 // nonce for encryption 53 keypair *Keypair // keypair for encryption 54 peer *Peer // related peer 55 } 56 57 type QueueOutboundElementsContainer struct { 58 sync.Mutex 59 elems []*QueueOutboundElement 60 } 61 62 func (device *Device) NewOutboundElement() *QueueOutboundElement { 63 elem := device.GetOutboundElement() 64 elem.buffer = device.GetMessageBuffer() 65 elem.nonce = 0 66 // keypair and peer were cleared (if necessary) by clearPointers. 67 return elem 68 } 69 70 // clearPointers clears elem fields that contain pointers. 71 // This makes the garbage collector's life easier and 72 // avoids accidentally keeping other objects around unnecessarily. 73 // It also reduces the possible collateral damage from use-after-free bugs. 74 func (elem *QueueOutboundElement) clearPointers() { 75 elem.buffer = nil 76 elem.packet = nil 77 elem.keypair = nil 78 elem.peer = nil 79 } 80 81 /* Queues a keepalive if no packets are queued for peer 82 */ 83 func (peer *Peer) SendKeepalive() { 84 if len(peer.queue.staged) == 0 && peer.isRunning.Load() { 85 elem := peer.device.NewOutboundElement() 86 elemsContainer := peer.device.GetOutboundElementsContainer() 87 elemsContainer.elems = append(elemsContainer.elems, elem) 88 select { 89 case peer.queue.staged <- elemsContainer: 90 peer.device.log.Verbosef("%v - Sending keepalive packet", peer) 91 default: 92 peer.device.PutMessageBuffer(elem.buffer) 93 peer.device.PutOutboundElement(elem) 94 peer.device.PutOutboundElementsContainer(elemsContainer) 95 } 96 } 97 peer.SendStagedPackets() 98 } 99 100 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { 101 if !isRetry { 102 peer.timers.handshakeAttempts.Store(0) 103 } 104 105 peer.handshake.mutex.RLock() 106 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 107 peer.handshake.mutex.RUnlock() 108 return nil 109 } 110 peer.handshake.mutex.RUnlock() 111 112 peer.handshake.mutex.Lock() 113 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 114 peer.handshake.mutex.Unlock() 115 return nil 116 } 117 peer.handshake.lastSentHandshake = time.Now() 118 peer.handshake.mutex.Unlock() 119 120 peer.device.log.Verbosef("%v - Sending handshake initiation", peer) 121 122 msg, err := peer.device.CreateMessageInitiation(peer) 123 if err != nil { 124 peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) 125 return err 126 } 127 var sendBuffer [][]byte 128 // so only packet processed for cookie generation 129 var junkedHeader []byte 130 if peer.device.isAdvancedSecurityOn() { 131 peer.device.aSecMux.RLock() 132 junks, err := peer.createJunkPackets() 133 peer.device.aSecMux.RUnlock() 134 135 if err != nil { 136 peer.device.log.Errorf("%v - %v", peer, err) 137 return err 138 } 139 140 if len(junks) > 0 { 141 err = peer.SendBuffers(junks) 142 143 if err != nil { 144 peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err) 145 return err 146 } 147 } 148 149 peer.device.aSecMux.RLock() 150 if peer.device.aSecCfg.initPacketJunkSize != 0 { 151 buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) 152 writer := bytes.NewBuffer(buf[:0]) 153 err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) 154 if err != nil { 155 peer.device.log.Errorf("%v - %v", peer, err) 156 peer.device.aSecMux.RUnlock() 157 return err 158 } 159 junkedHeader = writer.Bytes() 160 } 161 peer.device.aSecMux.RUnlock() 162 } 163 164 var buf [MessageInitiationSize]byte 165 writer := bytes.NewBuffer(buf[:0]) 166 binary.Write(writer, binary.LittleEndian, msg) 167 packet := writer.Bytes() 168 peer.cookieGenerator.AddMacs(packet) 169 junkedHeader = append(junkedHeader, packet...) 170 171 peer.timersAnyAuthenticatedPacketTraversal() 172 peer.timersAnyAuthenticatedPacketSent() 173 174 sendBuffer = append(sendBuffer, junkedHeader) 175 176 err = peer.SendBuffers(sendBuffer) 177 if err != nil { 178 peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) 179 } 180 peer.timersHandshakeInitiated() 181 182 return err 183 } 184 185 func (peer *Peer) SendHandshakeResponse() error { 186 peer.handshake.mutex.Lock() 187 peer.handshake.lastSentHandshake = time.Now() 188 peer.handshake.mutex.Unlock() 189 190 peer.device.log.Verbosef("%v - Sending handshake response", peer) 191 192 response, err := peer.device.CreateMessageResponse(peer) 193 if err != nil { 194 peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) 195 return err 196 } 197 var junkedHeader []byte 198 if peer.device.isAdvancedSecurityOn() { 199 peer.device.aSecMux.RLock() 200 if peer.device.aSecCfg.responsePacketJunkSize != 0 { 201 buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) 202 writer := bytes.NewBuffer(buf[:0]) 203 err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) 204 if err != nil { 205 peer.device.aSecMux.RUnlock() 206 peer.device.log.Errorf("%v - %v", peer, err) 207 return err 208 } 209 junkedHeader = writer.Bytes() 210 } 211 peer.device.aSecMux.RUnlock() 212 } 213 var buf [MessageResponseSize]byte 214 writer := bytes.NewBuffer(buf[:0]) 215 216 binary.Write(writer, binary.LittleEndian, response) 217 packet := writer.Bytes() 218 peer.cookieGenerator.AddMacs(packet) 219 junkedHeader = append(junkedHeader, packet...) 220 221 err = peer.BeginSymmetricSession() 222 if err != nil { 223 peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 224 return err 225 } 226 227 peer.timersSessionDerived() 228 peer.timersAnyAuthenticatedPacketTraversal() 229 peer.timersAnyAuthenticatedPacketSent() 230 231 // TODO: allocation could be avoided 232 err = peer.SendBuffers([][]byte{junkedHeader}) 233 if err != nil { 234 peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) 235 } 236 return err 237 } 238 239 func (device *Device) SendHandshakeCookie( 240 initiatingElem *QueueHandshakeElement, 241 ) error { 242 device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) 243 244 sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) 245 reply, err := device.cookieChecker.CreateReply( 246 initiatingElem.packet, 247 sender, 248 initiatingElem.endpoint.DstToBytes(), 249 ) 250 if err != nil { 251 device.log.Errorf("Failed to create cookie reply: %v", err) 252 return err 253 } 254 255 var buf [MessageCookieReplySize]byte 256 writer := bytes.NewBuffer(buf[:0]) 257 binary.Write(writer, binary.LittleEndian, reply) 258 // TODO: allocation could be avoided 259 device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) 260 return nil 261 } 262 263 func (peer *Peer) keepKeyFreshSending() { 264 keypair := peer.keypairs.Current() 265 if keypair == nil { 266 return 267 } 268 nonce := keypair.sendNonce.Load() 269 if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { 270 peer.SendHandshakeInitiation(false) 271 } 272 } 273 274 func (device *Device) RoutineReadFromTUN() { 275 defer func() { 276 device.log.Verbosef("Routine: TUN reader - stopped") 277 device.state.stopping.Done() 278 device.queue.encryption.wg.Done() 279 }() 280 281 device.log.Verbosef("Routine: TUN reader - started") 282 283 var ( 284 batchSize = device.BatchSize() 285 readErr error 286 elems = make([]*QueueOutboundElement, batchSize) 287 bufs = make([][]byte, batchSize) 288 elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) 289 count = 0 290 sizes = make([]int, batchSize) 291 offset = MessageTransportHeaderSize 292 ) 293 294 for i := range elems { 295 elems[i] = device.NewOutboundElement() 296 bufs[i] = elems[i].buffer[:] 297 } 298 299 defer func() { 300 for _, elem := range elems { 301 if elem != nil { 302 device.PutMessageBuffer(elem.buffer) 303 device.PutOutboundElement(elem) 304 } 305 } 306 }() 307 308 for { 309 // read packets 310 count, readErr = device.tun.device.Read(bufs, sizes, offset) 311 for i := 0; i < count; i++ { 312 if sizes[i] < 1 { 313 continue 314 } 315 316 elem := elems[i] 317 elem.packet = bufs[i][offset : offset+sizes[i]] 318 319 // lookup peer 320 var peer *Peer 321 switch elem.packet[0] >> 4 { 322 case 4: 323 if len(elem.packet) < ipv4.HeaderLen { 324 continue 325 } 326 dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] 327 peer = device.allowedips.Lookup(dst) 328 329 case 6: 330 if len(elem.packet) < ipv6.HeaderLen { 331 continue 332 } 333 dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] 334 peer = device.allowedips.Lookup(dst) 335 336 default: 337 device.log.Verbosef("Received packet with unknown IP version") 338 } 339 340 if peer == nil { 341 continue 342 } 343 elemsForPeer, ok := elemsByPeer[peer] 344 if !ok { 345 elemsForPeer = device.GetOutboundElementsContainer() 346 elemsByPeer[peer] = elemsForPeer 347 } 348 elemsForPeer.elems = append(elemsForPeer.elems, elem) 349 elems[i] = device.NewOutboundElement() 350 bufs[i] = elems[i].buffer[:] 351 } 352 353 for peer, elemsForPeer := range elemsByPeer { 354 if peer.isRunning.Load() { 355 peer.StagePackets(elemsForPeer) 356 peer.SendStagedPackets() 357 } else { 358 for _, elem := range elemsForPeer.elems { 359 device.PutMessageBuffer(elem.buffer) 360 device.PutOutboundElement(elem) 361 } 362 device.PutOutboundElementsContainer(elemsForPeer) 363 } 364 delete(elemsByPeer, peer) 365 } 366 367 if readErr != nil { 368 if errors.Is(readErr, tun.ErrTooManySegments) { 369 // TODO: record stat for this 370 // This will happen if MSS is surprisingly small (< 576) 371 // coincident with reasonably high throughput. 372 device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) 373 continue 374 } 375 if !device.isClosed() { 376 if !errors.Is(readErr, os.ErrClosed) { 377 device.log.Errorf("Failed to read packet from TUN device: %v", readErr) 378 } 379 go device.Close() 380 } 381 return 382 } 383 } 384 } 385 386 func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { 387 for { 388 select { 389 case peer.queue.staged <- elems: 390 return 391 default: 392 } 393 select { 394 case tooOld := <-peer.queue.staged: 395 for _, elem := range tooOld.elems { 396 peer.device.PutMessageBuffer(elem.buffer) 397 peer.device.PutOutboundElement(elem) 398 } 399 peer.device.PutOutboundElementsContainer(tooOld) 400 default: 401 } 402 } 403 } 404 405 func (peer *Peer) SendStagedPackets() { 406 top: 407 if len(peer.queue.staged) == 0 || !peer.device.isUp() { 408 return 409 } 410 411 keypair := peer.keypairs.Current() 412 if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { 413 peer.SendHandshakeInitiation(false) 414 return 415 } 416 417 for { 418 var elemsContainerOOO *QueueOutboundElementsContainer 419 select { 420 case elemsContainer := <-peer.queue.staged: 421 i := 0 422 for _, elem := range elemsContainer.elems { 423 elem.peer = peer 424 elem.nonce = keypair.sendNonce.Add(1) - 1 425 if elem.nonce >= RejectAfterMessages { 426 keypair.sendNonce.Store(RejectAfterMessages) 427 if elemsContainerOOO == nil { 428 elemsContainerOOO = peer.device.GetOutboundElementsContainer() 429 } 430 elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) 431 continue 432 } else { 433 elemsContainer.elems[i] = elem 434 i++ 435 } 436 437 elem.keypair = keypair 438 } 439 elemsContainer.Lock() 440 elemsContainer.elems = elemsContainer.elems[:i] 441 442 if elemsContainerOOO != nil { 443 peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans 444 } 445 446 if len(elemsContainer.elems) == 0 { 447 peer.device.PutOutboundElementsContainer(elemsContainer) 448 goto top 449 } 450 451 // add to parallel and sequential queue 452 if peer.isRunning.Load() { 453 peer.queue.outbound.c <- elemsContainer 454 peer.device.queue.encryption.c <- elemsContainer 455 } else { 456 for _, elem := range elemsContainer.elems { 457 peer.device.PutMessageBuffer(elem.buffer) 458 peer.device.PutOutboundElement(elem) 459 } 460 peer.device.PutOutboundElementsContainer(elemsContainer) 461 } 462 463 if elemsContainerOOO != nil { 464 goto top 465 } 466 default: 467 return 468 } 469 } 470 } 471 472 func (peer *Peer) createJunkPackets() ([][]byte, error) { 473 if peer.device.aSecCfg.junkPacketCount == 0 { 474 return nil, nil 475 } 476 477 junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount) 478 for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ { 479 packetSize := rand.Intn( 480 peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize, 481 ) + peer.device.aSecCfg.junkPacketMinSize 482 483 junk, err := randomJunkWithSize(packetSize) 484 if err != nil { 485 peer.device.log.Errorf( 486 "%v - Failed to create junk packet: %v", 487 peer, 488 err, 489 ) 490 return nil, err 491 } 492 junks = append(junks, junk) 493 } 494 return junks, nil 495 } 496 497 func (peer *Peer) FlushStagedPackets() { 498 for { 499 select { 500 case elemsContainer := <-peer.queue.staged: 501 for _, elem := range elemsContainer.elems { 502 peer.device.PutMessageBuffer(elem.buffer) 503 peer.device.PutOutboundElement(elem) 504 } 505 peer.device.PutOutboundElementsContainer(elemsContainer) 506 default: 507 return 508 } 509 } 510 } 511 512 func calculatePaddingSize(packetSize, mtu int) int { 513 lastUnit := packetSize 514 if mtu == 0 { 515 return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit 516 } 517 if lastUnit > mtu { 518 lastUnit %= mtu 519 } 520 paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) 521 if paddedSize > mtu { 522 paddedSize = mtu 523 } 524 return paddedSize - lastUnit 525 } 526 527 /* Encrypts the elements in the queue 528 * and marks them for sequential consumption (by releasing the mutex) 529 * 530 * Obs. One instance per core 531 */ 532 func (device *Device) RoutineEncryption(id int) { 533 var paddingZeros [PaddingMultiple]byte 534 var nonce [chacha20poly1305.NonceSize]byte 535 536 defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) 537 device.log.Verbosef("Routine: encryption worker %d - started", id) 538 539 for elemsContainer := range device.queue.encryption.c { 540 for _, elem := range elemsContainer.elems { 541 // populate header fields 542 header := elem.buffer[:MessageTransportHeaderSize] 543 544 fieldType := header[0:4] 545 fieldReceiver := header[4:8] 546 fieldNonce := header[8:16] 547 548 binary.LittleEndian.PutUint32(fieldType, MessageTransportType) 549 binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) 550 binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) 551 552 // pad content to multiple of 16 553 paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) 554 elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) 555 556 // encrypt content and release to consumer 557 558 binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) 559 elem.packet = elem.keypair.send.Seal( 560 header, 561 nonce[:], 562 elem.packet, 563 nil, 564 ) 565 } 566 elemsContainer.Unlock() 567 } 568 } 569 570 func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { 571 device := peer.device 572 defer func() { 573 defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) 574 peer.stopping.Done() 575 }() 576 device.log.Verbosef("%v - Routine: sequential sender - started", peer) 577 578 bufs := make([][]byte, 0, maxBatchSize) 579 580 for elemsContainer := range peer.queue.outbound.c { 581 bufs = bufs[:0] 582 if elemsContainer == nil { 583 return 584 } 585 if !peer.isRunning.Load() { 586 // peer has been stopped; return re-usable elems to the shared pool. 587 // This is an optimization only. It is possible for the peer to be stopped 588 // immediately after this check, in which case, elem will get processed. 589 // The timers and SendBuffers code are resilient to a few stragglers. 590 // TODO: rework peer shutdown order to ensure 591 // that we never accidentally keep timers alive longer than necessary. 592 elemsContainer.Lock() 593 for _, elem := range elemsContainer.elems { 594 device.PutMessageBuffer(elem.buffer) 595 device.PutOutboundElement(elem) 596 } 597 continue 598 } 599 dataSent := false 600 elemsContainer.Lock() 601 for _, elem := range elemsContainer.elems { 602 if len(elem.packet) != MessageKeepaliveSize { 603 dataSent = true 604 } 605 bufs = append(bufs, elem.packet) 606 } 607 608 peer.timersAnyAuthenticatedPacketTraversal() 609 peer.timersAnyAuthenticatedPacketSent() 610 611 err := peer.SendBuffers(bufs) 612 if dataSent { 613 peer.timersDataSent() 614 } 615 for _, elem := range elemsContainer.elems { 616 device.PutMessageBuffer(elem.buffer) 617 device.PutOutboundElement(elem) 618 } 619 device.PutOutboundElementsContainer(elemsContainer) 620 if err != nil { 621 var errGSO conn.ErrUDPGSODisabled 622 if errors.As(err, &errGSO) { 623 device.log.Verbosef(err.Error()) 624 err = errGSO.RetryErr 625 } 626 } 627 if err != nil { 628 device.log.Errorf("%v - Failed to send data packets: %v", peer, err) 629 continue 630 } 631 632 peer.keepKeyFreshSending() 633 } 634 }