github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/tun_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 /* Implementation of the TUN device interface for linux 9 */ 10 11 import ( 12 "errors" 13 "fmt" 14 "os" 15 "sync" 16 "syscall" 17 "time" 18 "unsafe" 19 20 "github.com/amnezia-vpn/amneziawg-go/conn" 21 "github.com/amnezia-vpn/amneziawg-go/rwcancel" 22 "golang.org/x/sys/unix" 23 ) 24 25 const ( 26 cloneDevicePath = "/dev/net/tun" 27 ifReqSize = unix.IFNAMSIZ + 64 28 ) 29 30 type NativeTun struct { 31 tunFile *os.File 32 index int32 // if index 33 errors chan error // async error handling 34 events chan Event // device related events 35 netlinkSock int 36 netlinkCancel *rwcancel.RWCancel 37 hackListenerClosed sync.Mutex 38 statusListenersShutdown chan struct{} 39 batchSize int 40 vnetHdr bool 41 udpGSO bool 42 43 closeOnce sync.Once 44 45 nameOnce sync.Once // guards calling initNameCache, which sets following fields 46 nameCache string // name of interface 47 nameErr error 48 49 readOpMu sync.Mutex // readOpMu guards readBuff 50 readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr 51 52 writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable 53 toWrite []int 54 tcpGROTable *tcpGROTable 55 udpGROTable *udpGROTable 56 } 57 58 func (tun *NativeTun) File() *os.File { 59 return tun.tunFile 60 } 61 62 func (tun *NativeTun) routineHackListener() { 63 defer tun.hackListenerClosed.Unlock() 64 /* This is needed for the detection to work across network namespaces 65 * If you are reading this and know a better method, please get in touch. 66 */ 67 last := 0 68 const ( 69 up = 1 70 down = 2 71 ) 72 for { 73 sysconn, err := tun.tunFile.SyscallConn() 74 if err != nil { 75 return 76 } 77 err2 := sysconn.Control(func(fd uintptr) { 78 _, err = unix.Write(int(fd), nil) 79 }) 80 if err2 != nil { 81 return 82 } 83 switch err { 84 case unix.EINVAL: 85 if last != up { 86 // If the tunnel is up, it reports that write() is 87 // allowed but we provided invalid data. 88 tun.events <- EventUp 89 last = up 90 } 91 case unix.EIO: 92 if last != down { 93 // If the tunnel is down, it reports that no I/O 94 // is possible, without checking our provided data. 95 tun.events <- EventDown 96 last = down 97 } 98 default: 99 return 100 } 101 select { 102 case <-time.After(time.Second): 103 // nothing 104 case <-tun.statusListenersShutdown: 105 return 106 } 107 } 108 } 109 110 func createNetlinkSocket() (int, error) { 111 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 112 if err != nil { 113 return -1, err 114 } 115 saddr := &unix.SockaddrNetlink{ 116 Family: unix.AF_NETLINK, 117 Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR, 118 } 119 err = unix.Bind(sock, saddr) 120 if err != nil { 121 return -1, err 122 } 123 return sock, nil 124 } 125 126 func (tun *NativeTun) routineNetlinkListener() { 127 defer func() { 128 unix.Close(tun.netlinkSock) 129 tun.hackListenerClosed.Lock() 130 close(tun.events) 131 tun.netlinkCancel.Close() 132 }() 133 134 for msg := make([]byte, 1<<16); ; { 135 var err error 136 var msgn int 137 for { 138 msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) 139 if err == nil || !rwcancel.RetryAfterError(err) { 140 break 141 } 142 if !tun.netlinkCancel.ReadyRead() { 143 tun.errors <- fmt.Errorf("netlink socket closed: %w", err) 144 return 145 } 146 } 147 if err != nil { 148 tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err) 149 return 150 } 151 152 select { 153 case <-tun.statusListenersShutdown: 154 return 155 default: 156 } 157 158 wasEverUp := false 159 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 160 161 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 162 163 if int(hdr.Len) > len(remain) { 164 break 165 } 166 167 switch hdr.Type { 168 case unix.NLMSG_DONE: 169 remain = []byte{} 170 171 case unix.RTM_NEWLINK: 172 info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) 173 remain = remain[hdr.Len:] 174 175 if info.Index != tun.index { 176 // not our interface 177 continue 178 } 179 180 if info.Flags&unix.IFF_RUNNING != 0 { 181 tun.events <- EventUp 182 wasEverUp = true 183 } 184 185 if info.Flags&unix.IFF_RUNNING == 0 { 186 // Don't emit EventDown before we've ever emitted EventUp. 187 // This avoids a startup race with HackListener, which 188 // might detect Up before we have finished reporting Down. 189 if wasEverUp { 190 tun.events <- EventDown 191 } 192 } 193 194 tun.events <- EventMTUUpdate 195 196 default: 197 remain = remain[hdr.Len:] 198 } 199 } 200 } 201 } 202 203 func getIFIndex(name string) (int32, error) { 204 fd, err := unix.Socket( 205 unix.AF_INET, 206 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 207 0, 208 ) 209 if err != nil { 210 return 0, err 211 } 212 213 defer unix.Close(fd) 214 215 var ifr [ifReqSize]byte 216 copy(ifr[:], name) 217 _, _, errno := unix.Syscall( 218 unix.SYS_IOCTL, 219 uintptr(fd), 220 uintptr(unix.SIOCGIFINDEX), 221 uintptr(unsafe.Pointer(&ifr[0])), 222 ) 223 224 if errno != 0 { 225 return 0, errno 226 } 227 228 return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil 229 } 230 231 func (tun *NativeTun) setMTU(n int) error { 232 name, err := tun.Name() 233 if err != nil { 234 return err 235 } 236 237 // open datagram socket 238 fd, err := unix.Socket( 239 unix.AF_INET, 240 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 241 0, 242 ) 243 if err != nil { 244 return err 245 } 246 247 defer unix.Close(fd) 248 249 // do ioctl call 250 var ifr [ifReqSize]byte 251 copy(ifr[:], name) 252 *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) 253 _, _, errno := unix.Syscall( 254 unix.SYS_IOCTL, 255 uintptr(fd), 256 uintptr(unix.SIOCSIFMTU), 257 uintptr(unsafe.Pointer(&ifr[0])), 258 ) 259 260 if errno != 0 { 261 return fmt.Errorf("failed to set MTU of TUN device: %w", errno) 262 } 263 264 return nil 265 } 266 267 func (tun *NativeTun) MTU() (int, error) { 268 name, err := tun.Name() 269 if err != nil { 270 return 0, err 271 } 272 273 // open datagram socket 274 fd, err := unix.Socket( 275 unix.AF_INET, 276 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 277 0, 278 ) 279 if err != nil { 280 return 0, err 281 } 282 283 defer unix.Close(fd) 284 285 // do ioctl call 286 287 var ifr [ifReqSize]byte 288 copy(ifr[:], name) 289 _, _, errno := unix.Syscall( 290 unix.SYS_IOCTL, 291 uintptr(fd), 292 uintptr(unix.SIOCGIFMTU), 293 uintptr(unsafe.Pointer(&ifr[0])), 294 ) 295 if errno != 0 { 296 return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno) 297 } 298 299 return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil 300 } 301 302 func (tun *NativeTun) Name() (string, error) { 303 tun.nameOnce.Do(tun.initNameCache) 304 return tun.nameCache, tun.nameErr 305 } 306 307 func (tun *NativeTun) initNameCache() { 308 tun.nameCache, tun.nameErr = tun.nameSlow() 309 } 310 311 func (tun *NativeTun) nameSlow() (string, error) { 312 sysconn, err := tun.tunFile.SyscallConn() 313 if err != nil { 314 return "", err 315 } 316 var ifr [ifReqSize]byte 317 var errno syscall.Errno 318 err = sysconn.Control(func(fd uintptr) { 319 _, _, errno = unix.Syscall( 320 unix.SYS_IOCTL, 321 fd, 322 uintptr(unix.TUNGETIFF), 323 uintptr(unsafe.Pointer(&ifr[0])), 324 ) 325 }) 326 if err != nil { 327 return "", fmt.Errorf("failed to get name of TUN device: %w", err) 328 } 329 if errno != 0 { 330 return "", fmt.Errorf("failed to get name of TUN device: %w", errno) 331 } 332 return unix.ByteSliceToString(ifr[:]), nil 333 } 334 335 func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 336 tun.writeOpMu.Lock() 337 defer func() { 338 tun.tcpGROTable.reset() 339 tun.udpGROTable.reset() 340 tun.writeOpMu.Unlock() 341 }() 342 var ( 343 errs error 344 total int 345 ) 346 tun.toWrite = tun.toWrite[:0] 347 if tun.vnetHdr { 348 err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) 349 if err != nil { 350 return 0, err 351 } 352 offset -= virtioNetHdrLen 353 } else { 354 for i := range bufs { 355 tun.toWrite = append(tun.toWrite, i) 356 } 357 } 358 for _, bufsI := range tun.toWrite { 359 n, err := tun.tunFile.Write(bufs[bufsI][offset:]) 360 if errors.Is(err, syscall.EBADFD) { 361 return total, os.ErrClosed 362 } 363 if err != nil { 364 errs = errors.Join(errs, err) 365 } else { 366 total += n 367 } 368 } 369 return total, errs 370 } 371 372 // handleVirtioRead splits in into bufs, leaving offset bytes at the front of 373 // each buffer. It mutates sizes to reflect the size of each element of bufs, 374 // and returns the number of packets read. 375 func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { 376 var hdr virtioNetHdr 377 err := hdr.decode(in) 378 if err != nil { 379 return 0, err 380 } 381 in = in[virtioNetHdrLen:] 382 if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { 383 if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { 384 // This means CHECKSUM_PARTIAL in skb context. We are responsible 385 // for computing the checksum starting at hdr.csumStart and placing 386 // at hdr.csumOffset. 387 err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) 388 if err != nil { 389 return 0, err 390 } 391 } 392 if len(in) > len(bufs[0][offset:]) { 393 return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) 394 } 395 n := copy(bufs[0][offset:], in) 396 sizes[0] = n 397 return 1, nil 398 } 399 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { 400 return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) 401 } 402 403 ipVersion := in[0] >> 4 404 switch ipVersion { 405 case 4: 406 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { 407 return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) 408 } 409 case 6: 410 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { 411 return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) 412 } 413 default: 414 return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) 415 } 416 417 // Don't trust hdr.hdrLen from the kernel as it can be equal to the length 418 // of the entire first packet when the kernel is handling it as part of a 419 // FORWARD path. Instead, parse the transport header length and add it onto 420 // csumStart, which is synonymous for IP header length. 421 if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { 422 hdr.hdrLen = hdr.csumStart + 8 423 } else { 424 if len(in) <= int(hdr.csumStart+12) { 425 return 0, errors.New("packet is too short") 426 } 427 428 tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) 429 if tcpHLen < 20 || tcpHLen > 60 { 430 // A TCP header must be between 20 and 60 bytes in length. 431 return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) 432 } 433 hdr.hdrLen = hdr.csumStart + tcpHLen 434 } 435 436 if len(in) < int(hdr.hdrLen) { 437 return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) 438 } 439 440 if hdr.hdrLen < hdr.csumStart { 441 return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) 442 } 443 cSumAt := int(hdr.csumStart + hdr.csumOffset) 444 if cSumAt+1 >= len(in) { 445 return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) 446 } 447 448 return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) 449 } 450 451 func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 452 tun.readOpMu.Lock() 453 defer tun.readOpMu.Unlock() 454 select { 455 case err := <-tun.errors: 456 return 0, err 457 default: 458 readInto := bufs[0][offset:] 459 if tun.vnetHdr { 460 readInto = tun.readBuff[:] 461 } 462 n, err := tun.tunFile.Read(readInto) 463 if errors.Is(err, syscall.EBADFD) { 464 err = os.ErrClosed 465 } 466 if err != nil { 467 return 0, err 468 } 469 if tun.vnetHdr { 470 return handleVirtioRead(readInto[:n], bufs, sizes, offset) 471 } else { 472 sizes[0] = n 473 return 1, nil 474 } 475 } 476 } 477 478 func (tun *NativeTun) Events() <-chan Event { 479 return tun.events 480 } 481 482 func (tun *NativeTun) Close() error { 483 var err1, err2 error 484 tun.closeOnce.Do(func() { 485 if tun.statusListenersShutdown != nil { 486 close(tun.statusListenersShutdown) 487 if tun.netlinkCancel != nil { 488 err1 = tun.netlinkCancel.Cancel() 489 } 490 } else if tun.events != nil { 491 close(tun.events) 492 } 493 err2 = tun.tunFile.Close() 494 }) 495 if err1 != nil { 496 return err1 497 } 498 return err2 499 } 500 501 func (tun *NativeTun) BatchSize() int { 502 return tun.batchSize 503 } 504 505 const ( 506 // TODO: support TSO with ECN bits 507 tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 508 tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 509 ) 510 511 func (tun *NativeTun) initFromFlags(name string) error { 512 sc, err := tun.tunFile.SyscallConn() 513 if err != nil { 514 return err 515 } 516 if e := sc.Control(func(fd uintptr) { 517 var ( 518 ifr *unix.Ifreq 519 ) 520 ifr, err = unix.NewIfreq(name) 521 if err != nil { 522 return 523 } 524 err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) 525 if err != nil { 526 return 527 } 528 got := ifr.Uint16() 529 if got&unix.IFF_VNET_HDR != 0 { 530 // tunTCPOffloads were added in Linux v2.6. We require their support 531 // if IFF_VNET_HDR is set. 532 err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) 533 if err != nil { 534 return 535 } 536 tun.vnetHdr = true 537 tun.batchSize = conn.IdealBatchSize 538 // tunUDPOffloads were added in Linux v6.2. We do not return an 539 // error if they are unsupported at runtime. 540 tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil 541 } else { 542 tun.batchSize = 1 543 } 544 }); e != nil { 545 return e 546 } 547 return err 548 } 549 550 // CreateTUN creates a Device with the provided name and MTU. 551 func CreateTUN(name string, mtu int) (Device, error) { 552 nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) 553 if err != nil { 554 if os.IsNotExist(err) { 555 return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) 556 } 557 return nil, err 558 } 559 560 ifr, err := unix.NewIfreq(name) 561 if err != nil { 562 return nil, err 563 } 564 // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() 565 // where a null write will return EINVAL indicating the TUN is up. 566 ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) 567 err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) 568 if err != nil { 569 return nil, err 570 } 571 572 err = unix.SetNonblock(nfd, true) 573 if err != nil { 574 unix.Close(nfd) 575 return nil, err 576 } 577 578 // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. 579 580 fd := os.NewFile(uintptr(nfd), cloneDevicePath) 581 return CreateTUNFromFile(fd, mtu) 582 } 583 584 // CreateTUNFromFile creates a Device from an os.File with the provided MTU. 585 func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 586 tun := &NativeTun{ 587 tunFile: file, 588 events: make(chan Event, 5), 589 errors: make(chan error, 5), 590 statusListenersShutdown: make(chan struct{}), 591 tcpGROTable: newTCPGROTable(), 592 udpGROTable: newUDPGROTable(), 593 toWrite: make([]int, 0, conn.IdealBatchSize), 594 } 595 596 name, err := tun.Name() 597 if err != nil { 598 return nil, err 599 } 600 601 err = tun.initFromFlags(name) 602 if err != nil { 603 return nil, err 604 } 605 606 // start event listener 607 tun.index, err = getIFIndex(name) 608 if err != nil { 609 return nil, err 610 } 611 612 tun.netlinkSock, err = createNetlinkSocket() 613 if err != nil { 614 return nil, err 615 } 616 tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock) 617 if err != nil { 618 unix.Close(tun.netlinkSock) 619 return nil, err 620 } 621 622 tun.hackListenerClosed.Lock() 623 go tun.routineNetlinkListener() 624 go tun.routineHackListener() // cross namespace 625 626 err = tun.setMTU(mtu) 627 if err != nil { 628 unix.Close(tun.netlinkSock) 629 return nil, err 630 } 631 632 return tun, nil 633 } 634 635 // CreateUnmonitoredTUNFromFD creates a Device from the provided file 636 // descriptor. 637 func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { 638 err := unix.SetNonblock(fd, true) 639 if err != nil { 640 return nil, "", err 641 } 642 file := os.NewFile(uintptr(fd), "/dev/tun") 643 tun := &NativeTun{ 644 tunFile: file, 645 events: make(chan Event, 5), 646 errors: make(chan error, 5), 647 tcpGROTable: newTCPGROTable(), 648 udpGROTable: newUDPGROTable(), 649 toWrite: make([]int, 0, conn.IdealBatchSize), 650 } 651 name, err := tun.Name() 652 if err != nil { 653 return nil, "", err 654 } 655 err = tun.initFromFlags(name) 656 if err != nil { 657 return nil, "", err 658 } 659 return tun, name, err 660 }