github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/iptables/iptables_util.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package iptables 16 17 import ( 18 "context" 19 "encoding/binary" 20 "errors" 21 "fmt" 22 "net" 23 "os/exec" 24 "strings" 25 "time" 26 27 "github.com/SagerNet/gvisor/pkg/test/testutil" 28 ) 29 30 // filterTable calls `ip{6}tables -t filter` with the given args. 31 func filterTable(ipv6 bool, args ...string) error { 32 return tableCmd(ipv6, "filter", args) 33 } 34 35 // natTable calls `ip{6}tables -t nat` with the given args. 36 func natTable(ipv6 bool, args ...string) error { 37 return tableCmd(ipv6, "nat", args) 38 } 39 40 func tableCmd(ipv6 bool, table string, args []string) error { 41 args = append([]string{"-t", table}, args...) 42 binary := "iptables" 43 if ipv6 { 44 binary = "ip6tables" 45 } 46 cmd := exec.Command(binary, args...) 47 if out, err := cmd.CombinedOutput(); err != nil { 48 return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) 49 } 50 return nil 51 } 52 53 // filterTableRules is like filterTable, but runs multiple iptables commands. 54 func filterTableRules(ipv6 bool, argsList [][]string) error { 55 return tableRules(ipv6, "filter", argsList) 56 } 57 58 // natTableRules is like natTable, but runs multiple iptables commands. 59 func natTableRules(ipv6 bool, argsList [][]string) error { 60 return tableRules(ipv6, "nat", argsList) 61 } 62 63 func tableRules(ipv6 bool, table string, argsList [][]string) error { 64 for _, args := range argsList { 65 if err := tableCmd(ipv6, table, args); err != nil { 66 return err 67 } 68 } 69 return nil 70 } 71 72 // listenUDP listens on a UDP port and returns nil if the first read from that 73 // port is successful. 74 func listenUDP(ctx context.Context, port int, ipv6 bool) error { 75 _, err := listenUDPFrom(ctx, port, ipv6) 76 return err 77 } 78 79 // listenUDPFrom listens on a UDP port and returns the sender's UDP address if 80 // the first read from that port is successful. 81 func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) { 82 localAddr := net.UDPAddr{ 83 Port: port, 84 } 85 conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr) 86 if err != nil { 87 return nil, err 88 } 89 defer conn.Close() 90 91 type result struct { 92 remoteAddr *net.UDPAddr 93 err error 94 } 95 96 ch := make(chan result) 97 go func() { 98 _, remoteAddr, err := conn.ReadFromUDP([]byte{0}) 99 ch <- result{remoteAddr, err} 100 }() 101 102 select { 103 case res := <-ch: 104 return res.remoteAddr, res.err 105 case <-ctx.Done(): 106 return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err()) 107 } 108 } 109 110 // sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified 111 // over a duration. 112 func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error { 113 remote := net.UDPAddr{ 114 IP: ip, 115 Port: port, 116 } 117 conn, err := net.DialUDP(udpNetwork(ipv6), nil, &remote) 118 if err != nil { 119 return err 120 } 121 defer conn.Close() 122 123 for { 124 // This may return an error (connection refused) if the remote 125 // hasn't started listening yet or they're dropping our 126 // packets. So we ignore Write errors and depend on the remote 127 // to report a failure if it doesn't get a packet it needs. 128 conn.Write([]byte{0}) 129 select { 130 case <-ctx.Done(): 131 // Being cancelled or timing out isn't an error, as we 132 // cannot tell with UDP whether we succeeded. 133 return nil 134 // Continue looping. 135 case <-time.After(200 * time.Millisecond): 136 } 137 } 138 } 139 140 // listenTCP listens for connections on a TCP port, and returns nil if a 141 // connection is established. 142 func listenTCP(ctx context.Context, port int, ipv6 bool) error { 143 _, err := listenTCPFrom(ctx, port, ipv6) 144 return err 145 } 146 147 // listenTCP listens for connections on a TCP port, and returns the remote 148 // TCP address if a connection is established. 149 func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) { 150 localAddr := net.TCPAddr{ 151 Port: port, 152 } 153 154 // Starts listening on port. 155 lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr) 156 if err != nil { 157 return nil, err 158 } 159 defer lConn.Close() 160 161 type result struct { 162 remoteAddr net.Addr 163 err error 164 } 165 166 // Accept connections on port. 167 ch := make(chan result) 168 go func() { 169 conn, err := lConn.AcceptTCP() 170 var remoteAddr net.Addr 171 if err == nil { 172 remoteAddr = conn.RemoteAddr() 173 } 174 ch <- result{remoteAddr, err} 175 conn.Close() 176 }() 177 178 select { 179 case res := <-ch: 180 return res.remoteAddr, res.err 181 case <-ctx.Done(): 182 return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err()) 183 } 184 } 185 186 // connectTCP connects to the given IP and port from an ephemeral local address. 187 func connectTCP(ctx context.Context, ip net.IP, port int, ipv6 bool) error { 188 contAddr := net.TCPAddr{ 189 IP: ip, 190 Port: port, 191 } 192 // The container may not be listening when we first connect, so retry 193 // upon error. 194 callback := func() error { 195 var d net.Dialer 196 conn, err := d.DialContext(ctx, tcpNetwork(ipv6), contAddr.String()) 197 if conn != nil { 198 conn.Close() 199 } 200 return err 201 } 202 if err := testutil.PollContext(ctx, callback); err != nil { 203 return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %w", port, err) 204 } 205 206 return nil 207 } 208 209 // localAddrs returns a list of local network interface addresses. When ipv6 is 210 // true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are 211 // returned. 212 func localAddrs(ipv6 bool) ([]string, error) { 213 addrs, err := net.InterfaceAddrs() 214 if err != nil { 215 return nil, err 216 } 217 addrStrs := make([]string, 0, len(addrs)) 218 for _, addr := range addrs { 219 // Add only IPv4 or only IPv6 addresses. 220 parts := strings.Split(addr.String(), "/") 221 if len(parts) != 2 { 222 return nil, fmt.Errorf("bad interface address: %q", addr.String()) 223 } 224 if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { 225 addrStrs = append(addrStrs, addr.String()) 226 } 227 } 228 return filterAddrs(addrStrs, ipv6), nil 229 } 230 231 func filterAddrs(addrs []string, ipv6 bool) []string { 232 addrStrs := make([]string, 0, len(addrs)) 233 for _, addr := range addrs { 234 // Add only IPv4 or only IPv6 addresses. 235 parts := strings.Split(addr, "/") 236 if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { 237 addrStrs = append(addrStrs, parts[0]) 238 } 239 } 240 return addrStrs 241 } 242 243 // getInterfaceName returns the name of the interface other than loopback. 244 func getInterfaceName() (string, bool) { 245 iface, ok := getNonLoopbackInterface() 246 if !ok { 247 return "", false 248 } 249 return iface.Name, true 250 } 251 252 func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { 253 iface, ok := getNonLoopbackInterface() 254 if !ok { 255 return nil, errors.New("no non-loopback interface found") 256 } 257 addrs, err := iface.Addrs() 258 if err != nil { 259 return nil, err 260 } 261 262 // Get only IPv4 or IPv6 addresses. 263 ips := make([]net.IP, 0, len(addrs)) 264 for _, addr := range addrs { 265 parts := strings.Split(addr.String(), "/") 266 var ip net.IP 267 // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. 268 // So we check whether To4() returns nil to test whether the 269 // address is v4 or v6. 270 if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { 271 ip = net.ParseIP(parts[0]).To16() 272 } else { 273 ip = v4 274 } 275 if ip != nil { 276 ips = append(ips, ip) 277 } 278 } 279 return ips, nil 280 } 281 282 func getNonLoopbackInterface() (net.Interface, bool) { 283 if interfaces, err := net.Interfaces(); err == nil { 284 for _, intf := range interfaces { 285 if intf.Name != "lo" { 286 return intf, true 287 } 288 } 289 } 290 return net.Interface{}, false 291 } 292 293 func htons(x uint16) uint16 { 294 buf := make([]byte, 2) 295 binary.BigEndian.PutUint16(buf, x) 296 return binary.LittleEndian.Uint16(buf) 297 } 298 299 func localIP(ipv6 bool) string { 300 if ipv6 { 301 return "::1" 302 } 303 return "127.0.0.1" 304 } 305 306 func nowhereIP(ipv6 bool) string { 307 if ipv6 { 308 return "2001:db8::1" 309 } 310 return "192.0.2.1" 311 } 312 313 // udpNetwork returns an IPv6 or IPv6 UDP network argument to net.Dial. 314 func udpNetwork(ipv6 bool) string { 315 if ipv6 { 316 return "udp6" 317 } 318 return "udp4" 319 } 320 321 // tcpNetwork returns an IPv6 or IPv6 TCP network argument to net.Dial. 322 func tcpNetwork(ipv6 bool) string { 323 if ipv6 { 324 return "tcp6" 325 } 326 return "tcp4" 327 }