github.com/sagernet/sing-box@v1.2.7/transport/wireguard/device_stack.go (about) 1 //go:build with_gvisor 2 3 package wireguard 4 5 import ( 6 "context" 7 "net" 8 "net/netip" 9 "os" 10 11 E "github.com/sagernet/sing/common/exceptions" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 "github.com/sagernet/wireguard-go/tun" 15 16 "gvisor.dev/gvisor/pkg/bufferv2" 17 "gvisor.dev/gvisor/pkg/tcpip" 18 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 19 "gvisor.dev/gvisor/pkg/tcpip/header" 20 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 21 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 22 "gvisor.dev/gvisor/pkg/tcpip/stack" 23 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 24 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 25 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 26 ) 27 28 var _ Device = (*StackDevice)(nil) 29 30 const defaultNIC tcpip.NICID = 1 31 32 type StackDevice struct { 33 stack *stack.Stack 34 mtu uint32 35 events chan tun.Event 36 outbound chan *stack.PacketBuffer 37 done chan struct{} 38 dispatcher stack.NetworkDispatcher 39 addr4 tcpip.Address 40 addr6 tcpip.Address 41 } 42 43 func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { 44 ipStack := stack.New(stack.Options{ 45 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 46 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, 47 HandleLocal: true, 48 }) 49 tunDevice := &StackDevice{ 50 stack: ipStack, 51 mtu: mtu, 52 events: make(chan tun.Event, 1), 53 outbound: make(chan *stack.PacketBuffer, 256), 54 done: make(chan struct{}), 55 } 56 err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) 57 if err != nil { 58 return nil, E.New(err.String()) 59 } 60 for _, prefix := range localAddresses { 61 addr := tcpip.Address(prefix.Addr().AsSlice()) 62 protoAddr := tcpip.ProtocolAddress{ 63 AddressWithPrefix: tcpip.AddressWithPrefix{ 64 Address: addr, 65 PrefixLen: prefix.Bits(), 66 }, 67 } 68 if prefix.Addr().Is4() { 69 tunDevice.addr4 = addr 70 protoAddr.Protocol = ipv4.ProtocolNumber 71 } else { 72 tunDevice.addr6 = addr 73 protoAddr.Protocol = ipv6.ProtocolNumber 74 } 75 err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) 76 if err != nil { 77 return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) 78 } 79 } 80 sOpt := tcpip.TCPSACKEnabled(true) 81 ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) 82 cOpt := tcpip.CongestionControlOption("cubic") 83 ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) 84 ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) 85 ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC}) 86 return tunDevice, nil 87 } 88 89 func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) { 90 return (*wireEndpoint)(w), nil 91 } 92 93 func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 94 addr := tcpip.FullAddress{ 95 NIC: defaultNIC, 96 Port: destination.Port, 97 Addr: tcpip.Address(destination.Addr.AsSlice()), 98 } 99 bind := tcpip.FullAddress{ 100 NIC: defaultNIC, 101 } 102 var networkProtocol tcpip.NetworkProtocolNumber 103 if destination.IsIPv4() { 104 networkProtocol = header.IPv4ProtocolNumber 105 bind.Addr = w.addr4 106 } else { 107 networkProtocol = header.IPv6ProtocolNumber 108 bind.Addr = w.addr6 109 } 110 switch N.NetworkName(network) { 111 case N.NetworkTCP: 112 tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) 113 if err != nil { 114 return nil, err 115 } 116 return tcpConn, nil 117 case N.NetworkUDP: 118 udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol) 119 if err != nil { 120 return nil, err 121 } 122 return udpConn, nil 123 default: 124 return nil, E.Extend(N.ErrUnknownNetwork, network) 125 } 126 } 127 128 func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 129 bind := tcpip.FullAddress{ 130 NIC: defaultNIC, 131 } 132 var networkProtocol tcpip.NetworkProtocolNumber 133 if destination.IsIPv4() || w.addr6 == "" { 134 networkProtocol = header.IPv4ProtocolNumber 135 bind.Addr = w.addr4 136 } else { 137 networkProtocol = header.IPv6ProtocolNumber 138 bind.Addr = w.addr6 139 } 140 udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) 141 if err != nil { 142 return nil, err 143 } 144 return udpConn, nil 145 } 146 147 func (w *StackDevice) Start() error { 148 w.events <- tun.EventUp 149 return nil 150 } 151 152 func (w *StackDevice) File() *os.File { 153 return nil 154 } 155 156 func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { 157 select { 158 case packetBuffer, ok := <-w.outbound: 159 if !ok { 160 return 0, os.ErrClosed 161 } 162 defer packetBuffer.DecRef() 163 p = p[offset:] 164 for _, slice := range packetBuffer.AsSlices() { 165 n += copy(p[n:], slice) 166 } 167 return 168 case <-w.done: 169 return 0, os.ErrClosed 170 } 171 } 172 173 func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { 174 p = p[offset:] 175 if len(p) == 0 { 176 return 177 } 178 var networkProtocol tcpip.NetworkProtocolNumber 179 switch header.IPVersion(p) { 180 case header.IPv4Version: 181 networkProtocol = header.IPv4ProtocolNumber 182 case header.IPv6Version: 183 networkProtocol = header.IPv6ProtocolNumber 184 } 185 packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 186 Payload: bufferv2.MakeWithData(p), 187 }) 188 defer packetBuffer.DecRef() 189 w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) 190 n = len(p) 191 return 192 } 193 194 func (w *StackDevice) Flush() error { 195 return nil 196 } 197 198 func (w *StackDevice) MTU() (int, error) { 199 return int(w.mtu), nil 200 } 201 202 func (w *StackDevice) Name() (string, error) { 203 return "sing-box", nil 204 } 205 206 func (w *StackDevice) Events() chan tun.Event { 207 return w.events 208 } 209 210 func (w *StackDevice) Close() error { 211 select { 212 case <-w.done: 213 return os.ErrClosed 214 default: 215 } 216 w.stack.Close() 217 for _, endpoint := range w.stack.CleanupEndpoints() { 218 endpoint.Abort() 219 } 220 w.stack.Wait() 221 close(w.done) 222 return nil 223 } 224 225 var _ stack.LinkEndpoint = (*wireEndpoint)(nil) 226 227 type wireEndpoint StackDevice 228 229 func (ep *wireEndpoint) MTU() uint32 { 230 return ep.mtu 231 } 232 233 func (ep *wireEndpoint) MaxHeaderLength() uint16 { 234 return 0 235 } 236 237 func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress { 238 return "" 239 } 240 241 func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities { 242 return stack.CapabilityNone 243 } 244 245 func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) { 246 ep.dispatcher = dispatcher 247 } 248 249 func (ep *wireEndpoint) IsAttached() bool { 250 return ep.dispatcher != nil 251 } 252 253 func (ep *wireEndpoint) Wait() { 254 } 255 256 func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType { 257 return header.ARPHardwareNone 258 } 259 260 func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) { 261 } 262 263 func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { 264 for _, packetBuffer := range list.AsSlice() { 265 packetBuffer.IncRef() 266 select { 267 case <-ep.done: 268 return 0, &tcpip.ErrClosedForSend{} 269 case ep.outbound <- packetBuffer: 270 } 271 } 272 return list.Len(), nil 273 }