github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/iptables/iptables_util.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package iptables
    16  
    17  import (
    18  	"context"
    19  	"encoding/binary"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"os/exec"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/SagerNet/gvisor/pkg/test/testutil"
    28  )
    29  
    30  // filterTable calls `ip{6}tables -t filter` with the given args.
    31  func filterTable(ipv6 bool, args ...string) error {
    32  	return tableCmd(ipv6, "filter", args)
    33  }
    34  
    35  // natTable calls `ip{6}tables -t nat` with the given args.
    36  func natTable(ipv6 bool, args ...string) error {
    37  	return tableCmd(ipv6, "nat", args)
    38  }
    39  
    40  func tableCmd(ipv6 bool, table string, args []string) error {
    41  	args = append([]string{"-t", table}, args...)
    42  	binary := "iptables"
    43  	if ipv6 {
    44  		binary = "ip6tables"
    45  	}
    46  	cmd := exec.Command(binary, args...)
    47  	if out, err := cmd.CombinedOutput(); err != nil {
    48  		return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
    49  	}
    50  	return nil
    51  }
    52  
    53  // filterTableRules is like filterTable, but runs multiple iptables commands.
    54  func filterTableRules(ipv6 bool, argsList [][]string) error {
    55  	return tableRules(ipv6, "filter", argsList)
    56  }
    57  
    58  // natTableRules is like natTable, but runs multiple iptables commands.
    59  func natTableRules(ipv6 bool, argsList [][]string) error {
    60  	return tableRules(ipv6, "nat", argsList)
    61  }
    62  
    63  func tableRules(ipv6 bool, table string, argsList [][]string) error {
    64  	for _, args := range argsList {
    65  		if err := tableCmd(ipv6, table, args); err != nil {
    66  			return err
    67  		}
    68  	}
    69  	return nil
    70  }
    71  
    72  // listenUDP listens on a UDP port and returns nil if the first read from that
    73  // port is successful.
    74  func listenUDP(ctx context.Context, port int, ipv6 bool) error {
    75  	_, err := listenUDPFrom(ctx, port, ipv6)
    76  	return err
    77  }
    78  
    79  // listenUDPFrom listens on a UDP port and returns the sender's UDP address if
    80  // the first read from that port is successful.
    81  func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
    82  	localAddr := net.UDPAddr{
    83  		Port: port,
    84  	}
    85  	conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	defer conn.Close()
    90  
    91  	type result struct {
    92  		remoteAddr *net.UDPAddr
    93  		err        error
    94  	}
    95  
    96  	ch := make(chan result)
    97  	go func() {
    98  		_, remoteAddr, err := conn.ReadFromUDP([]byte{0})
    99  		ch <- result{remoteAddr, err}
   100  	}()
   101  
   102  	select {
   103  	case res := <-ch:
   104  		return res.remoteAddr, res.err
   105  	case <-ctx.Done():
   106  		return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
   107  	}
   108  }
   109  
   110  // sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified
   111  // over a duration.
   112  func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
   113  	remote := net.UDPAddr{
   114  		IP:   ip,
   115  		Port: port,
   116  	}
   117  	conn, err := net.DialUDP(udpNetwork(ipv6), nil, &remote)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	defer conn.Close()
   122  
   123  	for {
   124  		// This may return an error (connection refused) if the remote
   125  		// hasn't started listening yet or they're dropping our
   126  		// packets. So we ignore Write errors and depend on the remote
   127  		// to report a failure if it doesn't get a packet it needs.
   128  		conn.Write([]byte{0})
   129  		select {
   130  		case <-ctx.Done():
   131  			// Being cancelled or timing out isn't an error, as we
   132  			// cannot tell with UDP whether we succeeded.
   133  			return nil
   134  		// Continue looping.
   135  		case <-time.After(200 * time.Millisecond):
   136  		}
   137  	}
   138  }
   139  
   140  // listenTCP listens for connections on a TCP port, and returns nil if a
   141  // connection is established.
   142  func listenTCP(ctx context.Context, port int, ipv6 bool) error {
   143  	_, err := listenTCPFrom(ctx, port, ipv6)
   144  	return err
   145  }
   146  
   147  // listenTCP listens for connections on a TCP port, and returns the remote
   148  // TCP address if a connection is established.
   149  func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
   150  	localAddr := net.TCPAddr{
   151  		Port: port,
   152  	}
   153  
   154  	// Starts listening on port.
   155  	lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	defer lConn.Close()
   160  
   161  	type result struct {
   162  		remoteAddr net.Addr
   163  		err        error
   164  	}
   165  
   166  	// Accept connections on port.
   167  	ch := make(chan result)
   168  	go func() {
   169  		conn, err := lConn.AcceptTCP()
   170  		var remoteAddr net.Addr
   171  		if err == nil {
   172  			remoteAddr = conn.RemoteAddr()
   173  		}
   174  		ch <- result{remoteAddr, err}
   175  		conn.Close()
   176  	}()
   177  
   178  	select {
   179  	case res := <-ch:
   180  		return res.remoteAddr, res.err
   181  	case <-ctx.Done():
   182  		return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
   183  	}
   184  }
   185  
   186  // connectTCP connects to the given IP and port from an ephemeral local address.
   187  func connectTCP(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
   188  	contAddr := net.TCPAddr{
   189  		IP:   ip,
   190  		Port: port,
   191  	}
   192  	// The container may not be listening when we first connect, so retry
   193  	// upon error.
   194  	callback := func() error {
   195  		var d net.Dialer
   196  		conn, err := d.DialContext(ctx, tcpNetwork(ipv6), contAddr.String())
   197  		if conn != nil {
   198  			conn.Close()
   199  		}
   200  		return err
   201  	}
   202  	if err := testutil.PollContext(ctx, callback); err != nil {
   203  		return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %w", port, err)
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  // localAddrs returns a list of local network interface addresses. When ipv6 is
   210  // true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are
   211  // returned.
   212  func localAddrs(ipv6 bool) ([]string, error) {
   213  	addrs, err := net.InterfaceAddrs()
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	addrStrs := make([]string, 0, len(addrs))
   218  	for _, addr := range addrs {
   219  		// Add only IPv4 or only IPv6 addresses.
   220  		parts := strings.Split(addr.String(), "/")
   221  		if len(parts) != 2 {
   222  			return nil, fmt.Errorf("bad interface address: %q", addr.String())
   223  		}
   224  		if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 {
   225  			addrStrs = append(addrStrs, addr.String())
   226  		}
   227  	}
   228  	return filterAddrs(addrStrs, ipv6), nil
   229  }
   230  
   231  func filterAddrs(addrs []string, ipv6 bool) []string {
   232  	addrStrs := make([]string, 0, len(addrs))
   233  	for _, addr := range addrs {
   234  		// Add only IPv4 or only IPv6 addresses.
   235  		parts := strings.Split(addr, "/")
   236  		if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 {
   237  			addrStrs = append(addrStrs, parts[0])
   238  		}
   239  	}
   240  	return addrStrs
   241  }
   242  
   243  // getInterfaceName returns the name of the interface other than loopback.
   244  func getInterfaceName() (string, bool) {
   245  	iface, ok := getNonLoopbackInterface()
   246  	if !ok {
   247  		return "", false
   248  	}
   249  	return iface.Name, true
   250  }
   251  
   252  func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) {
   253  	iface, ok := getNonLoopbackInterface()
   254  	if !ok {
   255  		return nil, errors.New("no non-loopback interface found")
   256  	}
   257  	addrs, err := iface.Addrs()
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	// Get only IPv4 or IPv6 addresses.
   263  	ips := make([]net.IP, 0, len(addrs))
   264  	for _, addr := range addrs {
   265  		parts := strings.Split(addr.String(), "/")
   266  		var ip net.IP
   267  		// To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses.
   268  		// So we check whether To4() returns nil to test whether the
   269  		// address is v4 or v6.
   270  		if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil {
   271  			ip = net.ParseIP(parts[0]).To16()
   272  		} else {
   273  			ip = v4
   274  		}
   275  		if ip != nil {
   276  			ips = append(ips, ip)
   277  		}
   278  	}
   279  	return ips, nil
   280  }
   281  
   282  func getNonLoopbackInterface() (net.Interface, bool) {
   283  	if interfaces, err := net.Interfaces(); err == nil {
   284  		for _, intf := range interfaces {
   285  			if intf.Name != "lo" {
   286  				return intf, true
   287  			}
   288  		}
   289  	}
   290  	return net.Interface{}, false
   291  }
   292  
   293  func htons(x uint16) uint16 {
   294  	buf := make([]byte, 2)
   295  	binary.BigEndian.PutUint16(buf, x)
   296  	return binary.LittleEndian.Uint16(buf)
   297  }
   298  
   299  func localIP(ipv6 bool) string {
   300  	if ipv6 {
   301  		return "::1"
   302  	}
   303  	return "127.0.0.1"
   304  }
   305  
   306  func nowhereIP(ipv6 bool) string {
   307  	if ipv6 {
   308  		return "2001:db8::1"
   309  	}
   310  	return "192.0.2.1"
   311  }
   312  
   313  // udpNetwork returns an IPv6 or IPv6 UDP network argument to net.Dial.
   314  func udpNetwork(ipv6 bool) string {
   315  	if ipv6 {
   316  		return "udp6"
   317  	}
   318  	return "udp4"
   319  }
   320  
   321  // tcpNetwork returns an IPv6 or IPv6 TCP network argument to net.Dial.
   322  func tcpNetwork(ipv6 bool) string {
   323  	if ipv6 {
   324  		return "tcp6"
   325  	}
   326  	return "tcp4"
   327  }