github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "time" 15 16 "golang.org/x/crypto/chacha20poly1305" 17 "golang.org/x/net/ipv4" 18 "golang.org/x/net/ipv6" 19 "github.com/koomox/wireguard-go/conn" 20 ) 21 22 type QueueHandshakeElement struct { 23 msgType uint32 24 packet []byte 25 endpoint conn.Endpoint 26 buffer *[MaxMessageSize]byte 27 } 28 29 type QueueInboundElement struct { 30 sync.Mutex 31 buffer *[MaxMessageSize]byte 32 packet []byte 33 counter uint64 34 keypair *Keypair 35 endpoint conn.Endpoint 36 } 37 38 // clearPointers clears elem fields that contain pointers. 39 // This makes the garbage collector's life easier and 40 // avoids accidentally keeping other objects around unnecessarily. 41 // It also reduces the possible collateral damage from use-after-free bugs. 42 func (elem *QueueInboundElement) clearPointers() { 43 elem.buffer = nil 44 elem.packet = nil 45 elem.keypair = nil 46 elem.endpoint = nil 47 } 48 49 /* Called when a new authenticated message has been received 50 * 51 * NOTE: Not thread safe, but called by sequential receiver! 52 */ 53 func (peer *Peer) keepKeyFreshReceiving() { 54 if peer.timers.sentLastMinuteHandshake.Load() { 55 return 56 } 57 keypair := peer.keypairs.Current() 58 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 59 peer.timers.sentLastMinuteHandshake.Store(true) 60 peer.SendHandshakeInitiation(false) 61 } 62 } 63 64 /* Receives incoming datagrams for the device 65 * 66 * Every time the bind is updated a new routine is started for 67 * IPv4 and IPv6 (separately) 68 */ 69 func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { 70 recvName := recv.PrettyName() 71 defer func() { 72 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 73 device.queue.decryption.wg.Done() 74 device.queue.handshake.wg.Done() 75 device.net.stopping.Done() 76 }() 77 78 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 79 80 // receive datagrams until conn is closed 81 82 var ( 83 bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) 84 bufs = make([][]byte, maxBatchSize) 85 err error 86 sizes = make([]int, maxBatchSize) 87 count int 88 endpoints = make([]conn.Endpoint, maxBatchSize) 89 deathSpiral int 90 elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) 91 ) 92 93 for i := range bufsArrs { 94 bufsArrs[i] = device.GetMessageBuffer() 95 bufs[i] = bufsArrs[i][:] 96 } 97 98 defer func() { 99 for i := 0; i < maxBatchSize; i++ { 100 if bufsArrs[i] != nil { 101 device.PutMessageBuffer(bufsArrs[i]) 102 } 103 } 104 }() 105 106 for { 107 count, err = recv(bufs, sizes, endpoints) 108 if err != nil { 109 if errors.Is(err, net.ErrClosed) { 110 return 111 } 112 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 113 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 114 return 115 } 116 if deathSpiral < 10 { 117 deathSpiral++ 118 time.Sleep(time.Second / 3) 119 continue 120 } 121 return 122 } 123 deathSpiral = 0 124 125 // handle each packet in the batch 126 for i, size := range sizes[:count] { 127 if size < MinMessageSize { 128 continue 129 } 130 131 // check size of packet 132 133 packet := bufsArrs[i][:size] 134 msgType := binary.LittleEndian.Uint32(packet[:4]) 135 136 switch msgType { 137 138 // check if transport 139 140 case MessageTransportType: 141 142 // check size 143 144 if len(packet) < MessageTransportSize { 145 continue 146 } 147 148 // lookup key pair 149 150 receiver := binary.LittleEndian.Uint32( 151 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 152 ) 153 value := device.indexTable.Lookup(receiver) 154 keypair := value.keypair 155 if keypair == nil { 156 continue 157 } 158 159 // check keypair expiry 160 161 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 162 continue 163 } 164 165 // create work element 166 peer := value.peer 167 elem := device.GetInboundElement() 168 elem.packet = packet 169 elem.buffer = bufsArrs[i] 170 elem.keypair = keypair 171 elem.endpoint = endpoints[i] 172 elem.counter = 0 173 elem.Mutex = sync.Mutex{} 174 elem.Lock() 175 176 elemsForPeer, ok := elemsByPeer[peer] 177 if !ok { 178 elemsForPeer = device.GetInboundElementsSlice() 179 elemsByPeer[peer] = elemsForPeer 180 } 181 *elemsForPeer = append(*elemsForPeer, elem) 182 bufsArrs[i] = device.GetMessageBuffer() 183 bufs[i] = bufsArrs[i][:] 184 continue 185 186 // otherwise it is a fixed size & handshake related packet 187 188 case MessageInitiationType: 189 if len(packet) != MessageInitiationSize { 190 continue 191 } 192 193 case MessageResponseType: 194 if len(packet) != MessageResponseSize { 195 continue 196 } 197 198 case MessageCookieReplyType: 199 if len(packet) != MessageCookieReplySize { 200 continue 201 } 202 203 default: 204 device.log.Verbosef("Received message with unknown type") 205 continue 206 } 207 208 select { 209 case device.queue.handshake.c <- QueueHandshakeElement{ 210 msgType: msgType, 211 buffer: bufsArrs[i], 212 packet: packet, 213 endpoint: endpoints[i], 214 }: 215 bufsArrs[i] = device.GetMessageBuffer() 216 bufs[i] = bufsArrs[i][:] 217 default: 218 } 219 } 220 for peer, elems := range elemsByPeer { 221 if peer.isRunning.Load() { 222 peer.queue.inbound.c <- elems 223 for _, elem := range *elems { 224 device.queue.decryption.c <- elem 225 } 226 } else { 227 for _, elem := range *elems { 228 device.PutMessageBuffer(elem.buffer) 229 device.PutInboundElement(elem) 230 } 231 device.PutInboundElementsSlice(elems) 232 } 233 delete(elemsByPeer, peer) 234 } 235 } 236 } 237 238 func (device *Device) RoutineDecryption(id int) { 239 var nonce [chacha20poly1305.NonceSize]byte 240 241 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 242 device.log.Verbosef("Routine: decryption worker %d - started", id) 243 244 for elem := range device.queue.decryption.c { 245 // split message into fields 246 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 247 content := elem.packet[MessageTransportOffsetContent:] 248 249 // decrypt and release to consumer 250 var err error 251 elem.counter = binary.LittleEndian.Uint64(counter) 252 // copy counter to nonce 253 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 254 elem.packet, err = elem.keypair.receive.Open( 255 content[:0], 256 nonce[:], 257 content, 258 nil, 259 ) 260 if err != nil { 261 elem.packet = nil 262 } 263 elem.Unlock() 264 } 265 } 266 267 /* Handles incoming packets related to handshake 268 */ 269 func (device *Device) RoutineHandshake(id int) { 270 defer func() { 271 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 272 device.queue.encryption.wg.Done() 273 }() 274 device.log.Verbosef("Routine: handshake worker %d - started", id) 275 276 for elem := range device.queue.handshake.c { 277 278 // handle cookie fields and ratelimiting 279 280 switch elem.msgType { 281 282 case MessageCookieReplyType: 283 284 // unmarshal packet 285 286 var reply MessageCookieReply 287 reader := bytes.NewReader(elem.packet) 288 err := binary.Read(reader, binary.LittleEndian, &reply) 289 if err != nil { 290 device.log.Verbosef("Failed to decode cookie reply") 291 goto skip 292 } 293 294 // lookup peer from index 295 296 entry := device.indexTable.Lookup(reply.Receiver) 297 298 if entry.peer == nil { 299 goto skip 300 } 301 302 // consume reply 303 304 if peer := entry.peer; peer.isRunning.Load() { 305 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) 306 if !peer.cookieGenerator.ConsumeReply(&reply) { 307 device.log.Verbosef("Could not decrypt invalid cookie response") 308 } 309 } 310 311 goto skip 312 313 case MessageInitiationType, MessageResponseType: 314 315 // check mac fields and maybe ratelimit 316 317 if !device.cookieChecker.CheckMAC1(elem.packet) { 318 device.log.Verbosef("Received packet with invalid mac1") 319 goto skip 320 } 321 322 // endpoints destination address is the source of the datagram 323 324 if device.IsUnderLoad() { 325 326 // verify MAC2 field 327 328 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 329 device.SendHandshakeCookie(&elem) 330 goto skip 331 } 332 333 // check ratelimiter 334 335 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 336 goto skip 337 } 338 } 339 340 default: 341 device.log.Errorf("Invalid packet ended up in the handshake queue") 342 goto skip 343 } 344 345 // handle handshake initiation/response content 346 347 switch elem.msgType { 348 case MessageInitiationType: 349 350 // unmarshal 351 352 var msg MessageInitiation 353 reader := bytes.NewReader(elem.packet) 354 err := binary.Read(reader, binary.LittleEndian, &msg) 355 if err != nil { 356 device.log.Errorf("Failed to decode initiation message") 357 goto skip 358 } 359 360 // consume initiation 361 362 peer := device.ConsumeMessageInitiation(&msg) 363 if peer == nil { 364 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 365 goto skip 366 } 367 368 // update timers 369 370 peer.timersAnyAuthenticatedPacketTraversal() 371 peer.timersAnyAuthenticatedPacketReceived() 372 373 // update endpoint 374 peer.SetEndpointFromPacket(elem.endpoint) 375 376 device.log.Verbosef("%v - Received handshake initiation", peer) 377 peer.rxBytes.Add(uint64(len(elem.packet))) 378 379 peer.SendHandshakeResponse() 380 381 case MessageResponseType: 382 383 // unmarshal 384 385 var msg MessageResponse 386 reader := bytes.NewReader(elem.packet) 387 err := binary.Read(reader, binary.LittleEndian, &msg) 388 if err != nil { 389 device.log.Errorf("Failed to decode response message") 390 goto skip 391 } 392 393 // consume response 394 395 peer := device.ConsumeMessageResponse(&msg) 396 if peer == nil { 397 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 398 goto skip 399 } 400 401 // update endpoint 402 peer.SetEndpointFromPacket(elem.endpoint) 403 404 device.log.Verbosef("%v - Received handshake response", peer) 405 peer.rxBytes.Add(uint64(len(elem.packet))) 406 407 // update timers 408 409 peer.timersAnyAuthenticatedPacketTraversal() 410 peer.timersAnyAuthenticatedPacketReceived() 411 412 // derive keypair 413 414 err = peer.BeginSymmetricSession() 415 416 if err != nil { 417 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 418 goto skip 419 } 420 421 peer.timersSessionDerived() 422 peer.timersHandshakeComplete() 423 peer.SendKeepalive() 424 } 425 skip: 426 device.PutMessageBuffer(elem.buffer) 427 } 428 } 429 430 func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { 431 device := peer.device 432 defer func() { 433 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 434 peer.stopping.Done() 435 }() 436 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 437 438 bufs := make([][]byte, 0, maxBatchSize) 439 440 for elems := range peer.queue.inbound.c { 441 if elems == nil { 442 return 443 } 444 for _, elem := range *elems { 445 elem.Lock() 446 if elem.packet == nil { 447 // decryption failed 448 continue 449 } 450 451 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 452 continue 453 } 454 455 peer.SetEndpointFromPacket(elem.endpoint) 456 if peer.ReceivedWithKeypair(elem.keypair) { 457 peer.timersHandshakeComplete() 458 peer.SendStagedPackets() 459 } 460 peer.keepKeyFreshReceiving() 461 peer.timersAnyAuthenticatedPacketTraversal() 462 peer.timersAnyAuthenticatedPacketReceived() 463 peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) 464 465 if len(elem.packet) == 0 { 466 device.log.Verbosef("%v - Receiving keepalive packet", peer) 467 continue 468 } 469 peer.timersDataReceived() 470 471 switch elem.packet[0] >> 4 { 472 case 4: 473 if len(elem.packet) < ipv4.HeaderLen { 474 continue 475 } 476 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 477 length := binary.BigEndian.Uint16(field) 478 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 479 continue 480 } 481 elem.packet = elem.packet[:length] 482 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 483 if device.allowedips.Lookup(src) != peer { 484 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) 485 continue 486 } 487 488 case 6: 489 if len(elem.packet) < ipv6.HeaderLen { 490 continue 491 } 492 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 493 length := binary.BigEndian.Uint16(field) 494 length += ipv6.HeaderLen 495 if int(length) > len(elem.packet) { 496 continue 497 } 498 elem.packet = elem.packet[:length] 499 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 500 if device.allowedips.Lookup(src) != peer { 501 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 502 continue 503 } 504 505 default: 506 device.log.Verbosef("Packet with invalid IP version from %v", peer) 507 continue 508 } 509 510 bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) 511 } 512 if len(bufs) > 0 { 513 _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) 514 if err != nil && !device.isClosed() { 515 device.log.Errorf("Failed to write packets to TUN device: %v", err) 516 } 517 } 518 for _, elem := range *elems { 519 device.PutMessageBuffer(elem.buffer) 520 device.PutInboundElement(elem) 521 } 522 bufs = bufs[:0] 523 device.PutInboundElementsSlice(elems) 524 } 525 }