github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/gvisor/open_linux.go (about)

     1  package tun
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
     8  	"github.com/Asutorufa/yuhaiin/pkg/net/netlink"
     9  	"github.com/tailscale/wireguard-go/conn"
    10  	wun "github.com/tailscale/wireguard-go/tun"
    11  )
    12  
    13  const (
    14  	offset = 0
    15  )
    16  
    17  func OpenWriter(sc netlink.TunScheme, mtu int) (netlink.Tun, error) {
    18  	var err error
    19  	var device wun.Device
    20  	switch sc.Scheme {
    21  	case "tun":
    22  		wd, err := wun.CreateTUN(sc.Name, mtu)
    23  		if err != nil {
    24  			return nil, fmt.Errorf("create tun failed: %w", err)
    25  		}
    26  
    27  		if wd.BatchSize() == conn.IdealBatchSize {
    28  			wd = newWrapGsoDevice(wd)
    29  			// gso enabled
    30  		}
    31  		device = wd
    32  	case "fd":
    33  		device, _, err = wun.CreateUnmonitoredTUNFromFD(sc.Fd)
    34  	default:
    35  		return nil, fmt.Errorf("invalid tun: %v", sc)
    36  	}
    37  	if err != nil {
    38  		return nil, fmt.Errorf("create tun failed: %w", err)
    39  	}
    40  
    41  	return NewDevice(device, offset), nil
    42  }
    43  
    44  type wrapGsoDevice struct {
    45  	wun.Device
    46  	mtu int
    47  
    48  	w        sync.Mutex
    49  	wbuffers [][]byte
    50  }
    51  
    52  func newWrapGsoDevice(device wun.Device) *wrapGsoDevice {
    53  	mtu, _ := device.MTU()
    54  	if mtu <= 0 {
    55  		mtu = nat.MaxSegmentSize
    56  	}
    57  	w := &wrapGsoDevice{
    58  		Device: device,
    59  		mtu:    mtu,
    60  
    61  		wbuffers: getBuffer(device.BatchSize(), mtu+offset+10),
    62  	}
    63  
    64  	return w
    65  }
    66  
    67  func (w *wrapGsoDevice) Write(bufs [][]byte, offset int) (int, error) {
    68  	// https://github.com/WireGuard/wireguard-go/blob/12269c2761734b15625017d8565745096325392f/tun/offload_linux.go#L867
    69  	//
    70  	// virtioNetHdrLen = 10
    71  
    72  	if len(bufs) > len(w.wbuffers) {
    73  		return 0, fmt.Errorf("buffer %d is larger than recevied: %d", len(w.wbuffers), len(bufs))
    74  	}
    75  
    76  	w.w.Lock()
    77  	defer w.w.Unlock()
    78  
    79  	buffs := buffPool(len(bufs), false).Get().([][]byte)
    80  	defer buffPool(len(bufs), false).Put(buffs)
    81  
    82  	for i := range bufs {
    83  		n := copy(w.wbuffers[i][10:], bufs[i])
    84  		buffs[i] = w.wbuffers[i][:n+10]
    85  	}
    86  
    87  	return w.Device.Write(buffs, 10)
    88  }
    89  
    90  func (w *wrapGsoDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
    91  	// https://github.com/WireGuard/wireguard-go/blob/12269c2761734b15625017d8565745096325392f/tun/offload_linux.go#L867
    92  	//
    93  	// virtioNetHdrLen = 10
    94  	n, err = w.Device.Read(bufs, sizes, 10)
    95  	if err != nil {
    96  		return
    97  	}
    98  
    99  	for x := range n {
   100  		if sizes[x] < 10 {
   101  			return n, fmt.Errorf("invalid packet size small than virtioHdr 10: %d", sizes[x])
   102  		}
   103  
   104  		copy(bufs[x], bufs[x][10:])
   105  	}
   106  
   107  	return
   108  }