github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/receive.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "net" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 "github.com/liloew/wireguard-go/conn" 21 ) 22 23 type QueueHandshakeElement struct { 24 msgType uint32 25 packet []byte 26 endpoint conn.Endpoint 27 buffer *[MaxMessageSize]byte 28 } 29 30 type QueueInboundElement struct { 31 sync.Mutex 32 buffer *[MaxMessageSize]byte 33 packet []byte 34 counter uint64 35 keypair *Keypair 36 endpoint conn.Endpoint 37 } 38 39 // clearPointers clears elem fields that contain pointers. 40 // This makes the garbage collector's life easier and 41 // avoids accidentally keeping other objects around unnecessarily. 42 // It also reduces the possible collateral damage from use-after-free bugs. 43 func (elem *QueueInboundElement) clearPointers() { 44 elem.buffer = nil 45 elem.packet = nil 46 elem.keypair = nil 47 elem.endpoint = nil 48 } 49 50 /* Called when a new authenticated message has been received 51 * 52 * NOTE: Not thread safe, but called by sequential receiver! 53 */ 54 func (peer *Peer) keepKeyFreshReceiving() { 55 if peer.timers.sentLastMinuteHandshake.Get() { 56 return 57 } 58 keypair := peer.keypairs.Current() 59 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 60 peer.timers.sentLastMinuteHandshake.Set(true) 61 peer.SendHandshakeInitiation(false) 62 } 63 } 64 65 /* Receives incoming datagrams for the device 66 * 67 * Every time the bind is updated a new routine is started for 68 * IPv4 and IPv6 (separately) 69 */ 70 func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { 71 recvName := recv.PrettyName() 72 defer func() { 73 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) 74 device.queue.decryption.wg.Done() 75 device.queue.handshake.wg.Done() 76 device.net.stopping.Done() 77 }() 78 79 device.log.Verbosef("Routine: receive incoming %s - started", recvName) 80 81 // receive datagrams until conn is closed 82 83 buffer := device.GetMessageBuffer() 84 85 var ( 86 err error 87 size int 88 endpoint conn.Endpoint 89 deathSpiral int 90 ) 91 92 for { 93 size, endpoint, err = recv(buffer[:]) 94 95 if err != nil { 96 device.PutMessageBuffer(buffer) 97 if errors.Is(err, net.ErrClosed) { 98 return 99 } 100 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) 101 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { 102 return 103 } 104 if deathSpiral < 10 { 105 deathSpiral++ 106 time.Sleep(time.Second / 3) 107 buffer = device.GetMessageBuffer() 108 continue 109 } 110 return 111 } 112 deathSpiral = 0 113 114 if size < MinMessageSize { 115 continue 116 } 117 118 // check size of packet 119 120 packet := buffer[:size] 121 msgType := binary.LittleEndian.Uint32(packet[:4]) 122 123 var okay bool 124 125 switch msgType { 126 127 // check if transport 128 129 case MessageTransportType: 130 131 // check size 132 133 if len(packet) < MessageTransportSize { 134 continue 135 } 136 137 // lookup key pair 138 139 receiver := binary.LittleEndian.Uint32( 140 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 141 ) 142 value := device.indexTable.Lookup(receiver) 143 keypair := value.keypair 144 if keypair == nil { 145 continue 146 } 147 148 // check keypair expiry 149 150 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 151 continue 152 } 153 154 // create work element 155 peer := value.peer 156 elem := device.GetInboundElement() 157 elem.packet = packet 158 elem.buffer = buffer 159 elem.keypair = keypair 160 elem.endpoint = endpoint 161 elem.counter = 0 162 elem.Mutex = sync.Mutex{} 163 elem.Lock() 164 165 // add to decryption queues 166 if peer.isRunning.Get() { 167 peer.queue.inbound.c <- elem 168 device.queue.decryption.c <- elem 169 buffer = device.GetMessageBuffer() 170 } else { 171 device.PutInboundElement(elem) 172 } 173 continue 174 175 // otherwise it is a fixed size & handshake related packet 176 177 case MessageInitiationType: 178 okay = len(packet) == MessageInitiationSize 179 180 case MessageResponseType: 181 okay = len(packet) == MessageResponseSize 182 183 case MessageCookieReplyType: 184 okay = len(packet) == MessageCookieReplySize 185 186 default: 187 device.log.Verbosef("Received message with unknown type") 188 } 189 190 if okay { 191 select { 192 case device.queue.handshake.c <- QueueHandshakeElement{ 193 msgType: msgType, 194 buffer: buffer, 195 packet: packet, 196 endpoint: endpoint, 197 }: 198 buffer = device.GetMessageBuffer() 199 default: 200 } 201 } 202 } 203 } 204 205 func (device *Device) RoutineDecryption(id int) { 206 var nonce [chacha20poly1305.NonceSize]byte 207 208 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) 209 device.log.Verbosef("Routine: decryption worker %d - started", id) 210 211 for elem := range device.queue.decryption.c { 212 // split message into fields 213 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 214 content := elem.packet[MessageTransportOffsetContent:] 215 216 // decrypt and release to consumer 217 var err error 218 elem.counter = binary.LittleEndian.Uint64(counter) 219 // copy counter to nonce 220 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 221 elem.packet, err = elem.keypair.receive.Open( 222 content[:0], 223 nonce[:], 224 content, 225 nil, 226 ) 227 if err != nil { 228 elem.packet = nil 229 } 230 elem.Unlock() 231 } 232 } 233 234 /* Handles incoming packets related to handshake 235 */ 236 func (device *Device) RoutineHandshake(id int) { 237 defer func() { 238 device.log.Verbosef("Routine: handshake worker %d - stopped", id) 239 device.queue.encryption.wg.Done() 240 }() 241 device.log.Verbosef("Routine: handshake worker %d - started", id) 242 243 for elem := range device.queue.handshake.c { 244 245 // handle cookie fields and ratelimiting 246 247 switch elem.msgType { 248 249 case MessageCookieReplyType: 250 251 // unmarshal packet 252 253 var reply MessageCookieReply 254 reader := bytes.NewReader(elem.packet) 255 err := binary.Read(reader, binary.LittleEndian, &reply) 256 if err != nil { 257 device.log.Verbosef("Failed to decode cookie reply") 258 goto skip 259 } 260 261 // lookup peer from index 262 263 entry := device.indexTable.Lookup(reply.Receiver) 264 265 if entry.peer == nil { 266 goto skip 267 } 268 269 // consume reply 270 271 if peer := entry.peer; peer.isRunning.Get() { 272 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) 273 if !peer.cookieGenerator.ConsumeReply(&reply) { 274 device.log.Verbosef("Could not decrypt invalid cookie response") 275 } 276 } 277 278 goto skip 279 280 case MessageInitiationType, MessageResponseType: 281 282 // check mac fields and maybe ratelimit 283 284 if !device.cookieChecker.CheckMAC1(elem.packet) { 285 device.log.Verbosef("Received packet with invalid mac1") 286 goto skip 287 } 288 289 // endpoints destination address is the source of the datagram 290 291 if device.IsUnderLoad() { 292 293 // verify MAC2 field 294 295 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 296 device.SendHandshakeCookie(&elem) 297 goto skip 298 } 299 300 // check ratelimiter 301 302 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { 303 goto skip 304 } 305 } 306 307 default: 308 device.log.Errorf("Invalid packet ended up in the handshake queue") 309 goto skip 310 } 311 312 // handle handshake initiation/response content 313 314 switch elem.msgType { 315 case MessageInitiationType: 316 317 // unmarshal 318 319 var msg MessageInitiation 320 reader := bytes.NewReader(elem.packet) 321 err := binary.Read(reader, binary.LittleEndian, &msg) 322 if err != nil { 323 device.log.Errorf("Failed to decode initiation message") 324 goto skip 325 } 326 327 // consume initiation 328 329 peer := device.ConsumeMessageInitiation(&msg) 330 if peer == nil { 331 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) 332 goto skip 333 } 334 335 // update timers 336 337 peer.timersAnyAuthenticatedPacketTraversal() 338 peer.timersAnyAuthenticatedPacketReceived() 339 340 // update endpoint 341 peer.SetEndpointFromPacket(elem.endpoint) 342 343 device.log.Verbosef("%v - Received handshake initiation", peer) 344 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 345 346 peer.SendHandshakeResponse() 347 348 case MessageResponseType: 349 350 // unmarshal 351 352 var msg MessageResponse 353 reader := bytes.NewReader(elem.packet) 354 err := binary.Read(reader, binary.LittleEndian, &msg) 355 if err != nil { 356 device.log.Errorf("Failed to decode response message") 357 goto skip 358 } 359 360 // consume response 361 362 peer := device.ConsumeMessageResponse(&msg) 363 if peer == nil { 364 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) 365 goto skip 366 } 367 368 // update endpoint 369 peer.SetEndpointFromPacket(elem.endpoint) 370 371 device.log.Verbosef("%v - Received handshake response", peer) 372 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) 373 374 // update timers 375 376 peer.timersAnyAuthenticatedPacketTraversal() 377 peer.timersAnyAuthenticatedPacketReceived() 378 379 // derive keypair 380 381 err = peer.BeginSymmetricSession() 382 383 if err != nil { 384 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) 385 goto skip 386 } 387 388 peer.timersSessionDerived() 389 peer.timersHandshakeComplete() 390 peer.SendKeepalive() 391 } 392 skip: 393 device.PutMessageBuffer(elem.buffer) 394 } 395 } 396 397 func (peer *Peer) RoutineSequentialReceiver() { 398 device := peer.device 399 defer func() { 400 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) 401 peer.stopping.Done() 402 }() 403 device.log.Verbosef("%v - Routine: sequential receiver - started", peer) 404 405 for elem := range peer.queue.inbound.c { 406 if elem == nil { 407 return 408 } 409 var err error 410 elem.Lock() 411 if elem.packet == nil { 412 // decryption failed 413 goto skip 414 } 415 416 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 417 goto skip 418 } 419 420 peer.SetEndpointFromPacket(elem.endpoint) 421 if peer.ReceivedWithKeypair(elem.keypair) { 422 peer.timersHandshakeComplete() 423 peer.SendStagedPackets() 424 } 425 426 peer.keepKeyFreshReceiving() 427 peer.timersAnyAuthenticatedPacketTraversal() 428 peer.timersAnyAuthenticatedPacketReceived() 429 atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) 430 431 if len(elem.packet) == 0 { 432 device.log.Verbosef("%v - Receiving keepalive packet", peer) 433 goto skip 434 } 435 peer.timersDataReceived() 436 437 switch elem.packet[0] >> 4 { 438 case ipv4.Version: 439 if len(elem.packet) < ipv4.HeaderLen { 440 goto skip 441 } 442 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] 443 length := binary.BigEndian.Uint16(field) 444 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { 445 goto skip 446 } 447 elem.packet = elem.packet[:length] 448 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] 449 if device.allowedips.Lookup(src) != peer { 450 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) 451 goto skip 452 } 453 454 case ipv6.Version: 455 if len(elem.packet) < ipv6.HeaderLen { 456 goto skip 457 } 458 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] 459 length := binary.BigEndian.Uint16(field) 460 length += ipv6.HeaderLen 461 if int(length) > len(elem.packet) { 462 goto skip 463 } 464 elem.packet = elem.packet[:length] 465 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] 466 if device.allowedips.Lookup(src) != peer { 467 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) 468 goto skip 469 } 470 471 default: 472 device.log.Verbosef("Packet with invalid IP version from %v", peer) 473 goto skip 474 } 475 476 _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) 477 if err != nil && !device.isClosed() { 478 device.log.Errorf("Failed to write packet to TUN device: %v", err) 479 } 480 if len(peer.queue.inbound.c) == 0 { 481 err = device.tun.device.Flush() 482 if err != nil { 483 peer.device.log.Errorf("Unable to flush packets: %v", err) 484 } 485 } 486 skip: 487 device.PutMessageBuffer(elem.buffer) 488 device.PutInboundElement(elem) 489 } 490 }