github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/gvisor/iovec_wireguard.go (about) 1 package tun 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "os" 8 "sync" 9 10 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 11 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 12 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 13 wun "github.com/tailscale/wireguard-go/tun" 14 ) 15 16 type wgDevice struct { 17 wun.Device 18 mtu int 19 offset int 20 21 rmu sync.Mutex 22 wmu sync.Mutex 23 wbuffers [][]byte 24 rbuffers [][]byte 25 rsize []int 26 } 27 28 func NewDevice(device wun.Device, offset int) *wgDevice { 29 mtu, _ := device.MTU() 30 if mtu <= 0 { 31 mtu = nat.MaxSegmentSize 32 } 33 wrwc := &wgDevice{ 34 Device: device, 35 mtu: mtu, 36 offset: offset, 37 38 wbuffers: getBuffer(device.BatchSize(), offset+mtu+10), 39 rbuffers: getBuffer(device.BatchSize(), offset+mtu+10), 40 rsize: buffPool(device.BatchSize(), true).Get().([]int), 41 } 42 43 return wrwc 44 } 45 46 func (t *wgDevice) Read(bufs [][]byte, sizes []int) (n int, err error) { 47 if t.offset == 0 && t.Device.BatchSize() == 1 { 48 return t.Device.Read(bufs, sizes, t.offset) 49 } 50 51 t.rmu.Lock() 52 defer t.rmu.Unlock() 53 54 count, err := t.Device.Read(t.rbuffers, t.rsize, t.offset) 55 if err != nil { 56 return 0, err 57 } 58 59 if count > len(bufs) { 60 return 0, fmt.Errorf("buffer %d is smaller than recevied: %d", len(bufs), count) 61 } 62 63 for i := range count { 64 copy(bufs[i], t.rbuffers[i][t.offset:t.rsize[i]+t.offset]) 65 sizes[i] = t.rsize[i] 66 } 67 68 return count, err 69 } 70 71 func (t *wgDevice) Write(bufs [][]byte) (int, error) { 72 if t.offset == 0 && t.BatchSize() == 1 { 73 return t.Device.Write(bufs, t.offset) 74 } 75 76 if len(bufs) > len(t.wbuffers) { 77 return 0, fmt.Errorf("buffer %d is larger than recevied: %d", len(t.wbuffers), len(bufs)) 78 } 79 80 t.wmu.Lock() 81 defer t.wmu.Unlock() 82 83 buffs := buffPool(len(bufs), false).Get().([][]byte) 84 defer buffPool(len(bufs), false).Put(buffs) 85 86 for i := range bufs { 87 n := copy(t.wbuffers[i][t.offset:], bufs[i]) 88 buffs[i] = t.wbuffers[i][:n+t.offset] 89 } 90 91 _, err := t.Device.Write(buffs, t.offset) 92 if err != nil { 93 return 0, err 94 } 95 96 return len(bufs), nil 97 } 98 99 func (t *wgDevice) Tun() wun.Device { return t.Device } 100 101 type poolType struct { 102 batch int 103 size bool 104 } 105 106 var poolMap syncmap.SyncMap[poolType, *sync.Pool] 107 108 func buffPool(batch int, size bool) *sync.Pool { 109 t := poolType{batch, size} 110 if v, ok := poolMap.Load(t); ok { 111 return v 112 } 113 114 var p *sync.Pool 115 116 if size { 117 p = &sync.Pool{ 118 New: func() any { 119 return make([]int, batch) 120 }, 121 } 122 } else { 123 p = &sync.Pool{New: func() any { 124 return make([][]byte, batch) 125 }} 126 } 127 poolMap.Store(t, p) 128 return p 129 } 130 131 func getBuffer(batch, size int) [][]byte { 132 bufs := buffPool(batch, false).Get().([][]byte) 133 134 for i := range bufs { 135 bufs[i] = pool.GetBytes(size) 136 } 137 138 return bufs 139 } 140 141 func putBuffer(bufs [][]byte) { 142 for i := range bufs { 143 pool.PutBytes(bufs[i]) 144 } 145 buffPool(len(bufs), false).Put(bufs) 146 } 147 148 type ChannelTun struct { 149 mtu int 150 inbound chan *pool.Bytes 151 outbound chan *pool.Bytes 152 ctx context.Context 153 cancel context.CancelFunc 154 events chan wun.Event 155 } 156 157 func NewChannelTun(ctx context.Context, mtu int) *ChannelTun { 158 if mtu <= 0 { 159 mtu = nat.MaxSegmentSize 160 } 161 ctx, cancel := context.WithCancel(ctx) 162 ct := &ChannelTun{ 163 mtu: mtu, 164 inbound: make(chan *pool.Bytes, 10), 165 outbound: make(chan *pool.Bytes, 10), 166 ctx: ctx, 167 cancel: cancel, 168 events: make(chan wun.Event, 1), 169 } 170 171 ct.events <- wun.EventUp 172 173 return ct 174 } 175 176 func (p *ChannelTun) Outbound(b []byte) error { 177 select { 178 case p.outbound <- pool.GetBytesBuffer(p.mtu).Copy(b): 179 return nil 180 case <-p.ctx.Done(): 181 return io.ErrClosedPipe 182 } 183 } 184 185 func (p *ChannelTun) Read(b [][]byte, size []int, offset int) (int, error) { 186 if len(b) == 0 { 187 return 0, nil 188 } 189 190 select { 191 case <-p.ctx.Done(): 192 return 0, io.EOF 193 case bb := <-p.outbound: 194 defer bb.Free() 195 size[0] = copy(b[0][offset:], bb.Bytes()) 196 return 1, nil 197 } 198 } 199 200 func (p *ChannelTun) Inbound(b []byte) (int, error) { 201 select { 202 case <-p.ctx.Done(): 203 return 0, io.EOF 204 case bb := <-p.inbound: 205 defer bb.Free() 206 return copy(b, bb.Bytes()), nil 207 } 208 } 209 210 func (p *ChannelTun) Write(b [][]byte, offset int) (int, error) { 211 for _, bb := range b { 212 select { 213 case p.inbound <- pool.GetBytesBuffer(p.mtu).Copy(bb[offset:]): 214 return len(b), nil 215 case <-p.ctx.Done(): 216 return 0, io.ErrClosedPipe 217 } 218 } 219 220 return len(b), nil 221 } 222 223 func (p *ChannelTun) Close() error { 224 select { 225 case <-p.ctx.Done(): 226 return nil 227 default: 228 } 229 close(p.events) 230 p.cancel() 231 return nil 232 } 233 234 func (p *ChannelTun) BatchSize() int { return 1 } 235 func (p *ChannelTun) Name() (string, error) { return "channelTun", nil } 236 func (p *ChannelTun) MTU() (int, error) { return p.mtu, nil } 237 func (p *ChannelTun) File() *os.File { return nil } 238 239 func (p *ChannelTun) Events() <-chan wun.Event { return p.events }