github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "sync/atomic" 15 "time" 16 "github.com/aead/chacha20" 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 "github.com/cawidtu/notwireguard-go/conn" 21 ) 22 23 type QueueHandshakeElement struct { 24 msgType uint32 25 packet []byte 26 endpoint conn.Endpoint 27 buffer *[MaxMessageSize]byte 28 } 29 30 type QueueInboundElement struct { 31 sync.Mutex 32 buffer *[MaxMessageSize]byte 33 packet []byte 34 counter uint64 35 keypair *Keypair 36 endpoint conn.Endpoint 37 } 38 39 // clearPointers clears elem fields that contain pointers. 40 // This makes the garbage collector's life easier and 41 // avoids accidentally keeping other objects around unnecessarily. 42 // It also reduces the possible collateral damage from use-after-free bugs. 43 func (elem *QueueInboundElement) clearPointers() { 44 elem.buffer = nil 45 elem.packet = nil 46 elem.keypair = nil 47 elem.endpoint = nil 48 } 49 50 func wgDeobfuscatePacket(obfuscator [NoisePublicKeySize]byte, buf []byte, len int) { 51 52 decryptLength := min(len&0xFFFC, NoiseObfuscateLenMax) - 4 // sizeof(u32) 53 54 // Initialize the cipher with the obfuscator as the key and the nonce derived from the packet contents. 55 56 var nonce [8]byte; 57 58 copy(nonce[0:4], buf[decryptLength : decryptLength + 4]) 59 60 state, err := chacha20.NewCipher(nonce[:],obfuscator[:]) 61 if err != nil { 62 panic(err) 63 } 64 65 state.XORKeyStream(buf[:decryptLength], buf[:decryptLength]) 66 } 67 68 69 /* Called when a new authenticated message has been received 70 * 71 * NOTE: Not thread safe, but called by sequential receiver! 72 */ 73 func (peer *Peer) keepKeyFreshReceiving() { 74 if peer.timers.sentLastMinuteHandshake.Get() { 75 return 76 } 77 keypair := peer.keypairs.Current() 78 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 79 peer.timers.sentLastMinuteHandshake.Set(true) 80 peer.SendHandshakeInitiation(false) 81 } 82 } 83 84 /* Receives incoming datagrams for the device 85 * 86 * Every time the bind is updated a new routine is started for 87 * IPv4 and IPv6 (separately) 88 */ 89 func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { 90 recvName := recv.PrettyName() 91 defer func() { 92 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 93 device.queue.decryption.wg.Done() 94 device.queue.handshake.wg.Done() 95 device.net.stopping.Done() 96 }() 97 98 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 99 100 // receive datagrams until conn is closed 101 102 buffer := device.GetMessageBuffer() 103 104 var ( 105 err error 106 size int 107 endpoint conn.Endpoint 108 deathSpiral int 109 ) 110 111 for { 112 size, endpoint, err = recv(buffer[:]) 113 114 if err != nil { 115 device.PutMessageBuffer(buffer) 116 if errors.Is(err, net.ErrClosed) { 117 return 118 } 119 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 120 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 121 return 122 } 123 if deathSpiral < 10 { 124 deathSpiral++ 125 time.Sleep(time.Second / 3) 126 buffer = device.GetMessageBuffer() 127 continue 128 } 129 return 130 } 131 deathSpiral = 0 132 133 if size < MinMessageSize { 134 continue 135 } 136 137 // deobfuscate packet 138 139 wgDeobfuscatePacket(device.staticIdentity.obfuscator, buffer[:], size) 140 packet := buffer[:size] 141 msgType := binary.LittleEndian.Uint32(packet[:4]) 142 143 var okay bool 144 var trim bool 145 var headerLen int 146 147 switch msgType { 148 149 // check if transport 150 151 case MessageTransportType: 152 153 // check size 154 155 if len(packet) < MessageTransportSize { 156 continue 157 } 158 159 // lookup key pair 160 161 receiver := binary.LittleEndian.Uint32( 162 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 163 ) 164 value := device.indexTable.Lookup(receiver) 165 keypair := value.keypair 166 if keypair == nil { 167 continue 168 } 169 170 // check keypair expiry 171 172 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 173 continue 174 } 175 176 // create work element 177 peer := value.peer 178 elem := device.GetInboundElement() 179 elem.packet = packet 180 elem.buffer = buffer 181 elem.keypair = keypair 182 elem.endpoint = endpoint 183 elem.counter = 0 184 elem.Mutex = sync.Mutex{} 185 elem.Lock() 186 187 // add to decryption queues 188 if peer.isRunning.Get() { 189 peer.queue.inbound.c <- elem 190 device.queue.decryption.c <- elem 191 buffer = device.GetMessageBuffer() 192 } else { 193 device.PutInboundElement(elem) 194 } 195 continue 196 197 // otherwise it is a "fixed size" & handshake related packet 198 199 // check minimum length instead of fixed length for obfuscation 200 201 case MessageInitiationType: 202 okay = len(packet) >= MessageInitiationSize && len(packet) <= NoiseObfuscateLenMax 203 trim = len(packet) > MessageInitiationSize 204 headerLen = MessageInitiationSize 205 206 case MessageResponseType: 207 okay = len(packet) >= MessageResponseSize && len(packet) <= NoiseObfuscateLenMax 208 trim = len(packet) > MessageResponseSize 209 headerLen = MessageResponseSize 210 211 case MessageCookieReplyType: 212 okay = len(packet) >= MessageCookieReplySize && len(packet) <= NoiseObfuscateLenMax 213 trim = len(packet) > MessageCookieReplySize 214 headerLen = MessageCookieReplySize 215 216 default: 217 device.log.Verbosef("Received message with unknown type") 218 } 219 220 // remove the junk added in obfuscation, if necessary 221 if trim { 222 packet = packet[:headerLen] 223 } 224 225 226 227 if okay { 228 select { 229 case device.queue.handshake.c <- QueueHandshakeElement{ 230 msgType: msgType, 231 buffer: buffer, 232 packet: packet, 233 endpoint: endpoint, 234 }: 235 buffer = device.GetMessageBuffer() 236 default: 237 } 238 } 239 } 240 } 241 242 func (device *Device) RoutineDecryption(id int) { 243 var nonce [chacha20poly1305.NonceSize]byte 244 245 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 246 device.log.Verbosef("Routine: decryption worker %d - started", id) 247 248 for elem := range device.queue.decryption.c { 249 // split message into fields 250 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 251 content := elem.packet[MessageTransportOffsetContent:] 252 253 // decrypt and release to consumer 254 var err error 255 elem.counter = binary.LittleEndian.Uint64(counter) 256 // copy counter to nonce 257 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 258 elem.packet, err = elem.keypair.receive.Open( 259 content[:0], 260 nonce[:], 261 content, 262 nil, 263 ) 264 if err != nil { 265 elem.packet = nil 266 } 267 elem.Unlock() 268 } 269 } 270 271 /* Handles incoming packets related to handshake 272 */ 273 func (device *Device) RoutineHandshake(id int) { 274 defer func() { 275 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 276 device.queue.encryption.wg.Done() 277 }() 278 device.log.Verbosef("Routine: handshake worker %d - started", id) 279 280 for elem := range device.queue.handshake.c { 281 282 // handle cookie fields and ratelimiting 283 284 switch elem.msgType { 285 286 case MessageCookieReplyType: 287 288 // unmarshal packet 289 290 var reply MessageCookieReply 291 reader := bytes.NewReader(elem.packet) 292 err := binary.Read(reader, binary.LittleEndian, &reply) 293 if err != nil { 294 device.log.Verbosef("Failed to decode cookie reply") 295 goto skip 296 } 297 298 // lookup peer from index 299 300 entry := device.indexTable.Lookup(reply.Receiver) 301 302 if entry.peer == nil { 303 goto skip 304 } 305 306 // consume reply 307 308 if peer := entry.peer; peer.isRunning.Get() { 309 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) 310 if !peer.cookieGenerator.ConsumeReply(&reply) { 311 device.log.Verbosef("Could not decrypt invalid cookie response") 312 } 313 } 314 315 goto skip 316 317 case MessageInitiationType, MessageResponseType: 318 319 // check mac fields and maybe ratelimit 320 321 if !device.cookieChecker.CheckMAC1(elem.packet) { 322 device.log.Verbosef("Received packet with invalid mac1") 323 goto skip 324 } 325 326 // endpoints destination address is the source of the datagram 327 328 if device.IsUnderLoad() { 329 330 // verify MAC2 field 331 332 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 333 // extract the obfuscator key from the packet received 334 var obfuscator [32]byte 335 copy(obfuscator[:], elem.packet[8:40]) 336 device.SendHandshakeCookie(&elem, obfuscator) 337 goto skip 338 } 339 340 // check ratelimiter 341 342 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 343 goto skip 344 } 345 } 346 347 default: 348 device.log.Errorf("Invalid packet ended up in the handshake queue") 349 goto skip 350 } 351 352 // handle handshake initiation/response content 353 354 switch elem.msgType { 355 case MessageInitiationType: 356 357 // unmarshal 358 359 var msg MessageInitiation 360 reader := bytes.NewReader(elem.packet) 361 err := binary.Read(reader, binary.LittleEndian, &msg) 362 if err != nil { 363 device.log.Errorf("Failed to decode initiation message") 364 goto skip 365 } 366 367 // consume initiation 368 369 peer := device.ConsumeMessageInitiation(&msg) 370 if peer == nil { 371 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 372 goto skip 373 } 374 375 // update timers 376 377 peer.timersAnyAuthenticatedPacketTraversal() 378 peer.timersAnyAuthenticatedPacketReceived() 379 380 // update endpoint 381 peer.SetEndpointFromPacket(elem.endpoint) 382 383 device.log.Verbosef("%v - Received handshake initiation", peer) 384 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 385 386 peer.SendHandshakeResponse() 387 388 case MessageResponseType: 389 390 // unmarshal 391 392 var msg MessageResponse 393 reader := bytes.NewReader(elem.packet) 394 err := binary.Read(reader, binary.LittleEndian, &msg) 395 if err != nil { 396 device.log.Errorf("Failed to decode response message") 397 goto skip 398 } 399 400 // consume response 401 402 peer := device.ConsumeMessageResponse(&msg) 403 if peer == nil { 404 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 405 goto skip 406 } 407 408 // update endpoint 409 peer.SetEndpointFromPacket(elem.endpoint) 410 411 device.log.Verbosef("%v - Received handshake response", peer) 412 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 413 414 // update timers 415 416 peer.timersAnyAuthenticatedPacketTraversal() 417 peer.timersAnyAuthenticatedPacketReceived() 418 419 // derive keypair 420 421 err = peer.BeginSymmetricSession() 422 423 if err != nil { 424 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 425 goto skip 426 } 427 428 peer.timersSessionDerived() 429 peer.timersHandshakeComplete() 430 peer.SendKeepalive() 431 } 432 skip: 433 device.PutMessageBuffer(elem.buffer) 434 } 435 } 436 437 func (peer *Peer) RoutineSequentialReceiver() { 438 device := peer.device 439 defer func() { 440 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 441 peer.stopping.Done() 442 }() 443 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 444 445 for elem := range peer.queue.inbound.c { 446 if elem == nil { 447 return 448 } 449 var err error 450 elem.Lock() 451 if elem.packet == nil { 452 // decryption failed 453 goto skip 454 } 455 456 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 457 goto skip 458 } 459 460 peer.SetEndpointFromPacket(elem.endpoint) 461 if peer.ReceivedWithKeypair(elem.keypair) { 462 peer.timersHandshakeComplete() 463 peer.SendStagedPackets() 464 } 465 466 peer.keepKeyFreshReceiving() 467 peer.timersAnyAuthenticatedPacketTraversal() 468 peer.timersAnyAuthenticatedPacketReceived() 469 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) 470 471 if len(elem.packet) == 0 { 472 device.log.Verbosef("%v - Receiving keepalive packet", peer) 473 goto skip 474 } 475 peer.timersDataReceived() 476 477 switch elem.packet[0] >> 4 { 478 case ipv4.Version: 479 if len(elem.packet) < ipv4.HeaderLen { 480 goto skip 481 } 482 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 483 length := binary.BigEndian.Uint16(field) 484 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 485 goto skip 486 } 487 elem.packet = elem.packet[:length] 488 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 489 if device.allowedips.Lookup(src) != peer { 490 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) 491 goto skip 492 } 493 494 case ipv6.Version: 495 if len(elem.packet) < ipv6.HeaderLen { 496 goto skip 497 } 498 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 499 length := binary.BigEndian.Uint16(field) 500 length += ipv6.HeaderLen 501 if int(length) > len(elem.packet) { 502 goto skip 503 } 504 elem.packet = elem.packet[:length] 505 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 506 if device.allowedips.Lookup(src) != peer { 507 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 508 goto skip 509 } 510 511 default: 512 device.log.Verbosef("Packet with invalid IP version from %v", peer) 513 goto skip 514 } 515 516 _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) 517 if err != nil && !device.isClosed() { 518 device.log.Errorf("Failed to write packet to TUN device: %v", err) 519 } 520 if len(peer.queue.inbound.c) == 0 { 521 err = device.tun.device.Flush() 522 if err != nil { 523 peer.device.log.Errorf("Unable to flush packets: %v", err) 524 } 525 } 526 skip: 527 device.PutMessageBuffer(elem.buffer) 528 device.PutInboundElement(elem) 529 } 530 }