github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "time" 15 16 "github.com/amnezia-vpn/amneziawg-go/conn" 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 ) 21 22 type QueueHandshakeElement struct { 23 msgType uint32 24 packet []byte 25 endpoint conn.Endpoint 26 buffer *[MaxMessageSize]byte 27 } 28 29 type QueueInboundElement struct { 30 buffer *[MaxMessageSize]byte 31 packet []byte 32 counter uint64 33 keypair *Keypair 34 endpoint conn.Endpoint 35 } 36 37 type QueueInboundElementsContainer struct { 38 sync.Mutex 39 elems []*QueueInboundElement 40 } 41 42 // clearPointers clears elem fields that contain pointers. 43 // This makes the garbage collector's life easier and 44 // avoids accidentally keeping other objects around unnecessarily. 45 // It also reduces the possible collateral damage from use-after-free bugs. 46 func (elem *QueueInboundElement) clearPointers() { 47 elem.buffer = nil 48 elem.packet = nil 49 elem.keypair = nil 50 elem.endpoint = nil 51 } 52 53 /* Called when a new authenticated message has been received 54 * 55 * NOTE: Not thread safe, but called by sequential receiver! 56 */ 57 func (peer *Peer) keepKeyFreshReceiving() { 58 if peer.timers.sentLastMinuteHandshake.Load() { 59 return 60 } 61 keypair := peer.keypairs.Current() 62 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 63 peer.timers.sentLastMinuteHandshake.Store(true) 64 peer.SendHandshakeInitiation(false) 65 } 66 } 67 68 /* Receives incoming datagrams for the device 69 * 70 * Every time the bind is updated a new routine is started for 71 * IPv4 and IPv6 (separately) 72 */ 73 func (device *Device) RoutineReceiveIncoming( 74 maxBatchSize int, 75 recv conn.ReceiveFunc, 76 ) { 77 recvName := recv.PrettyName() 78 defer func() { 79 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 80 device.queue.decryption.wg.Done() 81 device.queue.handshake.wg.Done() 82 device.net.stopping.Done() 83 }() 84 85 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 86 87 // receive datagrams until conn is closed 88 89 var ( 90 bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) 91 bufs = make([][]byte, maxBatchSize) 92 err error 93 sizes = make([]int, maxBatchSize) 94 count int 95 endpoints = make([]conn.Endpoint, maxBatchSize) 96 deathSpiral int 97 elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) 98 ) 99 100 for i := range bufsArrs { 101 bufsArrs[i] = device.GetMessageBuffer() 102 bufs[i] = bufsArrs[i][:] 103 } 104 105 defer func() { 106 for i := 0; i < maxBatchSize; i++ { 107 if bufsArrs[i] != nil { 108 device.PutMessageBuffer(bufsArrs[i]) 109 } 110 } 111 }() 112 113 for { 114 count, err = recv(bufs, sizes, endpoints) 115 if err != nil { 116 if errors.Is(err, net.ErrClosed) { 117 return 118 } 119 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 120 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 121 return 122 } 123 if deathSpiral < 10 { 124 deathSpiral++ 125 time.Sleep(time.Second / 3) 126 continue 127 } 128 return 129 } 130 deathSpiral = 0 131 132 device.aSecMux.RLock() 133 // handle each packet in the batch 134 for i, size := range sizes[:count] { 135 if size < MinMessageSize { 136 continue 137 } 138 139 // check size of packet 140 141 packet := bufsArrs[i][:size] 142 var msgType uint32 143 if device.isAdvancedSecurityOn() { 144 if assumedMsgType, ok := packetSizeToMsgType[size]; ok { 145 junkSize := msgTypeToJunkSize[assumedMsgType] 146 // transport size can align with other header types; 147 // making sure we have the right msgType 148 msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4]) 149 if msgType == assumedMsgType { 150 packet = packet[junkSize:] 151 } else { 152 device.log.Verbosef("Transport packet lined up with another msg type") 153 msgType = binary.LittleEndian.Uint32(packet[:4]) 154 } 155 } else { 156 msgType = binary.LittleEndian.Uint32(packet[:4]) 157 if msgType != MessageTransportType { 158 device.log.Verbosef("ASec: Received message with unknown type") 159 continue 160 } 161 } 162 } else { 163 msgType = binary.LittleEndian.Uint32(packet[:4]) 164 } 165 switch msgType { 166 167 // check if transport 168 169 case MessageTransportType: 170 171 // check size 172 173 if len(packet) < MessageTransportSize { 174 continue 175 } 176 177 // lookup key pair 178 179 receiver := binary.LittleEndian.Uint32( 180 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 181 ) 182 value := device.indexTable.Lookup(receiver) 183 keypair := value.keypair 184 if keypair == nil { 185 continue 186 } 187 188 // check keypair expiry 189 190 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 191 continue 192 } 193 194 // create work element 195 peer := value.peer 196 elem := device.GetInboundElement() 197 elem.packet = packet 198 elem.buffer = bufsArrs[i] 199 elem.keypair = keypair 200 elem.endpoint = endpoints[i] 201 elem.counter = 0 202 203 elemsForPeer, ok := elemsByPeer[peer] 204 if !ok { 205 elemsForPeer = device.GetInboundElementsContainer() 206 elemsForPeer.Lock() 207 elemsByPeer[peer] = elemsForPeer 208 } 209 elemsForPeer.elems = append(elemsForPeer.elems, elem) 210 bufsArrs[i] = device.GetMessageBuffer() 211 bufs[i] = bufsArrs[i][:] 212 continue 213 214 // otherwise it is a fixed size & handshake related packet 215 216 case MessageInitiationType: 217 if len(packet) != MessageInitiationSize { 218 continue 219 } 220 221 case MessageResponseType: 222 if len(packet) != MessageResponseSize { 223 continue 224 } 225 226 case MessageCookieReplyType: 227 if len(packet) != MessageCookieReplySize { 228 continue 229 } 230 231 default: 232 device.log.Verbosef("Received message with unknown type") 233 continue 234 } 235 236 select { 237 case device.queue.handshake.c <- QueueHandshakeElement{ 238 msgType: msgType, 239 buffer: bufsArrs[i], 240 packet: packet, 241 endpoint: endpoints[i], 242 }: 243 bufsArrs[i] = device.GetMessageBuffer() 244 bufs[i] = bufsArrs[i][:] 245 default: 246 } 247 } 248 device.aSecMux.RUnlock() 249 for peer, elemsContainer := range elemsByPeer { 250 if peer.isRunning.Load() { 251 peer.queue.inbound.c <- elemsContainer 252 device.queue.decryption.c <- elemsContainer 253 } else { 254 for _, elem := range elemsContainer.elems { 255 device.PutMessageBuffer(elem.buffer) 256 device.PutInboundElement(elem) 257 } 258 device.PutInboundElementsContainer(elemsContainer) 259 } 260 delete(elemsByPeer, peer) 261 } 262 } 263 } 264 265 func (device *Device) RoutineDecryption(id int) { 266 var nonce [chacha20poly1305.NonceSize]byte 267 268 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 269 device.log.Verbosef("Routine: decryption worker %d - started", id) 270 271 for elemsContainer := range device.queue.decryption.c { 272 for _, elem := range elemsContainer.elems { 273 // split message into fields 274 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 275 content := elem.packet[MessageTransportOffsetContent:] 276 277 // decrypt and release to consumer 278 var err error 279 elem.counter = binary.LittleEndian.Uint64(counter) 280 // copy counter to nonce 281 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 282 elem.packet, err = elem.keypair.receive.Open( 283 content[:0], 284 nonce[:], 285 content, 286 nil, 287 ) 288 if err != nil { 289 elem.packet = nil 290 } 291 } 292 elemsContainer.Unlock() 293 } 294 } 295 296 /* Handles incoming packets related to handshake 297 */ 298 func (device *Device) RoutineHandshake(id int) { 299 defer func() { 300 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 301 device.queue.encryption.wg.Done() 302 }() 303 device.log.Verbosef("Routine: handshake worker %d - started", id) 304 305 for elem := range device.queue.handshake.c { 306 307 device.aSecMux.RLock() 308 309 // handle cookie fields and ratelimiting 310 311 switch elem.msgType { 312 313 case MessageCookieReplyType: 314 315 // unmarshal packet 316 317 var reply MessageCookieReply 318 reader := bytes.NewReader(elem.packet) 319 err := binary.Read(reader, binary.LittleEndian, &reply) 320 if err != nil { 321 device.log.Verbosef("Failed to decode cookie reply") 322 goto skip 323 } 324 325 // lookup peer from index 326 327 entry := device.indexTable.Lookup(reply.Receiver) 328 329 if entry.peer == nil { 330 goto skip 331 } 332 333 // consume reply 334 335 if peer := entry.peer; peer.isRunning.Load() { 336 device.log.Verbosef( 337 "Receiving cookie response from %s", 338 elem.endpoint.DstToString(), 339 ) 340 if !peer.cookieGenerator.ConsumeReply(&reply) { 341 device.log.Verbosef( 342 "Could not decrypt invalid cookie response", 343 ) 344 } 345 } 346 347 goto skip 348 349 case MessageInitiationType, MessageResponseType: 350 351 // check mac fields and maybe ratelimit 352 353 if !device.cookieChecker.CheckMAC1(elem.packet) { 354 device.log.Verbosef("Received packet with invalid mac1") 355 goto skip 356 } 357 358 // endpoints destination address is the source of the datagram 359 360 if device.IsUnderLoad() { 361 362 // verify MAC2 field 363 364 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 365 device.SendHandshakeCookie(&elem) 366 goto skip 367 } 368 369 // check ratelimiter 370 371 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 372 goto skip 373 } 374 } 375 376 default: 377 device.log.Errorf("Invalid packet ended up in the handshake queue") 378 goto skip 379 } 380 381 // handle handshake initiation/response content 382 383 switch elem.msgType { 384 case MessageInitiationType: 385 // unmarshal 386 var msg MessageInitiation 387 reader := bytes.NewReader(elem.packet) 388 err := binary.Read(reader, binary.LittleEndian, &msg) 389 if err != nil { 390 device.log.Errorf("Failed to decode initiation message") 391 goto skip 392 } 393 394 // consume initiation 395 peer := device.ConsumeMessageInitiation(&msg) 396 if peer == nil { 397 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 398 goto skip 399 } 400 401 // update timers 402 403 peer.timersAnyAuthenticatedPacketTraversal() 404 peer.timersAnyAuthenticatedPacketReceived() 405 406 // update endpoint 407 peer.SetEndpointFromPacket(elem.endpoint) 408 409 device.log.Verbosef("%v - Received handshake initiation", peer) 410 peer.rxBytes.Add(uint64(len(elem.packet))) 411 412 peer.SendHandshakeResponse() 413 414 case MessageResponseType: 415 416 // unmarshal 417 418 var msg MessageResponse 419 reader := bytes.NewReader(elem.packet) 420 err := binary.Read(reader, binary.LittleEndian, &msg) 421 if err != nil { 422 device.log.Errorf("Failed to decode response message") 423 goto skip 424 } 425 426 // consume response 427 428 peer := device.ConsumeMessageResponse(&msg) 429 if peer == nil { 430 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 431 goto skip 432 } 433 434 // update endpoint 435 peer.SetEndpointFromPacket(elem.endpoint) 436 437 device.log.Verbosef("%v - Received handshake response", peer) 438 peer.rxBytes.Add(uint64(len(elem.packet))) 439 440 // update timers 441 442 peer.timersAnyAuthenticatedPacketTraversal() 443 peer.timersAnyAuthenticatedPacketReceived() 444 445 // derive keypair 446 447 err = peer.BeginSymmetricSession() 448 449 if err != nil { 450 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 451 goto skip 452 } 453 454 peer.timersSessionDerived() 455 peer.timersHandshakeComplete() 456 peer.SendKeepalive() 457 } 458 skip: 459 device.aSecMux.RUnlock() 460 device.PutMessageBuffer(elem.buffer) 461 } 462 } 463 464 func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { 465 device := peer.device 466 defer func() { 467 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 468 peer.stopping.Done() 469 }() 470 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 471 472 bufs := make([][]byte, 0, maxBatchSize) 473 474 for elemsContainer := range peer.queue.inbound.c { 475 if elemsContainer == nil { 476 return 477 } 478 elemsContainer.Lock() 479 validTailPacket := -1 480 dataPacketReceived := false 481 rxBytesLen := uint64(0) 482 for i, elem := range elemsContainer.elems { 483 if elem.packet == nil { 484 // decryption failed 485 continue 486 } 487 488 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 489 continue 490 } 491 492 validTailPacket = i 493 if peer.ReceivedWithKeypair(elem.keypair) { 494 peer.SetEndpointFromPacket(elem.endpoint) 495 peer.timersHandshakeComplete() 496 peer.SendStagedPackets() 497 } 498 rxBytesLen += uint64(len(elem.packet) + MinMessageSize) 499 500 if len(elem.packet) == 0 { 501 device.log.Verbosef("%v - Receiving keepalive packet", peer) 502 continue 503 } 504 dataPacketReceived = true 505 506 switch elem.packet[0] >> 4 { 507 case 4: 508 if len(elem.packet) < ipv4.HeaderLen { 509 continue 510 } 511 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 512 length := binary.BigEndian.Uint16(field) 513 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 514 continue 515 } 516 elem.packet = elem.packet[:length] 517 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 518 if device.allowedips.Lookup(src) != peer { 519 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) 520 continue 521 } 522 523 case 6: 524 if len(elem.packet) < ipv6.HeaderLen { 525 continue 526 } 527 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 528 length := binary.BigEndian.Uint16(field) 529 length += ipv6.HeaderLen 530 if int(length) > len(elem.packet) { 531 continue 532 } 533 elem.packet = elem.packet[:length] 534 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 535 if device.allowedips.Lookup(src) != peer { 536 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 537 continue 538 } 539 540 default: 541 device.log.Verbosef( 542 "Packet with invalid IP version from %v", 543 peer, 544 ) 545 continue 546 } 547 548 bufs = append( 549 bufs, 550 elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], 551 ) 552 } 553 554 peer.rxBytes.Add(rxBytesLen) 555 if validTailPacket >= 0 { 556 peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) 557 peer.keepKeyFreshReceiving() 558 peer.timersAnyAuthenticatedPacketTraversal() 559 peer.timersAnyAuthenticatedPacketReceived() 560 } 561 if dataPacketReceived { 562 peer.timersDataReceived() 563 } 564 if len(bufs) > 0 { 565 _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) 566 if err != nil && !device.isClosed() { 567 device.log.Errorf("Failed to write packets to TUN device: %v", err) 568 } 569 } 570 for _, elem := range elemsContainer.elems { 571 device.PutMessageBuffer(elem.buffer) 572 device.PutInboundElement(elem) 573 } 574 bufs = bufs[:0] 575 device.PutInboundElementsContainer(elemsContainer) 576 } 577 }