github.com/sagernet/sing-box@v1.9.0-rc.20/transport/wireguard/device_stack.go (about) 1 //go:build with_gvisor 2 3 package wireguard 4 5 import ( 6 "context" 7 "net" 8 "net/netip" 9 "os" 10 11 "github.com/sagernet/gvisor/pkg/buffer" 12 "github.com/sagernet/gvisor/pkg/tcpip" 13 "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" 14 "github.com/sagernet/gvisor/pkg/tcpip/header" 15 "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" 16 "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" 17 "github.com/sagernet/gvisor/pkg/tcpip/stack" 18 "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" 19 "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" 20 "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" 21 "github.com/sagernet/sing-tun" 22 "github.com/sagernet/sing/common/buf" 23 E "github.com/sagernet/sing/common/exceptions" 24 M "github.com/sagernet/sing/common/metadata" 25 N "github.com/sagernet/sing/common/network" 26 wgTun "github.com/sagernet/wireguard-go/tun" 27 ) 28 29 var _ Device = (*StackDevice)(nil) 30 31 const defaultNIC tcpip.NICID = 1 32 33 type StackDevice struct { 34 stack *stack.Stack 35 mtu uint32 36 events chan wgTun.Event 37 outbound chan *stack.PacketBuffer 38 packetOutbound chan *buf.Buffer 39 done chan struct{} 40 dispatcher stack.NetworkDispatcher 41 addr4 tcpip.Address 42 addr6 tcpip.Address 43 } 44 45 func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { 46 ipStack := stack.New(stack.Options{ 47 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 48 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, 49 HandleLocal: true, 50 }) 51 tunDevice := &StackDevice{ 52 stack: ipStack, 53 mtu: mtu, 54 events: make(chan wgTun.Event, 1), 55 outbound: make(chan *stack.PacketBuffer, 256), 56 packetOutbound: make(chan *buf.Buffer, 256), 57 done: make(chan struct{}), 58 } 59 err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) 60 if err != nil { 61 return nil, E.New(err.String()) 62 } 63 for _, prefix := range localAddresses { 64 addr := tun.AddressFromAddr(prefix.Addr()) 65 protoAddr := tcpip.ProtocolAddress{ 66 AddressWithPrefix: tcpip.AddressWithPrefix{ 67 Address: addr, 68 PrefixLen: prefix.Bits(), 69 }, 70 } 71 if prefix.Addr().Is4() { 72 tunDevice.addr4 = addr 73 protoAddr.Protocol = ipv4.ProtocolNumber 74 } else { 75 tunDevice.addr6 = addr 76 protoAddr.Protocol = ipv6.ProtocolNumber 77 } 78 err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) 79 if err != nil { 80 return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) 81 } 82 } 83 sOpt := tcpip.TCPSACKEnabled(true) 84 ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) 85 cOpt := tcpip.CongestionControlOption("cubic") 86 ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) 87 ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) 88 ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC}) 89 return tunDevice, nil 90 } 91 92 func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) { 93 return (*wireEndpoint)(w), nil 94 } 95 96 func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 97 addr := tcpip.FullAddress{ 98 NIC: defaultNIC, 99 Port: destination.Port, 100 Addr: tun.AddressFromAddr(destination.Addr), 101 } 102 bind := tcpip.FullAddress{ 103 NIC: defaultNIC, 104 } 105 var networkProtocol tcpip.NetworkProtocolNumber 106 if destination.IsIPv4() { 107 networkProtocol = header.IPv4ProtocolNumber 108 bind.Addr = w.addr4 109 } else { 110 networkProtocol = header.IPv6ProtocolNumber 111 bind.Addr = w.addr6 112 } 113 switch N.NetworkName(network) { 114 case N.NetworkTCP: 115 tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) 116 if err != nil { 117 return nil, err 118 } 119 return tcpConn, nil 120 case N.NetworkUDP: 121 udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol) 122 if err != nil { 123 return nil, err 124 } 125 return udpConn, nil 126 default: 127 return nil, E.Extend(N.ErrUnknownNetwork, network) 128 } 129 } 130 131 func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 132 bind := tcpip.FullAddress{ 133 NIC: defaultNIC, 134 } 135 var networkProtocol tcpip.NetworkProtocolNumber 136 if destination.IsIPv4() { 137 networkProtocol = header.IPv4ProtocolNumber 138 bind.Addr = w.addr4 139 } else { 140 networkProtocol = header.IPv6ProtocolNumber 141 bind.Addr = w.addr6 142 } 143 udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) 144 if err != nil { 145 return nil, err 146 } 147 return udpConn, nil 148 } 149 150 func (w *StackDevice) Inet4Address() netip.Addr { 151 return tun.AddrFromAddress(w.addr4) 152 } 153 154 func (w *StackDevice) Inet6Address() netip.Addr { 155 return tun.AddrFromAddress(w.addr6) 156 } 157 158 func (w *StackDevice) Start() error { 159 w.events <- wgTun.EventUp 160 return nil 161 } 162 163 func (w *StackDevice) File() *os.File { 164 return nil 165 } 166 167 func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { 168 select { 169 case packetBuffer, ok := <-w.outbound: 170 if !ok { 171 return 0, os.ErrClosed 172 } 173 defer packetBuffer.DecRef() 174 p := bufs[0] 175 p = p[offset:] 176 n := 0 177 for _, slice := range packetBuffer.AsSlices() { 178 n += copy(p[n:], slice) 179 } 180 sizes[0] = n 181 count = 1 182 return 183 case packet := <-w.packetOutbound: 184 defer packet.Release() 185 sizes[0] = copy(bufs[0][offset:], packet.Bytes()) 186 count = 1 187 return 188 case <-w.done: 189 return 0, os.ErrClosed 190 } 191 } 192 193 func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { 194 for _, b := range bufs { 195 b = b[offset:] 196 if len(b) == 0 { 197 continue 198 } 199 var networkProtocol tcpip.NetworkProtocolNumber 200 switch header.IPVersion(b) { 201 case header.IPv4Version: 202 networkProtocol = header.IPv4ProtocolNumber 203 case header.IPv6Version: 204 networkProtocol = header.IPv6ProtocolNumber 205 } 206 packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 207 Payload: buffer.MakeWithData(b), 208 }) 209 w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) 210 packetBuffer.DecRef() 211 count++ 212 } 213 return 214 } 215 216 func (w *StackDevice) Flush() error { 217 return nil 218 } 219 220 func (w *StackDevice) MTU() (int, error) { 221 return int(w.mtu), nil 222 } 223 224 func (w *StackDevice) Name() (string, error) { 225 return "sing-box", nil 226 } 227 228 func (w *StackDevice) Events() <-chan wgTun.Event { 229 return w.events 230 } 231 232 func (w *StackDevice) Close() error { 233 select { 234 case <-w.done: 235 return os.ErrClosed 236 default: 237 } 238 w.stack.Close() 239 for _, endpoint := range w.stack.CleanupEndpoints() { 240 endpoint.Abort() 241 } 242 w.stack.Wait() 243 close(w.done) 244 return nil 245 } 246 247 func (w *StackDevice) BatchSize() int { 248 return 1 249 } 250 251 var _ stack.LinkEndpoint = (*wireEndpoint)(nil) 252 253 type wireEndpoint StackDevice 254 255 func (ep *wireEndpoint) MTU() uint32 { 256 return ep.mtu 257 } 258 259 func (ep *wireEndpoint) MaxHeaderLength() uint16 { 260 return 0 261 } 262 263 func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress { 264 return "" 265 } 266 267 func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities { 268 return stack.CapabilityRXChecksumOffload 269 } 270 271 func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) { 272 ep.dispatcher = dispatcher 273 } 274 275 func (ep *wireEndpoint) IsAttached() bool { 276 return ep.dispatcher != nil 277 } 278 279 func (ep *wireEndpoint) Wait() { 280 } 281 282 func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType { 283 return header.ARPHardwareNone 284 } 285 286 func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) { 287 } 288 289 func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool { 290 return true 291 } 292 293 func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { 294 for _, packetBuffer := range list.AsSlice() { 295 packetBuffer.IncRef() 296 select { 297 case <-ep.done: 298 return 0, &tcpip.ErrClosedForSend{} 299 case ep.outbound <- packetBuffer: 300 } 301 } 302 return list.Len(), nil 303 }