github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/tun2socket/nat/nat.go (about) 1 package nat 2 3 import ( 4 "context" 5 "errors" 6 "math" 7 "net" 8 9 "github.com/Asutorufa/yuhaiin/pkg/log" 10 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 11 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 12 "github.com/Asutorufa/yuhaiin/pkg/net/netlink" 13 tun "github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/gvisor" 14 "gvisor.dev/gvisor/pkg/tcpip" 15 "gvisor.dev/gvisor/pkg/tcpip/checksum" 16 "gvisor.dev/gvisor/pkg/tcpip/header" 17 ) 18 19 type Nat struct { 20 *TCP 21 *UDP 22 23 address tcpip.Address 24 portal tcpip.Address 25 addressV6 tcpip.Address 26 portalV6 tcpip.Address 27 gatewayPort uint16 28 mtu int32 29 30 tab *tableSplit 31 } 32 33 func Start(opt *tun.Opt) (*Nat, error) { 34 listener, err := dialer.ListenContextWithOptions(context.Background(), "tcp", "", &dialer.Options{}) 35 if err != nil { 36 return nil, err 37 } 38 39 log.Info("new tun2socket tcp server", "host", listener.Addr(), 40 "gateway", opt.V4Address(), "portal", opt.V4Address().Addr().Next(), 41 "gatewayv6", opt.V6Address(), "portalv6", opt.V6Address().Addr().Next(), 42 ) 43 44 err = netlink.Route(opt.Options) 45 if err != nil { 46 log.Warn("set route failed", "err", err) 47 } 48 49 if opt.MTU <= 0 { 50 opt.MTU = nat.MaxSegmentSize 51 } 52 53 tab := newTable() 54 55 nat := &Nat{ 56 address: tcpip.AddrFromSlice(opt.V4Address().Addr().AsSlice()), 57 portal: tcpip.AddrFromSlice(opt.V4Address().Addr().Next().AsSlice()), 58 addressV6: tcpip.AddrFromSlice(opt.V6Address().Addr().AsSlice()), 59 portalV6: tcpip.AddrFromSlice(opt.V6Address().Addr().Next().AsSlice()), 60 gatewayPort: uint16(listener.Addr().(*net.TCPAddr).Port), 61 mtu: int32(opt.MTU), 62 tab: tab, 63 TCP: &TCP{ 64 listener: listener.(*net.TCPListener), 65 portal: opt.V4Address().Addr().Next().AsSlice(), 66 portalv6: opt.V6Address().Addr().Next().AsSlice(), 67 table: tab, 68 }, 69 UDP: NewUDPv2(int32(opt.MTU), opt.Writer), 70 } 71 72 subnet := tcpip.AddressWithPrefix{Address: nat.address, PrefixLen: opt.V4Address().Bits()}.Subnet() 73 broadcast := subnet.Broadcast() 74 if broadcast.Equal(nat.address) || broadcast.Equal(nat.portal) { 75 broadcast = tcpip.AddrFrom4([4]byte{255, 255, 255, 255}) 76 } 77 78 go func() { 79 defer nat.Close() 80 81 sizes := make([]int, opt.Writer.Tun().BatchSize()) 82 bufs := make([][]byte, opt.Writer.Tun().BatchSize()) 83 for i := range bufs { 84 bufs[i] = make([]byte, opt.MTU) 85 } 86 87 wbufs := make([][]byte, opt.Writer.Tun().BatchSize()) 88 89 for { 90 n, err := opt.Writer.Read(bufs, sizes) 91 if err != nil { 92 log.Error("tun device read failed", "err", err) 93 return 94 } 95 96 wbufs = wbufs[:0] 97 98 for i := range n { 99 if sizes[i] < header.IPv4MinimumSize { 100 continue 101 } 102 103 raw := bufs[i][:sizes[i]] 104 105 ip := nat.processIP(raw) 106 if ip == nil { 107 continue 108 } 109 110 if len(ip.Payload()) > len(raw) { 111 continue 112 } 113 114 dst, src := ip.DestinationAddress(), ip.SourceAddress() 115 116 if !net.IP(dst.AsSlice()).IsGlobalUnicast() || dst.Equal(broadcast) { 117 continue 118 } 119 120 var tp header.Transport 121 var pseudoHeaderSum uint16 122 var ok bool 123 124 switch ip.TransportProtocol() { 125 case header.TCPProtocolNumber: 126 tp, pseudoHeaderSum, ok = nat.processTCP(ip, src, dst) 127 128 case header.ICMPv4ProtocolNumber: 129 tp, pseudoHeaderSum, ok = processICMP(ip) 130 131 case header.ICMPv6ProtocolNumber: 132 tp, pseudoHeaderSum, ok = processICMPv6(ip) 133 134 case header.UDPProtocolNumber: 135 u := header.UDP(ip.Payload()) 136 if u.Length() == 0 { 137 continue 138 } 139 140 nat.UDP.handleUDPPacket( 141 Tuple{ 142 SourceAddr: src, 143 SourcePort: u.SourcePort(), 144 DestinationAddr: dst, 145 DestinationPort: u.DestinationPort(), 146 }, u.Payload()) 147 148 continue 149 150 default: 151 continue 152 } 153 154 if !ok { 155 continue 156 } 157 158 resetCheckSum(ip, tp, pseudoHeaderSum) 159 160 wbufs = append(wbufs, raw) 161 } 162 163 if len(wbufs) == 0 { 164 continue 165 } 166 167 if _, err = opt.Writer.Write(wbufs); err != nil { 168 log.Error("write tcp raw to tun device failed", "err", err) 169 } 170 171 } 172 }() 173 174 return nat, nil 175 } 176 177 func (n *Nat) processIP(raw []byte) header.Network { 178 switch header.IPVersion(raw) { 179 case header.IPv4Version: 180 ipv4 := header.IPv4(raw) 181 182 if !ipv4.IsValid(int(ipv4.TotalLength())) { 183 return nil 184 } 185 186 if ipv4.More() { 187 return nil 188 } 189 190 if ipv4.FragmentOffset() != 0 { 191 return nil 192 } 193 194 return ipv4 195 196 case header.IPv6Version: 197 ipv6 := header.IPv6(raw) 198 199 if ipv6.HopLimit() == 0x00 { 200 return nil 201 } 202 203 return ipv6 204 } 205 206 return nil 207 } 208 209 func (n *Nat) processTCP(ip header.Network, src, dst tcpip.Address) (_ header.Transport, pseudoHeaderSum uint16, _ bool) { 210 t := header.TCP(ip.Payload()) 211 212 sourcePort := t.SourcePort() 213 destinationPort := t.DestinationPort() 214 215 var address, portal tcpip.Address 216 if _, ok := ip.(header.IPv4); ok { 217 address, portal = n.address, n.portal 218 } else { 219 address, portal = n.addressV6, n.portalV6 220 } 221 222 if address.Unspecified() || portal.Unspecified() { 223 return nil, 0, false 224 } 225 if src == address && sourcePort == n.gatewayPort { 226 tup := n.tab.tupleOf(destinationPort, dst.Len() == 16) 227 if tup == zeroTuple { 228 return nil, 0, false 229 } 230 231 ip.SetDestinationAddress(tup.SourceAddr) 232 t.SetDestinationPort(tup.SourcePort) 233 ip.SetSourceAddress(tup.DestinationAddr) 234 t.SetSourcePort(tup.DestinationPort) 235 } else { 236 tup := Tuple{ 237 SourceAddr: src, 238 SourcePort: sourcePort, 239 DestinationAddr: dst, 240 DestinationPort: destinationPort, 241 } 242 243 port := n.tab.portOf(tup) 244 ip.SetDestinationAddress(address) 245 t.SetDestinationPort(n.gatewayPort) 246 ip.SetSourceAddress(portal) 247 t.SetSourcePort(port) 248 } 249 250 pseudoHeaderSum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, 251 ip.SourceAddress(), 252 ip.DestinationAddress(), 253 uint16(len(ip.Payload())), 254 ) 255 256 return t, pseudoHeaderSum, true 257 } 258 259 func (n *Nat) Close() error { 260 var err error 261 262 if n.UDP != nil { 263 if er := n.UDP.Close(); er != nil { 264 err = errors.Join(err, er) 265 } 266 } 267 268 if n.TCP != nil { 269 if er := n.TCP.Close(); er != nil { 270 err = errors.Join(err, er) 271 } 272 } 273 274 return err 275 } 276 277 func processICMP(ip header.Network) (_ header.Transport, pseudoHeaderSum uint16, _ bool) { 278 i := header.ICMPv4(ip.Payload()) 279 280 if i.Type() != header.ICMPv4Echo || i.Code() != 0 { 281 return nil, 0, false 282 } 283 284 i.SetType(header.ICMPv4EchoReply) 285 286 destination := ip.DestinationAddress() 287 ip.SetDestinationAddress(ip.SourceAddress()) 288 ip.SetSourceAddress(destination) 289 290 pseudoHeaderSum = 0 291 292 return i, pseudoHeaderSum, true 293 } 294 295 func processICMPv6(ip header.Network) (_ header.Transport, pseudoHeaderSum uint16, _ bool) { 296 i := header.ICMPv6(ip.Payload()) 297 298 if i.Type() != header.ICMPv6EchoRequest || i.Code() != 0 { 299 return nil, 0, false 300 } 301 302 i.SetType(header.ICMPv6EchoReply) 303 304 destination := ip.DestinationAddress() 305 ip.SetDestinationAddress(ip.SourceAddress()) 306 ip.SetSourceAddress(destination) 307 308 pseudoHeaderSum = header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, 309 ip.SourceAddress(), ip.DestinationAddress(), 310 uint16(len(i)), 311 ) 312 313 return i, pseudoHeaderSum, true 314 } 315 316 func resetCheckSum(ip header.Network, tp header.Transport, pseudoHeaderSum uint16) { 317 resetIPCheckSum(ip) 318 resetTransportCheckSum(ip, tp, pseudoHeaderSum) 319 } 320 321 func resetIPCheckSum(ip header.Network) { 322 if ip, ok := ip.(header.IPv4); ok { 323 ip.SetChecksum(0) 324 sum := ip.CalculateChecksum() 325 ip.SetChecksum(^sum) 326 } 327 } 328 329 func resetTransportCheckSum(ip header.Network, tp header.Transport, pseudoHeaderSum uint16) { 330 tp.SetChecksum(0) 331 sum := checksum.Checksum(ip.Payload(), pseudoHeaderSum) 332 333 //https://datatracker.ietf.org/doc/html/rfc768 334 // 335 // If the computed checksum is zero, it is transmitted as all ones (the 336 // equivalent in one's complement arithmetic). An all zero transmitted 337 // checksum value means that the transmitter generated no checksum (for 338 // debugging or for higher level protocols that don't care). 339 // 340 // https://datatracker.ietf.org/doc/html/rfc8200 341 // Unlike IPv4, the default behavior when UDP packets are 342 // originated by an IPv6 node is that the UDP checksum is not 343 // optional. That is, whenever originating a UDP packet, an IPv6 344 // node must compute a UDP checksum over the packet and the 345 // pseudo-header, and, if that computation yields a result of 346 // zero, it must be changed to hex FFFF for placement in the UDP 347 // header. IPv6 receivers must discard UDP packets containing a 348 // zero checksum and should log the error. 349 if ip.TransportProtocol() != header.UDPProtocolNumber || sum != math.MaxUint16 { 350 sum = ^sum 351 } 352 tp.SetChecksum(sum) 353 }