github.com/forest33/wtun@v0.3.1/tun/tun_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 /* Implementation of the TUN device interface for linux 9 */ 10 11 import ( 12 "errors" 13 "fmt" 14 "os" 15 "sync" 16 "syscall" 17 "time" 18 "unsafe" 19 20 "golang.org/x/sys/unix" 21 22 "github.com/forest33/wtun/rwcancel" 23 24 "github.com/forest33/wtun/conn" 25 ) 26 27 const ( 28 cloneDevicePath = "/dev/net/tun" 29 ifReqSize = unix.IFNAMSIZ + 64 30 ) 31 32 type NativeTun struct { 33 tunFile *os.File 34 index int32 // if index 35 errors chan error // async error handling 36 events chan Event // device related events 37 netlinkSock int 38 netlinkCancel *rwcancel.RWCancel 39 hackListenerClosed sync.Mutex 40 statusListenersShutdown chan struct{} 41 batchSize int 42 vnetHdr bool 43 44 closeOnce sync.Once 45 46 nameOnce sync.Once // guards calling initNameCache, which sets following fields 47 nameCache string // name of interface 48 nameErr error 49 50 readOpMu sync.Mutex // readOpMu guards readBuff 51 readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr 52 53 writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable 54 toWrite []int 55 tcp4GROTable, tcp6GROTable *tcpGROTable 56 } 57 58 func (tun *NativeTun) File() *os.File { 59 return tun.tunFile 60 } 61 62 func (tun *NativeTun) routineHackListener() { 63 defer tun.hackListenerClosed.Unlock() 64 /* This is needed for the detection to work across network namespaces 65 * If you are reading this and know a better method, please get in touch. 66 */ 67 last := 0 68 const ( 69 up = 1 70 down = 2 71 ) 72 for { 73 sysconn, err := tun.tunFile.SyscallConn() 74 if err != nil { 75 return 76 } 77 err2 := sysconn.Control(func(fd uintptr) { 78 _, err = unix.Write(int(fd), nil) 79 }) 80 if err2 != nil { 81 return 82 } 83 switch err { 84 case unix.EINVAL: 85 if last != up { 86 // If the tunnel is up, it reports that write() is 87 // allowed but we provided invalid data. 88 tun.events <- EventUp 89 last = up 90 } 91 case unix.EIO: 92 if last != down { 93 // If the tunnel is down, it reports that no I/O 94 // is possible, without checking our provided data. 95 tun.events <- EventDown 96 last = down 97 } 98 default: 99 return 100 } 101 select { 102 case <-time.After(time.Second): 103 // nothing 104 case <-tun.statusListenersShutdown: 105 return 106 } 107 } 108 } 109 110 func createNetlinkSocket() (int, error) { 111 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 112 if err != nil { 113 return -1, err 114 } 115 saddr := &unix.SockaddrNetlink{ 116 Family: unix.AF_NETLINK, 117 Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR, 118 } 119 err = unix.Bind(sock, saddr) 120 if err != nil { 121 return -1, err 122 } 123 return sock, nil 124 } 125 126 func (tun *NativeTun) routineNetlinkListener() { 127 defer func() { 128 unix.Close(tun.netlinkSock) 129 tun.hackListenerClosed.Lock() 130 close(tun.events) 131 tun.netlinkCancel.Close() 132 }() 133 134 for msg := make([]byte, 1<<16); ; { 135 var err error 136 var msgn int 137 for { 138 msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) 139 if err == nil || !rwcancel.RetryAfterError(err) { 140 break 141 } 142 if !tun.netlinkCancel.ReadyRead() { 143 tun.errors <- fmt.Errorf("netlink socket closed: %w", err) 144 return 145 } 146 } 147 if err != nil { 148 tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err) 149 return 150 } 151 152 select { 153 case <-tun.statusListenersShutdown: 154 return 155 default: 156 } 157 158 wasEverUp := false 159 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 160 161 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 162 163 if int(hdr.Len) > len(remain) { 164 break 165 } 166 167 switch hdr.Type { 168 case unix.NLMSG_DONE: 169 remain = []byte{} 170 171 case unix.RTM_NEWLINK: 172 info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) 173 remain = remain[hdr.Len:] 174 175 if info.Index != tun.index { 176 // not our interface 177 continue 178 } 179 180 if info.Flags&unix.IFF_RUNNING != 0 { 181 tun.events <- EventUp 182 wasEverUp = true 183 } 184 185 if info.Flags&unix.IFF_RUNNING == 0 { 186 // Don't emit EventDown before we've ever emitted EventUp. 187 // This avoids a startup race with HackListener, which 188 // might detect Up before we have finished reporting Down. 189 if wasEverUp { 190 tun.events <- EventDown 191 } 192 } 193 194 tun.events <- EventMTUUpdate 195 196 default: 197 remain = remain[hdr.Len:] 198 } 199 } 200 } 201 } 202 203 func getIFIndex(name string) (int32, error) { 204 fd, err := unix.Socket( 205 unix.AF_INET, 206 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 207 0, 208 ) 209 if err != nil { 210 return 0, err 211 } 212 213 defer unix.Close(fd) 214 215 var ifr [ifReqSize]byte 216 copy(ifr[:], name) 217 _, _, errno := unix.Syscall( 218 unix.SYS_IOCTL, 219 uintptr(fd), 220 uintptr(unix.SIOCGIFINDEX), 221 uintptr(unsafe.Pointer(&ifr[0])), 222 ) 223 224 if errno != 0 { 225 return 0, errno 226 } 227 228 return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil 229 } 230 231 func (tun *NativeTun) setMTU(n int) error { 232 name, err := tun.Name() 233 if err != nil { 234 return err 235 } 236 237 // open datagram socket 238 fd, err := unix.Socket( 239 unix.AF_INET, 240 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 241 0, 242 ) 243 if err != nil { 244 return err 245 } 246 247 defer unix.Close(fd) 248 249 // do ioctl call 250 var ifr [ifReqSize]byte 251 copy(ifr[:], name) 252 *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) 253 _, _, errno := unix.Syscall( 254 unix.SYS_IOCTL, 255 uintptr(fd), 256 uintptr(unix.SIOCSIFMTU), 257 uintptr(unsafe.Pointer(&ifr[0])), 258 ) 259 260 if errno != 0 { 261 return fmt.Errorf("failed to set MTU of TUN device: %w", errno) 262 } 263 264 return nil 265 } 266 267 func (tun *NativeTun) MTU() (int, error) { 268 name, err := tun.Name() 269 if err != nil { 270 return 0, err 271 } 272 273 // open datagram socket 274 fd, err := unix.Socket( 275 unix.AF_INET, 276 unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 277 0, 278 ) 279 if err != nil { 280 return 0, err 281 } 282 283 defer unix.Close(fd) 284 285 // do ioctl call 286 287 var ifr [ifReqSize]byte 288 copy(ifr[:], name) 289 _, _, errno := unix.Syscall( 290 unix.SYS_IOCTL, 291 uintptr(fd), 292 uintptr(unix.SIOCGIFMTU), 293 uintptr(unsafe.Pointer(&ifr[0])), 294 ) 295 if errno != 0 { 296 return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno) 297 } 298 299 return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil 300 } 301 302 func (tun *NativeTun) Name() (string, error) { 303 tun.nameOnce.Do(tun.initNameCache) 304 return tun.nameCache, tun.nameErr 305 } 306 307 func (tun *NativeTun) initNameCache() { 308 tun.nameCache, tun.nameErr = tun.nameSlow() 309 } 310 311 func (tun *NativeTun) nameSlow() (string, error) { 312 sysconn, err := tun.tunFile.SyscallConn() 313 if err != nil { 314 return "", err 315 } 316 var ifr [ifReqSize]byte 317 var errno syscall.Errno 318 err = sysconn.Control(func(fd uintptr) { 319 _, _, errno = unix.Syscall( 320 unix.SYS_IOCTL, 321 fd, 322 uintptr(unix.TUNGETIFF), 323 uintptr(unsafe.Pointer(&ifr[0])), 324 ) 325 }) 326 if err != nil { 327 return "", fmt.Errorf("failed to get name of TUN device: %w", err) 328 } 329 if errno != 0 { 330 return "", fmt.Errorf("failed to get name of TUN device: %w", errno) 331 } 332 return unix.ByteSliceToString(ifr[:]), nil 333 } 334 335 func (tun *NativeTun) WritePackets(bufs [][]byte, offset int) (int, error) { 336 tun.writeOpMu.Lock() 337 defer func() { 338 tun.tcp4GROTable.reset() 339 tun.tcp6GROTable.reset() 340 tun.writeOpMu.Unlock() 341 }() 342 var ( 343 errs error 344 total int 345 ) 346 tun.toWrite = tun.toWrite[:0] 347 if tun.vnetHdr { 348 err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) 349 if err != nil { 350 return 0, err 351 } 352 offset -= virtioNetHdrLen 353 } else { 354 for i := range bufs { 355 tun.toWrite = append(tun.toWrite, i) 356 } 357 } 358 for _, bufsI := range tun.toWrite { 359 n, err := tun.tunFile.Write(bufs[bufsI][offset:]) 360 if errors.Is(err, syscall.EBADFD) { 361 return total, os.ErrClosed 362 } 363 if err != nil { 364 errs = errors.Join(errs, err) 365 } else { 366 total += n 367 } 368 } 369 return total, errs 370 } 371 372 // handleVirtioRead splits in into bufs, leaving offset bytes at the front of 373 // each buffer. It mutates sizes to reflect the size of each element of bufs, 374 // and returns the number of packets read. 375 func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { 376 var hdr virtioNetHdr 377 err := hdr.decode(in) 378 if err != nil { 379 return 0, err 380 } 381 in = in[virtioNetHdrLen:] 382 if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { 383 if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { 384 // This means CHECKSUM_PARTIAL in skb context. We are responsible 385 // for computing the checksum starting at hdr.csumStart and placing 386 // at hdr.csumOffset. 387 err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) 388 if err != nil { 389 return 0, err 390 } 391 } 392 if len(in) > len(bufs[0][offset:]) { 393 return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) 394 } 395 n := copy(bufs[0][offset:], in) 396 sizes[0] = n 397 return 1, nil 398 } 399 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { 400 return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) 401 } 402 403 ipVersion := in[0] >> 4 404 switch ipVersion { 405 case 4: 406 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { 407 return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) 408 } 409 case 6: 410 if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { 411 return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) 412 } 413 default: 414 return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) 415 } 416 417 if len(in) <= int(hdr.csumStart+12) { 418 return 0, errors.New("packet is too short") 419 } 420 // Don't trust hdr.hdrLen from the kernel as it can be equal to the length 421 // of the entire first packet when the kernel is handling it as part of a 422 // FORWARD path. Instead, parse the TCP header length and add it onto 423 // csumStart, which is synonymous for IP header length. 424 tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) 425 if tcpHLen < 20 || tcpHLen > 60 { 426 // A TCP header must be between 20 and 60 bytes in length. 427 return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) 428 } 429 hdr.hdrLen = hdr.csumStart + tcpHLen 430 431 if len(in) < int(hdr.hdrLen) { 432 return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) 433 } 434 435 if hdr.hdrLen < hdr.csumStart { 436 return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) 437 } 438 cSumAt := int(hdr.csumStart + hdr.csumOffset) 439 if cSumAt+1 >= len(in) { 440 return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) 441 } 442 443 return tcpTSO(in, hdr, bufs, sizes, offset) 444 } 445 446 func (tun *NativeTun) ReadPackets(bufs [][]byte, sizes []int, offset int) (int, error) { 447 tun.readOpMu.Lock() 448 defer tun.readOpMu.Unlock() 449 select { 450 case err := <-tun.errors: 451 return 0, err 452 default: 453 readInto := bufs[0][offset:] 454 if tun.vnetHdr { 455 readInto = tun.readBuff[:] 456 } 457 n, err := tun.tunFile.Read(readInto) 458 if errors.Is(err, syscall.EBADFD) { 459 err = os.ErrClosed 460 } 461 if err != nil { 462 return 0, err 463 } 464 if tun.vnetHdr { 465 return handleVirtioRead(readInto[:n], bufs, sizes, offset) 466 } else { 467 sizes[0] = n 468 return 1, nil 469 } 470 } 471 } 472 473 func (tun *NativeTun) Events() <-chan Event { 474 return tun.events 475 } 476 477 func (tun *NativeTun) Close() error { 478 var err1, err2 error 479 tun.closeOnce.Do(func() { 480 if tun.statusListenersShutdown != nil { 481 close(tun.statusListenersShutdown) 482 if tun.netlinkCancel != nil { 483 err1 = tun.netlinkCancel.Cancel() 484 } 485 } else if tun.events != nil { 486 close(tun.events) 487 } 488 err2 = tun.tunFile.Close() 489 }) 490 if err1 != nil { 491 return err1 492 } 493 return err2 494 } 495 496 func (tun *NativeTun) BatchSize() int { 497 return tun.batchSize 498 } 499 500 const ( 501 // TODO: support TSO with ECN bits 502 tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 503 ) 504 505 func (tun *NativeTun) initFromFlags(name string) error { 506 sc, err := tun.tunFile.SyscallConn() 507 if err != nil { 508 return err 509 } 510 if e := sc.Control(func(fd uintptr) { 511 var ( 512 ifr *unix.Ifreq 513 ) 514 ifr, err = unix.NewIfreq(name) 515 if err != nil { 516 return 517 } 518 err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) 519 if err != nil { 520 return 521 } 522 got := ifr.Uint16() 523 if got&unix.IFF_VNET_HDR != 0 { 524 err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) 525 if err != nil { 526 return 527 } 528 tun.vnetHdr = true 529 tun.batchSize = conn.IdealBatchSize 530 } else { 531 tun.batchSize = 1 532 } 533 }); e != nil { 534 return e 535 } 536 return err 537 } 538 539 // CreateTUN creates a Device with the provided name and MTU. 540 func CreateTUN(name string, mtu int, flags uint16) (Device, error) { 541 nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) 542 if err != nil { 543 if os.IsNotExist(err) { 544 return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) 545 } 546 return nil, err 547 } 548 549 ifr, err := unix.NewIfreq(name) 550 if err != nil { 551 return nil, err 552 } 553 // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() 554 // where a null write will return EINVAL indicating the TUN is up. 555 ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | flags) 556 err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) 557 if err != nil { 558 return nil, err 559 } 560 561 err = unix.SetNonblock(nfd, true) 562 if err != nil { 563 unix.Close(nfd) 564 return nil, err 565 } 566 567 // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. 568 569 fd := os.NewFile(uintptr(nfd), cloneDevicePath) 570 return CreateTUNFromFile(fd, mtu) 571 } 572 573 // CreateTUNFromFile creates a Device from an os.File with the provided MTU. 574 func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 575 tun := &NativeTun{ 576 tunFile: file, 577 events: make(chan Event, 5), 578 errors: make(chan error, 5), 579 statusListenersShutdown: make(chan struct{}), 580 tcp4GROTable: newTCPGROTable(), 581 tcp6GROTable: newTCPGROTable(), 582 toWrite: make([]int, 0, conn.IdealBatchSize), 583 } 584 585 name, err := tun.Name() 586 if err != nil { 587 return nil, err 588 } 589 590 err = tun.initFromFlags(name) 591 if err != nil { 592 return nil, err 593 } 594 595 // start event listener 596 tun.index, err = getIFIndex(name) 597 if err != nil { 598 return nil, err 599 } 600 601 tun.netlinkSock, err = createNetlinkSocket() 602 if err != nil { 603 return nil, err 604 } 605 tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock) 606 if err != nil { 607 unix.Close(tun.netlinkSock) 608 return nil, err 609 } 610 611 tun.hackListenerClosed.Lock() 612 go tun.routineNetlinkListener() 613 go tun.routineHackListener() // cross namespace 614 615 err = tun.setMTU(mtu) 616 if err != nil { 617 unix.Close(tun.netlinkSock) 618 return nil, err 619 } 620 621 return tun, nil 622 } 623 624 // CreateUnmonitoredTUNFromFD creates a Device from the provided file 625 // descriptor. 626 func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { 627 err := unix.SetNonblock(fd, true) 628 if err != nil { 629 return nil, "", err 630 } 631 file := os.NewFile(uintptr(fd), "/dev/tun") 632 tun := &NativeTun{ 633 tunFile: file, 634 events: make(chan Event, 5), 635 errors: make(chan error, 5), 636 tcp4GROTable: newTCPGROTable(), 637 tcp6GROTable: newTCPGROTable(), 638 toWrite: make([]int, 0, conn.IdealBatchSize), 639 } 640 name, err := tun.Name() 641 if err != nil { 642 return nil, "", err 643 } 644 err = tun.initFromFlags(name) 645 if err != nil { 646 return nil, "", err 647 } 648 return tun, name, err 649 } 650 651 func (tun *NativeTun) Read(p []byte) (n int, err error) { 652 var ( 653 bufs = make([][]byte, 1) 654 sizes = make([]int, 1) 655 ) 656 657 bufs[0] = make([]byte, len(p)) 658 n, err = tun.ReadPackets(bufs, sizes, 0) 659 if err != nil { 660 return 0, err 661 } 662 if sizes[0] < 1 { 663 return 0, nil 664 } 665 666 copy(p, bufs[0][:sizes[0]]) 667 668 return sizes[0], nil 669 } 670 671 func (tun *NativeTun) Write(p []byte) (n int, err error) { 672 return tun.WritePackets([][]byte{p}, 0) 673 }