github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/device/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 21 "github.com/bugfan/wireguard-go/conn" 22 ) 23 24 type QueueHandshakeElement struct { 25 msgType uint32 26 packet []byte 27 endpoint conn.Endpoint 28 buffer *[MaxMessageSize]byte 29 } 30 31 type QueueInboundElement struct { 32 sync.Mutex 33 buffer *[MaxMessageSize]byte 34 packet []byte 35 counter uint64 36 keypair *Keypair 37 endpoint conn.Endpoint 38 } 39 40 // clearPointers clears elem fields that contain pointers. 41 // This makes the garbage collector's life easier and 42 // avoids accidentally keeping other objects around unnecessarily. 43 // It also reduces the possible collateral damage from use-after-free bugs. 44 func (elem *QueueInboundElement) clearPointers() { 45 elem.buffer = nil 46 elem.packet = nil 47 elem.keypair = nil 48 elem.endpoint = nil 49 } 50 51 /* Called when a new authenticated message has been received 52 * 53 * NOTE: Not thread safe, but called by sequential receiver! 54 */ 55 func (peer *Peer) keepKeyFreshReceiving() { 56 if peer.timers.sentLastMinuteHandshake.Get() { 57 return 58 } 59 keypair := peer.keypairs.Current() 60 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 61 peer.timers.sentLastMinuteHandshake.Set(true) 62 peer.SendHandshakeInitiation(false) 63 } 64 } 65 66 /* Receives incoming datagrams for the device 67 * 68 * Every time the bind is updated a new routine is started for 69 * IPv4 and IPv6 (separately) 70 */ 71 func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { 72 recvName := recv.PrettyName() 73 defer func() { 74 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 75 device.queue.decryption.wg.Done() 76 device.queue.handshake.wg.Done() 77 device.net.stopping.Done() 78 }() 79 80 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 81 82 // receive datagrams until conn is closed 83 84 buffer := device.GetMessageBuffer() 85 86 var ( 87 err error 88 size int 89 endpoint conn.Endpoint 90 deathSpiral int 91 ) 92 93 for { 94 size, endpoint, err = recv(buffer[:]) 95 96 if err != nil { 97 device.PutMessageBuffer(buffer) 98 if errors.Is(err, net.ErrClosed) { 99 return 100 } 101 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 102 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 103 return 104 } 105 if deathSpiral < 10 { 106 deathSpiral++ 107 time.Sleep(time.Second / 3) 108 buffer = device.GetMessageBuffer() 109 continue 110 } 111 return 112 } 113 deathSpiral = 0 114 115 if size < MinMessageSize { 116 continue 117 } 118 119 // check size of packet 120 121 packet := buffer[:size] 122 msgType := binary.LittleEndian.Uint32(packet[:4]) 123 124 var okay bool 125 126 switch msgType { 127 128 // check if transport 129 130 case MessageTransportType: 131 132 // check size 133 134 if len(packet) < MessageTransportSize { 135 continue 136 } 137 138 // lookup key pair 139 140 receiver := binary.LittleEndian.Uint32( 141 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 142 ) 143 value := device.indexTable.Lookup(receiver) 144 keypair := value.keypair 145 if keypair == nil { 146 continue 147 } 148 149 // check keypair expiry 150 151 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 152 continue 153 } 154 155 // create work element 156 peer := value.peer 157 elem := device.GetInboundElement() 158 elem.packet = packet 159 elem.buffer = buffer 160 elem.keypair = keypair 161 elem.endpoint = endpoint 162 elem.counter = 0 163 elem.Mutex = sync.Mutex{} 164 elem.Lock() 165 166 // add to decryption queues 167 if peer.isRunning.Get() { 168 peer.queue.inbound.c <- elem 169 device.queue.decryption.c <- elem 170 buffer = device.GetMessageBuffer() 171 } else { 172 device.PutInboundElement(elem) 173 } 174 continue 175 176 // otherwise it is a fixed size & handshake related packet 177 178 case MessageInitiationType: 179 okay = len(packet) == MessageInitiationSize 180 181 case MessageResponseType: 182 okay = len(packet) == MessageResponseSize 183 184 case MessageCookieReplyType: 185 okay = len(packet) == MessageCookieReplySize 186 187 default: 188 device.log.Verbosef("Received message with unknown type") 189 } 190 191 if okay { 192 select { 193 case device.queue.handshake.c <- QueueHandshakeElement{ 194 msgType: msgType, 195 buffer: buffer, 196 packet: packet, 197 endpoint: endpoint, 198 }: 199 buffer = device.GetMessageBuffer() 200 default: 201 } 202 } 203 } 204 } 205 206 func (device *Device) RoutineDecryption(id int) { 207 var nonce [chacha20poly1305.NonceSize]byte 208 209 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 210 device.log.Verbosef("Routine: decryption worker %d - started", id) 211 212 for elem := range device.queue.decryption.c { 213 // split message into fields 214 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 215 content := elem.packet[MessageTransportOffsetContent:] 216 217 // decrypt and release to consumer 218 var err error 219 elem.counter = binary.LittleEndian.Uint64(counter) 220 // copy counter to nonce 221 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 222 elem.packet, err = elem.keypair.receive.Open( 223 content[:0], 224 nonce[:], 225 content, 226 nil, 227 ) 228 if err != nil { 229 elem.packet = nil 230 } 231 elem.Unlock() 232 } 233 } 234 235 /* Handles incoming packets related to handshake 236 */ 237 func (device *Device) RoutineHandshake(id int) { 238 defer func() { 239 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 240 device.queue.encryption.wg.Done() 241 }() 242 device.log.Verbosef("Routine: handshake worker %d - started", id) 243 244 for elem := range device.queue.handshake.c { 245 246 // handle cookie fields and ratelimiting 247 248 switch elem.msgType { 249 250 case MessageCookieReplyType: 251 252 // unmarshal packet 253 254 var reply MessageCookieReply 255 reader := bytes.NewReader(elem.packet) 256 err := binary.Read(reader, binary.LittleEndian, &reply) 257 if err != nil { 258 device.log.Verbosef("Failed to decode cookie reply") 259 goto skip 260 } 261 262 // lookup peer from index 263 264 entry := device.indexTable.Lookup(reply.Receiver) 265 266 if entry.peer == nil { 267 goto skip 268 } 269 270 // consume reply 271 272 if peer := entry.peer; peer.isRunning.Get() { 273 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) 274 if !peer.cookieGenerator.ConsumeReply(&reply) { 275 device.log.Verbosef("Could not decrypt invalid cookie response") 276 } 277 } 278 279 goto skip 280 281 case MessageInitiationType, MessageResponseType: 282 283 // check mac fields and maybe ratelimit 284 285 if !device.cookieChecker.CheckMAC1(elem.packet) { 286 device.log.Verbosef("Received packet with invalid mac1") 287 goto skip 288 } 289 290 // endpoints destination address is the source of the datagram 291 292 if device.IsUnderLoad() { 293 294 // verify MAC2 field 295 296 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 297 device.SendHandshakeCookie(&elem) 298 goto skip 299 } 300 301 // check ratelimiter 302 303 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 304 goto skip 305 } 306 } 307 308 default: 309 device.log.Errorf("Invalid packet ended up in the handshake queue") 310 goto skip 311 } 312 313 // handle handshake initiation/response content 314 315 switch elem.msgType { 316 case MessageInitiationType: 317 318 // unmarshal 319 320 var msg MessageInitiation 321 reader := bytes.NewReader(elem.packet) 322 err := binary.Read(reader, binary.LittleEndian, &msg) 323 if err != nil { 324 device.log.Errorf("Failed to decode initiation message") 325 goto skip 326 } 327 328 // consume initiation 329 330 peer := device.ConsumeMessageInitiation(&msg) 331 if peer == nil { 332 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 333 goto skip 334 } 335 336 // update timers 337 338 peer.timersAnyAuthenticatedPacketTraversal() 339 peer.timersAnyAuthenticatedPacketReceived() 340 341 // update endpoint 342 peer.SetEndpointFromPacket(elem.endpoint) 343 344 device.log.Verbosef("%v - Received handshake initiation", peer) 345 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 346 347 peer.SendHandshakeResponse() 348 349 case MessageResponseType: 350 351 // unmarshal 352 353 var msg MessageResponse 354 reader := bytes.NewReader(elem.packet) 355 err := binary.Read(reader, binary.LittleEndian, &msg) 356 if err != nil { 357 device.log.Errorf("Failed to decode response message") 358 goto skip 359 } 360 361 // consume response 362 363 peer := device.ConsumeMessageResponse(&msg) 364 if peer == nil { 365 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 366 goto skip 367 } 368 369 // update endpoint 370 peer.SetEndpointFromPacket(elem.endpoint) 371 372 device.log.Verbosef("%v - Received handshake response", peer) 373 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 374 375 // update timers 376 377 peer.timersAnyAuthenticatedPacketTraversal() 378 peer.timersAnyAuthenticatedPacketReceived() 379 380 // derive keypair 381 382 err = peer.BeginSymmetricSession() 383 384 if err != nil { 385 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 386 goto skip 387 } 388 389 peer.timersSessionDerived() 390 peer.timersHandshakeComplete() 391 peer.SendKeepalive() 392 } 393 skip: 394 device.PutMessageBuffer(elem.buffer) 395 } 396 } 397 398 func (peer *Peer) RoutineSequentialReceiver() { 399 device := peer.device 400 defer func() { 401 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 402 peer.stopping.Done() 403 }() 404 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 405 406 for elem := range peer.queue.inbound.c { 407 if elem == nil { 408 return 409 } 410 var err error 411 elem.Lock() 412 if elem.packet == nil { 413 // decryption failed 414 goto skip 415 } 416 417 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 418 goto skip 419 } 420 421 peer.SetEndpointFromPacket(elem.endpoint) 422 if peer.ReceivedWithKeypair(elem.keypair) { 423 peer.timersHandshakeComplete() 424 peer.SendStagedPackets() 425 } 426 427 peer.keepKeyFreshReceiving() 428 peer.timersAnyAuthenticatedPacketTraversal() 429 peer.timersAnyAuthenticatedPacketReceived() 430 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) 431 432 if len(elem.packet) == 0 { 433 device.log.Verbosef("%v - Receiving keepalive packet", peer) 434 goto skip 435 } 436 peer.timersDataReceived() 437 438 switch elem.packet[0] >> 4 { 439 case ipv4.Version: 440 if len(elem.packet) < ipv4.HeaderLen { 441 goto skip 442 } 443 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 444 length := binary.BigEndian.Uint16(field) 445 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 446 goto skip 447 } 448 elem.packet = elem.packet[:length] 449 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 450 if device.allowedips.Lookup(src) != peer { 451 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) 452 goto skip 453 } 454 455 case ipv6.Version: 456 if len(elem.packet) < ipv6.HeaderLen { 457 goto skip 458 } 459 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 460 length := binary.BigEndian.Uint16(field) 461 length += ipv6.HeaderLen 462 if int(length) > len(elem.packet) { 463 goto skip 464 } 465 elem.packet = elem.packet[:length] 466 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 467 if device.allowedips.Lookup(src) != peer { 468 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 469 goto skip 470 } 471 472 default: 473 device.log.Verbosef("Packet with invalid IP version from %v", peer) 474 goto skip 475 } 476 477 _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) 478 if err != nil && !device.isClosed() { 479 device.log.Errorf("Failed to write packet to TUN device: %v", err) 480 } 481 if len(peer.queue.inbound.c) == 0 { 482 err = device.tun.device.Flush() 483 if err != nil { 484 peer.device.log.Errorf("Unable to flush packets: %v", err) 485 } 486 } 487 skip: 488 device.PutMessageBuffer(elem.buffer) 489 device.PutInboundElement(elem) 490 } 491 }