github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/send.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "os" 14 "sync" 15 "time" 16 17 "github.com/sagernet/wireguard-go/conn" 18 "github.com/sagernet/wireguard-go/tun" 19 "golang.org/x/crypto/chacha20poly1305" 20 "golang.org/x/net/ipv4" 21 "golang.org/x/net/ipv6" 22 ) 23 24 /* Outbound flow 25 * 26 * 1. TUN queue 27 * 2. Routing (sequential) 28 * 3. Nonce assignment (sequential) 29 * 4. Encryption (parallel) 30 * 5. Transmission (sequential) 31 * 32 * The functions in this file occur (roughly) in the order in 33 * which the packets are processed. 34 * 35 * Locking, Producers and Consumers 36 * 37 * The order of packets (per peer) must be maintained, 38 * but encryption of packets happen out-of-order: 39 * 40 * The sequential consumers will attempt to take the lock, 41 * workers release lock when they have completed work (encryption) on the packet. 42 * 43 * If the element is inserted into the "encryption queue", 44 * the content is preceded by enough "junk" to contain the transport header 45 * (to allow the construction of transport messages in-place) 46 */ 47 48 type QueueOutboundElement struct { 49 buffer *[MaxMessageSize]byte // slice holding the packet data 50 packet []byte // slice of "buffer" (always!) 51 nonce uint64 // nonce for encryption 52 keypair *Keypair // keypair for encryption 53 peer *Peer // related peer 54 } 55 56 type QueueOutboundElementsContainer struct { 57 sync.Mutex 58 elems []*QueueOutboundElement 59 } 60 61 func (device *Device) NewOutboundElement() *QueueOutboundElement { 62 elem := device.GetOutboundElement() 63 elem.buffer = device.GetMessageBuffer() 64 elem.nonce = 0 65 // keypair and peer were cleared (if necessary) by clearPointers. 66 return elem 67 } 68 69 // clearPointers clears elem fields that contain pointers. 70 // This makes the garbage collector's life easier and 71 // avoids accidentally keeping other objects around unnecessarily. 72 // It also reduces the possible collateral damage from use-after-free bugs. 73 func (elem *QueueOutboundElement) clearPointers() { 74 elem.buffer = nil 75 elem.packet = nil 76 elem.keypair = nil 77 elem.peer = nil 78 } 79 80 /* Queues a keepalive if no packets are queued for peer 81 */ 82 func (peer *Peer) SendKeepalive() { 83 if len(peer.queue.staged) == 0 && peer.isRunning.Load() { 84 elem := peer.device.NewOutboundElement() 85 elemsContainer := peer.device.GetOutboundElementsContainer() 86 elemsContainer.elems = append(elemsContainer.elems, elem) 87 select { 88 case peer.queue.staged <- elemsContainer: 89 peer.device.log.Verbosef("%v - Sending keepalive packet", peer) 90 default: 91 peer.device.PutMessageBuffer(elem.buffer) 92 peer.device.PutOutboundElement(elem) 93 peer.device.PutOutboundElementsContainer(elemsContainer) 94 } 95 } 96 peer.SendStagedPackets() 97 } 98 99 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { 100 if !isRetry { 101 peer.timers.handshakeAttempts.Store(0) 102 } 103 104 peer.handshake.mutex.RLock() 105 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 106 peer.handshake.mutex.RUnlock() 107 return nil 108 } 109 peer.handshake.mutex.RUnlock() 110 111 peer.handshake.mutex.Lock() 112 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 113 peer.handshake.mutex.Unlock() 114 return nil 115 } 116 peer.handshake.lastSentHandshake = time.Now() 117 peer.handshake.mutex.Unlock() 118 119 peer.device.log.Verbosef("%v - Sending handshake initiation", peer) 120 121 msg, err := peer.device.CreateMessageInitiation(peer) 122 if err != nil { 123 peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) 124 return err 125 } 126 127 var buf [MessageInitiationSize]byte 128 writer := bytes.NewBuffer(buf[:0]) 129 binary.Write(writer, binary.LittleEndian, msg) 130 packet := writer.Bytes() 131 peer.cookieGenerator.AddMacs(packet) 132 133 peer.timersAnyAuthenticatedPacketTraversal() 134 peer.timersAnyAuthenticatedPacketSent() 135 136 err = peer.SendBuffers([][]byte{packet}) 137 if err != nil { 138 peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) 139 } 140 peer.timersHandshakeInitiated() 141 142 return err 143 } 144 145 func (peer *Peer) SendHandshakeResponse() error { 146 peer.handshake.mutex.Lock() 147 peer.handshake.lastSentHandshake = time.Now() 148 peer.handshake.mutex.Unlock() 149 150 peer.device.log.Verbosef("%v - Sending handshake response", peer) 151 152 response, err := peer.device.CreateMessageResponse(peer) 153 if err != nil { 154 peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) 155 return err 156 } 157 158 var buf [MessageResponseSize]byte 159 writer := bytes.NewBuffer(buf[:0]) 160 binary.Write(writer, binary.LittleEndian, response) 161 packet := writer.Bytes() 162 peer.cookieGenerator.AddMacs(packet) 163 164 err = peer.BeginSymmetricSession() 165 if err != nil { 166 peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 167 return err 168 } 169 170 peer.timersSessionDerived() 171 peer.timersAnyAuthenticatedPacketTraversal() 172 peer.timersAnyAuthenticatedPacketSent() 173 174 // TODO: allocation could be avoided 175 err = peer.SendBuffers([][]byte{packet}) 176 if err != nil { 177 peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) 178 } 179 return err 180 } 181 182 func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { 183 device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) 184 185 sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) 186 reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) 187 if err != nil { 188 device.log.Errorf("Failed to create cookie reply: %v", err) 189 return err 190 } 191 192 var buf [MessageCookieReplySize]byte 193 writer := bytes.NewBuffer(buf[:0]) 194 binary.Write(writer, binary.LittleEndian, reply) 195 // TODO: allocation could be avoided 196 device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) 197 return nil 198 } 199 200 func (peer *Peer) keepKeyFreshSending() { 201 keypair := peer.keypairs.Current() 202 if keypair == nil { 203 return 204 } 205 nonce := keypair.sendNonce.Load() 206 if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { 207 peer.SendHandshakeInitiation(false) 208 } 209 } 210 211 func (device *Device) RoutineReadFromTUN() { 212 defer func() { 213 device.log.Verbosef("Routine: TUN reader - stopped") 214 device.state.stopping.Done() 215 device.queue.encryption.wg.Done() 216 }() 217 218 device.log.Verbosef("Routine: TUN reader - started") 219 220 var ( 221 batchSize = device.BatchSize() 222 readErr error 223 elems = make([]*QueueOutboundElement, batchSize) 224 bufs = make([][]byte, batchSize) 225 elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) 226 count = 0 227 sizes = make([]int, batchSize) 228 offset = MessageTransportHeaderSize 229 ) 230 231 for i := range elems { 232 elems[i] = device.NewOutboundElement() 233 bufs[i] = elems[i].buffer[:] 234 } 235 236 defer func() { 237 for _, elem := range elems { 238 if elem != nil { 239 device.PutMessageBuffer(elem.buffer) 240 device.PutOutboundElement(elem) 241 } 242 } 243 }() 244 245 for { 246 // read packets 247 count, readErr = device.tun.device.Read(bufs, sizes, offset) 248 for i := 0; i < count; i++ { 249 if sizes[i] < 1 { 250 continue 251 } 252 253 elem := elems[i] 254 elem.packet = bufs[i][offset : offset+sizes[i]] 255 256 // lookup peer 257 var peer *Peer 258 switch elem.packet[0] >> 4 { 259 case 4: 260 if len(elem.packet) < ipv4.HeaderLen { 261 continue 262 } 263 dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] 264 peer = device.allowedips.Lookup(dst) 265 266 case 6: 267 if len(elem.packet) < ipv6.HeaderLen { 268 continue 269 } 270 dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] 271 peer = device.allowedips.Lookup(dst) 272 273 default: 274 device.log.Verbosef("Received packet with unknown IP version") 275 } 276 277 if peer == nil { 278 continue 279 } 280 elemsForPeer, ok := elemsByPeer[peer] 281 if !ok { 282 elemsForPeer = device.GetOutboundElementsContainer() 283 elemsByPeer[peer] = elemsForPeer 284 } 285 elemsForPeer.elems = append(elemsForPeer.elems, elem) 286 elems[i] = device.NewOutboundElement() 287 bufs[i] = elems[i].buffer[:] 288 } 289 290 for peer, elemsForPeer := range elemsByPeer { 291 if peer.isRunning.Load() { 292 peer.StagePackets(elemsForPeer) 293 peer.SendStagedPackets() 294 } else { 295 for _, elem := range elemsForPeer.elems { 296 device.PutMessageBuffer(elem.buffer) 297 device.PutOutboundElement(elem) 298 } 299 device.PutOutboundElementsContainer(elemsForPeer) 300 } 301 delete(elemsByPeer, peer) 302 } 303 304 if readErr != nil { 305 if errors.Is(readErr, tun.ErrTooManySegments) { 306 // TODO: record stat for this 307 // This will happen if MSS is surprisingly small (< 576) 308 // coincident with reasonably high throughput. 309 device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) 310 continue 311 } 312 if !device.isClosed() { 313 if !errors.Is(readErr, os.ErrClosed) { 314 device.log.Errorf("Failed to read packet from TUN device: %v", readErr) 315 } 316 go device.Close() 317 } 318 return 319 } 320 } 321 } 322 323 func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { 324 for { 325 select { 326 case peer.queue.staged <- elems: 327 return 328 default: 329 } 330 select { 331 case tooOld := <-peer.queue.staged: 332 for _, elem := range tooOld.elems { 333 peer.device.PutMessageBuffer(elem.buffer) 334 peer.device.PutOutboundElement(elem) 335 } 336 peer.device.PutOutboundElementsContainer(tooOld) 337 default: 338 } 339 } 340 } 341 342 func (peer *Peer) SendStagedPackets() { 343 top: 344 if len(peer.queue.staged) == 0 || !peer.device.isUp() { 345 return 346 } 347 348 keypair := peer.keypairs.Current() 349 if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { 350 peer.SendHandshakeInitiation(false) 351 return 352 } 353 354 for { 355 var elemsContainerOOO *QueueOutboundElementsContainer 356 select { 357 case elemsContainer := <-peer.queue.staged: 358 i := 0 359 for _, elem := range elemsContainer.elems { 360 elem.peer = peer 361 elem.nonce = keypair.sendNonce.Add(1) - 1 362 if elem.nonce >= RejectAfterMessages { 363 keypair.sendNonce.Store(RejectAfterMessages) 364 if elemsContainerOOO == nil { 365 elemsContainerOOO = peer.device.GetOutboundElementsContainer() 366 } 367 elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) 368 continue 369 } else { 370 elemsContainer.elems[i] = elem 371 i++ 372 } 373 374 elem.keypair = keypair 375 } 376 elemsContainer.Lock() 377 elemsContainer.elems = elemsContainer.elems[:i] 378 379 if elemsContainerOOO != nil { 380 peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans 381 } 382 383 if len(elemsContainer.elems) == 0 { 384 peer.device.PutOutboundElementsContainer(elemsContainer) 385 goto top 386 } 387 388 // add to parallel and sequential queue 389 if peer.isRunning.Load() { 390 peer.queue.outbound.c <- elemsContainer 391 peer.device.queue.encryption.c <- elemsContainer 392 } else { 393 for _, elem := range elemsContainer.elems { 394 peer.device.PutMessageBuffer(elem.buffer) 395 peer.device.PutOutboundElement(elem) 396 } 397 peer.device.PutOutboundElementsContainer(elemsContainer) 398 } 399 400 if elemsContainerOOO != nil { 401 goto top 402 } 403 default: 404 return 405 } 406 } 407 } 408 409 func (peer *Peer) FlushStagedPackets() { 410 for { 411 select { 412 case elemsContainer := <-peer.queue.staged: 413 for _, elem := range elemsContainer.elems { 414 peer.device.PutMessageBuffer(elem.buffer) 415 peer.device.PutOutboundElement(elem) 416 } 417 peer.device.PutOutboundElementsContainer(elemsContainer) 418 default: 419 return 420 } 421 } 422 } 423 424 func calculatePaddingSize(packetSize, mtu int) int { 425 lastUnit := packetSize 426 if mtu == 0 { 427 return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit 428 } 429 if lastUnit > mtu { 430 lastUnit %= mtu 431 } 432 paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) 433 if paddedSize > mtu { 434 paddedSize = mtu 435 } 436 return paddedSize - lastUnit 437 } 438 439 /* Encrypts the elements in the queue 440 * and marks them for sequential consumption (by releasing the mutex) 441 * 442 * Obs. One instance per core 443 */ 444 func (device *Device) RoutineEncryption(id int) { 445 var paddingZeros [PaddingMultiple]byte 446 var nonce [chacha20poly1305.NonceSize]byte 447 448 defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) 449 device.log.Verbosef("Routine: encryption worker %d - started", id) 450 451 for elemsContainer := range device.queue.encryption.c { 452 for _, elem := range elemsContainer.elems { 453 // populate header fields 454 header := elem.buffer[:MessageTransportHeaderSize] 455 456 fieldType := header[0:4] 457 fieldReceiver := header[4:8] 458 fieldNonce := header[8:16] 459 460 binary.LittleEndian.PutUint32(fieldType, MessageTransportType) 461 binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) 462 binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) 463 464 // pad content to multiple of 16 465 paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) 466 elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) 467 468 // encrypt content and release to consumer 469 470 binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) 471 elem.packet = elem.keypair.send.Seal( 472 header, 473 nonce[:], 474 elem.packet, 475 nil, 476 ) 477 } 478 elemsContainer.Unlock() 479 } 480 } 481 482 func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { 483 device := peer.device 484 defer func() { 485 defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) 486 peer.stopping.Done() 487 }() 488 device.log.Verbosef("%v - Routine: sequential sender - started", peer) 489 490 bufs := make([][]byte, 0, maxBatchSize) 491 492 for elemsContainer := range peer.queue.outbound.c { 493 bufs = bufs[:0] 494 if elemsContainer == nil { 495 return 496 } 497 if !peer.isRunning.Load() { 498 // peer has been stopped; return re-usable elems to the shared pool. 499 // This is an optimization only. It is possible for the peer to be stopped 500 // immediately after this check, in which case, elem will get processed. 501 // The timers and SendBuffers code are resilient to a few stragglers. 502 // TODO: rework peer shutdown order to ensure 503 // that we never accidentally keep timers alive longer than necessary. 504 elemsContainer.Lock() 505 for _, elem := range elemsContainer.elems { 506 device.PutMessageBuffer(elem.buffer) 507 device.PutOutboundElement(elem) 508 } 509 continue 510 } 511 dataSent := false 512 elemsContainer.Lock() 513 for _, elem := range elemsContainer.elems { 514 if len(elem.packet) != MessageKeepaliveSize { 515 dataSent = true 516 } 517 bufs = append(bufs, elem.packet) 518 } 519 520 peer.timersAnyAuthenticatedPacketTraversal() 521 peer.timersAnyAuthenticatedPacketSent() 522 523 err := peer.SendBuffers(bufs) 524 if dataSent { 525 peer.timersDataSent() 526 } 527 for _, elem := range elemsContainer.elems { 528 device.PutMessageBuffer(elem.buffer) 529 device.PutOutboundElement(elem) 530 } 531 device.PutOutboundElementsContainer(elemsContainer) 532 if err != nil { 533 var errGSO conn.ErrUDPGSODisabled 534 if errors.As(err, &errGSO) { 535 device.log.Verbosef(err.Error()) 536 err = errGSO.RetryErr 537 } 538 } 539 if err != nil { 540 device.log.Errorf("%v - Failed to send data packets: %v", peer, err) 541 continue 542 } 543 544 peer.keepKeyFreshSending() 545 } 546 }