github.com/yaling888/clash@v1.53.0/listener/tun/device/iobased/endpoint.go (about)

     1  //go:build !nogvisor
     2  
     3  package iobased
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"os"
     9  	"sync"
    10  
    11  	"gvisor.dev/gvisor/pkg/buffer"
    12  	"gvisor.dev/gvisor/pkg/tcpip"
    13  	"gvisor.dev/gvisor/pkg/tcpip/header"
    14  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    15  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    16  
    17  	dev "github.com/yaling888/clash/listener/tun/device"
    18  )
    19  
    20  const (
    21  	// Queue length for outbound packet, arriving for read. Overflow
    22  	// causes packet drops.
    23  	defaultOutQueueLen = 1 << 10
    24  )
    25  
    26  // Endpoint implements the interface of stack.LinkEndpoint from io.ReadWriter.
    27  type Endpoint struct {
    28  	*channel.Endpoint
    29  
    30  	// rw is the io.ReadWriter for reading and writing packets.
    31  	rw dev.Device
    32  
    33  	// mtu (maximum transmission unit) is the maximum size of a packet.
    34  	mtu uint32
    35  
    36  	// offset can be useful when perform TUN device I/O with TUN_PI enabled.
    37  	offset int
    38  
    39  	// once is used to perform the init action once when attaching.
    40  	once sync.Once
    41  
    42  	// wg keeps track of running goroutines.
    43  	wg sync.WaitGroup
    44  }
    45  
    46  // New returns stack.LinkEndpoint(.*Endpoint) and error.
    47  func New(rw dev.Device, mtu uint32, offset int) (*Endpoint, error) {
    48  	if mtu == 0 {
    49  		return nil, errors.New("MTU size is zero")
    50  	}
    51  
    52  	if rw == nil {
    53  		return nil, errors.New("RW interface is nil")
    54  	}
    55  
    56  	if offset < 0 {
    57  		return nil, errors.New("offset must be non-negative")
    58  	}
    59  
    60  	return &Endpoint{
    61  		Endpoint: channel.New(defaultOutQueueLen, mtu, ""),
    62  		rw:       rw,
    63  		mtu:      mtu,
    64  		offset:   offset,
    65  	}, nil
    66  }
    67  
    68  func (e *Endpoint) Wait() {
    69  	e.wg.Wait()
    70  }
    71  
    72  // Attach launches the goroutine that reads packets from io.Reader and
    73  // dispatches them via the provided dispatcher.
    74  func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
    75  	e.Endpoint.Attach(dispatcher)
    76  	e.once.Do(func() {
    77  		ctx, cancel := context.WithCancel(context.Background())
    78  		e.wg.Add(2)
    79  		go func() {
    80  			e.outboundLoop(ctx)
    81  			e.wg.Done()
    82  		}()
    83  		go func() {
    84  			e.dispatchLoop(cancel)
    85  			e.wg.Done()
    86  		}()
    87  	})
    88  }
    89  
    90  // dispatchLoop dispatches packets to upper layer.
    91  func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
    92  	// Call cancel() to ensure (*Endpoint).outboundLoop(context.Context) exits
    93  	// gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns.
    94  	defer cancel()
    95  
    96  	var (
    97  		readErr    error
    98  		device     = e.rw
    99  		offset     = e.offset
   100  		batchSize  = device.BatchSize()
   101  		bufferSize = 65535 + offset
   102  		count      int
   103  		buffs      = make([][]byte, batchSize)
   104  		sizes      = make([]int, batchSize)
   105  	)
   106  
   107  	bufferSize += 7 - ((bufferSize + 7) % 8)
   108  	for i := range buffs {
   109  		buffs[i] = make([]byte, bufferSize)
   110  	}
   111  
   112  	for {
   113  		count, readErr = device.Read(buffs, sizes, offset)
   114  		for i := 0; i < count; i++ {
   115  			if sizes[i] < 1 || !e.IsAttached() {
   116  				continue
   117  			}
   118  
   119  			data := buffs[i][offset : offset+sizes[i]]
   120  
   121  			var p tcpip.NetworkProtocolNumber
   122  			switch header.IPVersion(data) {
   123  			case header.IPv4Version:
   124  				p = header.IPv4ProtocolNumber
   125  			case header.IPv6Version:
   126  				p = header.IPv6ProtocolNumber
   127  			default:
   128  				continue
   129  			}
   130  
   131  			pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   132  				Payload: buffer.MakeWithData(data),
   133  			})
   134  
   135  			e.InjectInbound(p, pkt)
   136  
   137  			pkt.DecRef()
   138  		}
   139  
   140  		if readErr != nil {
   141  			if errors.Is(readErr, os.ErrClosed) {
   142  				return
   143  			}
   144  			continue
   145  		}
   146  	}
   147  }
   148  
   149  // outboundLoop reads outbound packets from channel, and then it calls
   150  // writePacket to send those packets back to lower layer.
   151  func (e *Endpoint) outboundLoop(ctx context.Context) {
   152  	buffs := make([][]byte, 0, 1)
   153  	for {
   154  		pkt := e.ReadContext(ctx)
   155  		if pkt == nil {
   156  			break
   157  		}
   158  		buffs = buffs[:0]
   159  		e.writePacket(buffs, pkt)
   160  	}
   161  }
   162  
   163  // writePacket writes outbound packets to the io.Writer.
   164  func (e *Endpoint) writePacket(buffs [][]byte, pkt *stack.PacketBuffer) tcpip.Error {
   165  	var (
   166  		pktView *buffer.View
   167  		offset  = e.offset
   168  	)
   169  
   170  	defer func() {
   171  		pktView.Release()
   172  		pkt.DecRef()
   173  	}()
   174  
   175  	if offset > 0 {
   176  		v := pkt.ToView()
   177  		pktView = buffer.NewViewSize(offset + pkt.Size())
   178  		_, _ = pktView.WriteAt(v.AsSlice(), offset)
   179  		v.Release()
   180  	} else {
   181  		pktView = pkt.ToView()
   182  	}
   183  
   184  	buffs = append(buffs, pktView.AsSlice())
   185  	if _, err := e.rw.Write(buffs, offset); err != nil {
   186  		return &tcpip.ErrInvalidEndpointState{}
   187  	}
   188  	return nil
   189  }