github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/gvisor/iovec_wireguard.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"sync"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    11  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    12  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    13  	wun "github.com/tailscale/wireguard-go/tun"
    14  )
    15  
    16  type wgDevice struct {
    17  	wun.Device
    18  	mtu    int
    19  	offset int
    20  
    21  	rmu      sync.Mutex
    22  	wmu      sync.Mutex
    23  	wbuffers [][]byte
    24  	rbuffers [][]byte
    25  	rsize    []int
    26  }
    27  
    28  func NewDevice(device wun.Device, offset int) *wgDevice {
    29  	mtu, _ := device.MTU()
    30  	if mtu <= 0 {
    31  		mtu = nat.MaxSegmentSize
    32  	}
    33  	wrwc := &wgDevice{
    34  		Device: device,
    35  		mtu:    mtu,
    36  		offset: offset,
    37  
    38  		wbuffers: getBuffer(device.BatchSize(), offset+mtu+10),
    39  		rbuffers: getBuffer(device.BatchSize(), offset+mtu+10),
    40  		rsize:    buffPool(device.BatchSize(), true).Get().([]int),
    41  	}
    42  
    43  	return wrwc
    44  }
    45  
    46  func (t *wgDevice) Read(bufs [][]byte, sizes []int) (n int, err error) {
    47  	if t.offset == 0 && t.Device.BatchSize() == 1 {
    48  		return t.Device.Read(bufs, sizes, t.offset)
    49  	}
    50  
    51  	t.rmu.Lock()
    52  	defer t.rmu.Unlock()
    53  
    54  	count, err := t.Device.Read(t.rbuffers, t.rsize, t.offset)
    55  	if err != nil {
    56  		return 0, err
    57  	}
    58  
    59  	if count > len(bufs) {
    60  		return 0, fmt.Errorf("buffer %d is smaller than recevied: %d", len(bufs), count)
    61  	}
    62  
    63  	for i := range count {
    64  		copy(bufs[i], t.rbuffers[i][t.offset:t.rsize[i]+t.offset])
    65  		sizes[i] = t.rsize[i]
    66  	}
    67  
    68  	return count, err
    69  }
    70  
    71  func (t *wgDevice) Write(bufs [][]byte) (int, error) {
    72  	if t.offset == 0 && t.BatchSize() == 1 {
    73  		return t.Device.Write(bufs, t.offset)
    74  	}
    75  
    76  	if len(bufs) > len(t.wbuffers) {
    77  		return 0, fmt.Errorf("buffer %d is larger than recevied: %d", len(t.wbuffers), len(bufs))
    78  	}
    79  
    80  	t.wmu.Lock()
    81  	defer t.wmu.Unlock()
    82  
    83  	buffs := buffPool(len(bufs), false).Get().([][]byte)
    84  	defer buffPool(len(bufs), false).Put(buffs)
    85  
    86  	for i := range bufs {
    87  		n := copy(t.wbuffers[i][t.offset:], bufs[i])
    88  		buffs[i] = t.wbuffers[i][:n+t.offset]
    89  	}
    90  
    91  	_, err := t.Device.Write(buffs, t.offset)
    92  	if err != nil {
    93  		return 0, err
    94  	}
    95  
    96  	return len(bufs), nil
    97  }
    98  
    99  func (t *wgDevice) Tun() wun.Device { return t.Device }
   100  
   101  type poolType struct {
   102  	batch int
   103  	size  bool
   104  }
   105  
   106  var poolMap syncmap.SyncMap[poolType, *sync.Pool]
   107  
   108  func buffPool(batch int, size bool) *sync.Pool {
   109  	t := poolType{batch, size}
   110  	if v, ok := poolMap.Load(t); ok {
   111  		return v
   112  	}
   113  
   114  	var p *sync.Pool
   115  
   116  	if size {
   117  		p = &sync.Pool{
   118  			New: func() any {
   119  				return make([]int, batch)
   120  			},
   121  		}
   122  	} else {
   123  		p = &sync.Pool{New: func() any {
   124  			return make([][]byte, batch)
   125  		}}
   126  	}
   127  	poolMap.Store(t, p)
   128  	return p
   129  }
   130  
   131  func getBuffer(batch, size int) [][]byte {
   132  	bufs := buffPool(batch, false).Get().([][]byte)
   133  
   134  	for i := range bufs {
   135  		bufs[i] = pool.GetBytes(size)
   136  	}
   137  
   138  	return bufs
   139  }
   140  
   141  func putBuffer(bufs [][]byte) {
   142  	for i := range bufs {
   143  		pool.PutBytes(bufs[i])
   144  	}
   145  	buffPool(len(bufs), false).Put(bufs)
   146  }
   147  
   148  type ChannelTun struct {
   149  	mtu      int
   150  	inbound  chan *pool.Bytes
   151  	outbound chan *pool.Bytes
   152  	ctx      context.Context
   153  	cancel   context.CancelFunc
   154  	events   chan wun.Event
   155  }
   156  
   157  func NewChannelTun(ctx context.Context, mtu int) *ChannelTun {
   158  	if mtu <= 0 {
   159  		mtu = nat.MaxSegmentSize
   160  	}
   161  	ctx, cancel := context.WithCancel(ctx)
   162  	ct := &ChannelTun{
   163  		mtu:      mtu,
   164  		inbound:  make(chan *pool.Bytes, 10),
   165  		outbound: make(chan *pool.Bytes, 10),
   166  		ctx:      ctx,
   167  		cancel:   cancel,
   168  		events:   make(chan wun.Event, 1),
   169  	}
   170  
   171  	ct.events <- wun.EventUp
   172  
   173  	return ct
   174  }
   175  
   176  func (p *ChannelTun) Outbound(b []byte) error {
   177  	select {
   178  	case p.outbound <- pool.GetBytesBuffer(p.mtu).Copy(b):
   179  		return nil
   180  	case <-p.ctx.Done():
   181  		return io.ErrClosedPipe
   182  	}
   183  }
   184  
   185  func (p *ChannelTun) Read(b [][]byte, size []int, offset int) (int, error) {
   186  	if len(b) == 0 {
   187  		return 0, nil
   188  	}
   189  
   190  	select {
   191  	case <-p.ctx.Done():
   192  		return 0, io.EOF
   193  	case bb := <-p.outbound:
   194  		defer bb.Free()
   195  		size[0] = copy(b[0][offset:], bb.Bytes())
   196  		return 1, nil
   197  	}
   198  }
   199  
   200  func (p *ChannelTun) Inbound(b []byte) (int, error) {
   201  	select {
   202  	case <-p.ctx.Done():
   203  		return 0, io.EOF
   204  	case bb := <-p.inbound:
   205  		defer bb.Free()
   206  		return copy(b, bb.Bytes()), nil
   207  	}
   208  }
   209  
   210  func (p *ChannelTun) Write(b [][]byte, offset int) (int, error) {
   211  	for _, bb := range b {
   212  		select {
   213  		case p.inbound <- pool.GetBytesBuffer(p.mtu).Copy(bb[offset:]):
   214  			return len(b), nil
   215  		case <-p.ctx.Done():
   216  			return 0, io.ErrClosedPipe
   217  		}
   218  	}
   219  
   220  	return len(b), nil
   221  }
   222  
   223  func (p *ChannelTun) Close() error {
   224  	select {
   225  	case <-p.ctx.Done():
   226  		return nil
   227  	default:
   228  	}
   229  	close(p.events)
   230  	p.cancel()
   231  	return nil
   232  }
   233  
   234  func (p *ChannelTun) BatchSize() int        { return 1 }
   235  func (p *ChannelTun) Name() (string, error) { return "channelTun", nil }
   236  func (p *ChannelTun) MTU() (int, error)     { return p.mtu, nil }
   237  func (p *ChannelTun) File() *os.File        { return nil }
   238  
   239  func (p *ChannelTun) Events() <-chan wun.Event { return p.events }