gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "time" 15 16 "gitee.com/aurawing/surguard-go/conn" 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 ) 21 22 type QueueHandshakeElement struct { 23 msgType uint32 24 packet []byte 25 endpoint conn.Endpoint 26 buffer *[MaxMessageSize]byte 27 } 28 29 type QueueInboundElement struct { 30 buffer *[MaxMessageSize]byte 31 packet []byte 32 counter uint64 33 keypair *Keypair 34 endpoint conn.Endpoint 35 } 36 37 type QueueInboundElementsContainer struct { 38 sync.Mutex 39 elems []*QueueInboundElement 40 } 41 42 // clearPointers clears elem fields that contain pointers. 43 // This makes the garbage collector's life easier and 44 // avoids accidentally keeping other objects around unnecessarily. 45 // It also reduces the possible collateral damage from use-after-free bugs. 46 func (elem *QueueInboundElement) clearPointers() { 47 elem.buffer = nil 48 elem.packet = nil 49 elem.keypair = nil 50 elem.endpoint = nil 51 } 52 53 /* Called when a new authenticated message has been received 54 * 55 * NOTE: Not thread safe, but called by sequential receiver! 56 */ 57 func (peer *Peer) keepKeyFreshReceiving() { 58 if peer.timers.sentLastMinuteHandshake.Load() { 59 return 60 } 61 keypair := peer.keypairs.Current() 62 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 63 peer.timers.sentLastMinuteHandshake.Store(true) 64 peer.SendHandshakeInitiation(false) 65 } 66 } 67 68 /* Receives incoming datagrams for the device 69 * 70 * Every time the bind is updated a new routine is started for 71 * IPv4 and IPv6 (separately) 72 */ 73 func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { 74 recvName := recv.PrettyName() 75 defer func() { 76 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 77 device.queue.decryption.wg.Done() 78 device.queue.handshake.wg.Done() 79 device.net.stopping.Done() 80 }() 81 82 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 83 84 // receive datagrams until conn is closed 85 86 var ( 87 bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) 88 bufs = make([][]byte, maxBatchSize) 89 err error 90 sizes = make([]int, maxBatchSize) 91 count int 92 endpoints = make([]conn.Endpoint, maxBatchSize) 93 deathSpiral int 94 elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) 95 ) 96 97 for i := range bufsArrs { 98 bufsArrs[i] = device.GetMessageBuffer() 99 bufs[i] = bufsArrs[i][:] 100 } 101 102 defer func() { 103 for i := 0; i < maxBatchSize; i++ { 104 if bufsArrs[i] != nil { 105 device.PutMessageBuffer(bufsArrs[i]) 106 } 107 } 108 }() 109 110 for { 111 count, err = recv(bufs, sizes, endpoints) 112 if err != nil { 113 if errors.Is(err, net.ErrClosed) { 114 return 115 } 116 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 117 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 118 return 119 } 120 if deathSpiral < 10 { 121 deathSpiral++ 122 time.Sleep(time.Second / 3) 123 continue 124 } 125 return 126 } 127 deathSpiral = 0 128 129 // handle each packet in the batch 130 for i, size := range sizes[:count] { 131 if size < MinMessageSize { 132 continue 133 } 134 135 // check size of packet 136 137 packet := bufsArrs[i][:size] 138 msgType := binary.LittleEndian.Uint32(packet[:4]) 139 140 switch msgType { 141 142 // check if transport 143 144 case MessageTransportType: 145 146 // check size 147 148 if len(packet) < MessageTransportSize { 149 continue 150 } 151 152 // lookup key pair 153 154 receiver := binary.LittleEndian.Uint32( 155 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 156 ) 157 value := device.indexTable.Lookup(receiver) 158 keypair := value.keypair 159 if keypair == nil { 160 continue 161 } 162 163 // check keypair expiry 164 165 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 166 continue 167 } 168 169 // create work element 170 peer := value.peer 171 elem := device.GetInboundElement() 172 elem.packet = packet 173 elem.buffer = bufsArrs[i] 174 elem.keypair = keypair 175 elem.endpoint = endpoints[i] 176 elem.counter = 0 177 178 elemsForPeer, ok := elemsByPeer[peer] 179 if !ok { 180 elemsForPeer = device.GetInboundElementsContainer() 181 elemsForPeer.Lock() 182 elemsByPeer[peer] = elemsForPeer 183 } 184 elemsForPeer.elems = append(elemsForPeer.elems, elem) 185 bufsArrs[i] = device.GetMessageBuffer() 186 bufs[i] = bufsArrs[i][:] 187 continue 188 189 // otherwise it is a fixed size & handshake related packet 190 191 case MessageInitiationType: 192 if len(packet) != MessageInitiationSize { 193 continue 194 } 195 196 case MessageResponseType: 197 if len(packet) != MessageResponseSize { 198 continue 199 } 200 201 case MessageCookieReplyType: 202 if len(packet) != MessageCookieReplySize { 203 continue 204 } 205 206 default: 207 device.log.Verbosef("Received message with unknown type") 208 continue 209 } 210 211 select { 212 case device.queue.handshake.c <- QueueHandshakeElement{ 213 msgType: msgType, 214 buffer: bufsArrs[i], 215 packet: packet, 216 endpoint: endpoints[i], 217 }: 218 bufsArrs[i] = device.GetMessageBuffer() 219 bufs[i] = bufsArrs[i][:] 220 default: 221 } 222 } 223 for peer, elemsContainer := range elemsByPeer { 224 if peer.isRunning.Load() { 225 peer.queue.inbound.c <- elemsContainer 226 device.queue.decryption.c <- elemsContainer 227 } else { 228 for _, elem := range elemsContainer.elems { 229 device.PutMessageBuffer(elem.buffer) 230 device.PutInboundElement(elem) 231 } 232 device.PutInboundElementsContainer(elemsContainer) 233 } 234 delete(elemsByPeer, peer) 235 } 236 } 237 } 238 239 func (device *Device) RoutineDecryption(id int) { 240 var nonce [chacha20poly1305.NonceSize]byte 241 242 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 243 device.log.Verbosef("Routine: decryption worker %d - started", id) 244 245 for elemsContainer := range device.queue.decryption.c { 246 for _, elem := range elemsContainer.elems { 247 // split message into fields 248 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 249 content := elem.packet[MessageTransportOffsetContent:] 250 251 // decrypt and release to consumer 252 var err error 253 elem.counter = binary.LittleEndian.Uint64(counter) 254 // copy counter to nonce 255 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 256 elem.packet, err = elem.keypair.receive.Open( 257 content[:0], 258 nonce[:], 259 content, 260 nil, 261 ) 262 if err != nil { 263 elem.packet = nil 264 } 265 } 266 elemsContainer.Unlock() 267 } 268 } 269 270 /* Handles incoming packets related to handshake 271 */ 272 func (device *Device) RoutineHandshake(id int) { 273 defer func() { 274 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 275 device.queue.encryption.wg.Done() 276 }() 277 device.log.Verbosef("Routine: handshake worker %d - started", id) 278 279 for elem := range device.queue.handshake.c { 280 281 // handle cookie fields and ratelimiting 282 283 switch elem.msgType { 284 285 case MessageCookieReplyType: 286 287 // unmarshal packet 288 289 var reply MessageCookieReply 290 reader := bytes.NewReader(elem.packet) 291 err := binary.Read(reader, binary.LittleEndian, &reply) 292 if err != nil { 293 device.log.Verbosef("Failed to decode cookie reply") 294 goto skip 295 } 296 297 // lookup peer from index 298 299 entry := device.indexTable.Lookup(reply.Receiver) 300 301 if entry.peer == nil { 302 goto skip 303 } 304 305 // consume reply 306 307 if peer := entry.peer; peer.isRunning.Load() { 308 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) 309 if !peer.cookieGenerator.ConsumeReply(&reply) { 310 device.log.Verbosef("Could not decrypt invalid cookie response") 311 } 312 } 313 314 goto skip 315 316 case MessageInitiationType, MessageResponseType: 317 318 // check mac fields and maybe ratelimit 319 320 if !device.cookieChecker.CheckMAC1(elem.packet) { 321 device.log.Verbosef("Received packet with invalid mac1") 322 goto skip 323 } 324 325 // endpoints destination address is the source of the datagram 326 327 if device.IsUnderLoad() { 328 329 // verify MAC2 field 330 331 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 332 device.SendHandshakeCookie(&elem) 333 goto skip 334 } 335 336 // check ratelimiter 337 338 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 339 goto skip 340 } 341 } 342 343 default: 344 device.log.Errorf("Invalid packet ended up in the handshake queue") 345 goto skip 346 } 347 348 // handle handshake initiation/response content 349 350 switch elem.msgType { 351 case MessageInitiationType: 352 353 // unmarshal 354 355 var msg MessageInitiation 356 reader := bytes.NewReader(elem.packet) 357 err := binary.Read(reader, binary.LittleEndian, &msg) 358 if err != nil { 359 device.log.Errorf("Failed to decode initiation message") 360 goto skip 361 } 362 363 // consume initiation 364 365 peer := device.ConsumeMessageInitiation(&msg) 366 if peer == nil { 367 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 368 goto skip 369 } 370 371 // update timers 372 373 peer.timersAnyAuthenticatedPacketTraversal() 374 peer.timersAnyAuthenticatedPacketReceived() 375 376 // update endpoint 377 peer.SetEndpointFromPacket(elem.endpoint) 378 379 device.log.Verbosef("%v - Received handshake initiation", peer) 380 peer.rxBytes.Add(uint64(len(elem.packet))) 381 382 peer.SendHandshakeResponse() 383 384 case MessageResponseType: 385 386 // unmarshal 387 388 var msg MessageResponse 389 reader := bytes.NewReader(elem.packet) 390 err := binary.Read(reader, binary.LittleEndian, &msg) 391 if err != nil { 392 device.log.Errorf("Failed to decode response message") 393 goto skip 394 } 395 396 // consume response 397 398 peer := device.ConsumeMessageResponse(&msg) 399 if peer == nil { 400 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 401 goto skip 402 } 403 404 // update endpoint 405 peer.SetEndpointFromPacket(elem.endpoint) 406 407 device.log.Verbosef("%v - Received handshake response", peer) 408 peer.rxBytes.Add(uint64(len(elem.packet))) 409 410 // update timers 411 412 peer.timersAnyAuthenticatedPacketTraversal() 413 peer.timersAnyAuthenticatedPacketReceived() 414 415 // derive keypair 416 417 err = peer.BeginSymmetricSession() 418 419 if err != nil { 420 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 421 goto skip 422 } 423 424 peer.timersSessionDerived() 425 peer.timersHandshakeComplete() 426 peer.SendKeepalive() 427 } 428 skip: 429 device.PutMessageBuffer(elem.buffer) 430 } 431 } 432 433 func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { 434 device := peer.device 435 defer func() { 436 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 437 peer.stopping.Done() 438 }() 439 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 440 441 bufs := make([][]byte, 0, maxBatchSize) 442 443 for elemsContainer := range peer.queue.inbound.c { 444 if elemsContainer == nil { 445 return 446 } 447 elemsContainer.Lock() 448 validTailPacket := -1 449 dataPacketReceived := false 450 rxBytesLen := uint64(0) 451 for i, elem := range elemsContainer.elems { 452 if elem.packet == nil { 453 // decryption failed 454 continue 455 } 456 457 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 458 continue 459 } 460 461 validTailPacket = i 462 if peer.ReceivedWithKeypair(elem.keypair) { 463 peer.SetEndpointFromPacket(elem.endpoint) 464 peer.timersHandshakeComplete() 465 peer.SendStagedPackets() 466 } 467 rxBytesLen += uint64(len(elem.packet) + MinMessageSize) 468 469 if len(elem.packet) == 0 { 470 device.log.Verbosef("%v - Receiving keepalive packet", peer) 471 continue 472 } 473 dataPacketReceived = true 474 475 switch elem.packet[0] >> 4 { 476 case 4: 477 if len(elem.packet) < ipv4.HeaderLen { 478 continue 479 } 480 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 481 length := binary.BigEndian.Uint16(field) 482 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 483 continue 484 } 485 elem.packet = elem.packet[:length] 486 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 487 if device.allowedips.Lookup(src) != peer { 488 device.log.Verbosef("IPv4 packet with disallowed source address from %v: %d.%d.%d.%d", peer, src[0], src[1], src[2], src[3]) 489 continue 490 } 491 492 dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] 493 if ifBindInterface && zkCli != nil && interfaceIP != "" { //&& src[0] == 169 && src[1] == 254 && src[2] == interfaceIndex { 494 zkCli.RLock() 495 for i, ipaddr := range interfaceIPArr { 496 if bytes.Equal(ipaddr[:], dst) { 497 dst[0] = 169 498 dst[1] = 254 499 dst[2] = interfaceIndex 500 dst[3] = byte(i + 1) 501 break 502 } 503 } 504 elem.packet[10] = 0 505 elem.packet[11] = 0 506 checksum := IPv4CheckSum(elem.packet[0:20]) 507 binary.BigEndian.PutUint16(elem.packet[10:12], checksum) 508 zkCli.RUnlock() 509 protoType := elem.packet[IPv4offsetProtoType] 510 switch protoType { 511 case 6: 512 TCPv4CheckSum(elem.packet) 513 case 17: 514 UDPv4CheckSum(elem.packet) 515 } 516 } 517 518 case 6: 519 if len(elem.packet) < ipv6.HeaderLen { 520 continue 521 } 522 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 523 length := binary.BigEndian.Uint16(field) 524 length += ipv6.HeaderLen 525 if int(length) > len(elem.packet) { 526 continue 527 } 528 elem.packet = elem.packet[:length] 529 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 530 if device.allowedips.Lookup(src) != peer { 531 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 532 continue 533 } 534 535 default: 536 device.log.Verbosef("Packet with invalid IP version from %v", peer) 537 continue 538 } 539 540 bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) 541 } 542 543 peer.rxBytes.Add(rxBytesLen) 544 if validTailPacket >= 0 { 545 peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) 546 peer.keepKeyFreshReceiving() 547 peer.timersAnyAuthenticatedPacketTraversal() 548 peer.timersAnyAuthenticatedPacketReceived() 549 } 550 if dataPacketReceived { 551 peer.timersDataReceived() 552 } 553 if len(bufs) > 0 { 554 _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) 555 if err != nil && !device.isClosed() { 556 device.log.Errorf("Failed to write packets to TUN device: %v", err) 557 } 558 } 559 for _, elem := range elemsContainer.elems { 560 device.PutMessageBuffer(elem.buffer) 561 device.PutInboundElement(elem) 562 } 563 bufs = bufs[:0] 564 device.PutInboundElementsContainer(elemsContainer) 565 } 566 }