github.com/pwn-term/docker@v0.0.0-20210616085119-6e977cce2565/libnetwork/iptables/iptables.go (about)

     1  package iptables
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"os/exec"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/sirupsen/logrus"
    15  )
    16  
    17  // Action signifies the iptable action.
    18  type Action string
    19  
    20  // Policy is the default iptable policies
    21  type Policy string
    22  
    23  // Table refers to Nat, Filter or Mangle.
    24  type Table string
    25  
    26  // IPVersion refers to IP version, v4 or v6
    27  type IPVersion string
    28  
    29  const (
    30  	// Append appends the rule at the end of the chain.
    31  	Append Action = "-A"
    32  	// Delete deletes the rule from the chain.
    33  	Delete Action = "-D"
    34  	// Insert inserts the rule at the top of the chain.
    35  	Insert Action = "-I"
    36  	// Nat table is used for nat translation rules.
    37  	Nat Table = "nat"
    38  	// Filter table is used for filter rules.
    39  	Filter Table = "filter"
    40  	// Mangle table is used for mangling the packet.
    41  	Mangle Table = "mangle"
    42  	// Drop is the default iptables DROP policy
    43  	Drop Policy = "DROP"
    44  	// Accept is the default iptables ACCEPT policy
    45  	Accept Policy = "ACCEPT"
    46  	// IPv4 is version 4
    47  	IPv4 IPVersion = "IPV4"
    48  	// IPv6 is version 6
    49  	IPv6 IPVersion = "IPV6"
    50  )
    51  
    52  var (
    53  	iptablesPath  string
    54  	ip6tablesPath string
    55  	supportsXlock = false
    56  	supportsCOpt  = false
    57  	xLockWaitMsg  = "Another app is currently holding the xtables lock"
    58  	// used to lock iptables commands if xtables lock is not supported
    59  	bestEffortLock sync.Mutex
    60  	// ErrIptablesNotFound is returned when the rule is not found.
    61  	ErrIptablesNotFound = errors.New("Iptables not found")
    62  	initOnce            sync.Once
    63  )
    64  
    65  // IPTable defines struct with IPVersion
    66  type IPTable struct {
    67  	Version IPVersion
    68  }
    69  
    70  // ChainInfo defines the iptables chain.
    71  type ChainInfo struct {
    72  	Name        string
    73  	Table       Table
    74  	HairpinMode bool
    75  	IPTable     IPTable
    76  }
    77  
    78  // ChainError is returned to represent errors during ip table operation.
    79  type ChainError struct {
    80  	Chain  string
    81  	Output []byte
    82  }
    83  
    84  func (e ChainError) Error() string {
    85  	return fmt.Sprintf("Error iptables %s: %s", e.Chain, string(e.Output))
    86  }
    87  
    88  func probe() {
    89  	path, err := exec.LookPath("iptables")
    90  	if err != nil {
    91  		logrus.Warnf("Failed to find iptables: %v", err)
    92  		return
    93  	}
    94  	if out, err := exec.Command(path, "--wait", "-t", "nat", "-L", "-n").CombinedOutput(); err != nil {
    95  		logrus.Warnf("Running iptables --wait -t nat -L -n failed with message: `%s`, error: %v", strings.TrimSpace(string(out)), err)
    96  	}
    97  	_, err = exec.LookPath("ip6tables")
    98  	if err != nil {
    99  		logrus.Warnf("Failed to find ip6tables: %v", err)
   100  		return
   101  	}
   102  }
   103  
   104  func initFirewalld() {
   105  	if err := FirewalldInit(); err != nil {
   106  		logrus.Debugf("Fail to initialize firewalld: %v, using raw iptables instead", err)
   107  	}
   108  }
   109  
   110  func detectIptables() {
   111  	path, err := exec.LookPath("iptables")
   112  	if err != nil {
   113  		return
   114  	}
   115  	iptablesPath = path
   116  	path, err = exec.LookPath("ip6tables")
   117  	if err != nil {
   118  		return
   119  	}
   120  	ip6tablesPath = path
   121  	supportsXlock = exec.Command(iptablesPath, "--wait", "-L", "-n").Run() == nil
   122  	mj, mn, mc, err := GetVersion()
   123  	if err != nil {
   124  		logrus.Warnf("Failed to read iptables version: %v", err)
   125  		return
   126  	}
   127  	supportsCOpt = supportsCOption(mj, mn, mc)
   128  }
   129  
   130  func initDependencies() {
   131  	probe()
   132  	initFirewalld()
   133  	detectIptables()
   134  }
   135  
   136  func initCheck() error {
   137  	initOnce.Do(initDependencies)
   138  
   139  	if iptablesPath == "" {
   140  		return ErrIptablesNotFound
   141  	}
   142  	return nil
   143  }
   144  
   145  // GetIptable returns an instance of IPTable with specified version
   146  func GetIptable(version IPVersion) *IPTable {
   147  	return &IPTable{Version: version}
   148  }
   149  
   150  // NewChain adds a new chain to ip table.
   151  func (iptable IPTable) NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) {
   152  	c := &ChainInfo{
   153  		Name:        name,
   154  		Table:       table,
   155  		HairpinMode: hairpinMode,
   156  		IPTable:     iptable,
   157  	}
   158  	if string(c.Table) == "" {
   159  		c.Table = Filter
   160  	}
   161  
   162  	// Add chain if it doesn't exist
   163  	if _, err := iptable.Raw("-t", string(c.Table), "-n", "-L", c.Name); err != nil {
   164  		if output, err := iptable.Raw("-t", string(c.Table), "-N", c.Name); err != nil {
   165  			return nil, err
   166  		} else if len(output) != 0 {
   167  			return nil, fmt.Errorf("Could not create %s/%s chain: %s", c.Table, c.Name, output)
   168  		}
   169  	}
   170  	return c, nil
   171  }
   172  
   173  // LoopbackByVersion returns loopback address by version
   174  func (iptable IPTable) LoopbackByVersion() string {
   175  	if iptable.Version == IPv6 {
   176  		return "::1/128"
   177  	}
   178  	return "127.0.0.0/8"
   179  }
   180  
   181  // ProgramChain is used to add rules to a chain
   182  func (iptable IPTable) ProgramChain(c *ChainInfo, bridgeName string, hairpinMode, enable bool) error {
   183  	if c.Name == "" {
   184  		return errors.New("Could not program chain, missing chain name")
   185  	}
   186  
   187  	// Either add or remove the interface from the firewalld zone
   188  	if firewalldRunning {
   189  		if enable {
   190  			if err := AddInterfaceFirewalld(bridgeName); err != nil {
   191  				return err
   192  			}
   193  		} else {
   194  			if err := DelInterfaceFirewalld(bridgeName); err != nil {
   195  				return err
   196  			}
   197  		}
   198  	}
   199  
   200  	switch c.Table {
   201  	case Nat:
   202  		preroute := []string{
   203  			"-m", "addrtype",
   204  			"--dst-type", "LOCAL",
   205  			"-j", c.Name}
   206  		if !iptable.Exists(Nat, "PREROUTING", preroute...) && enable {
   207  			if err := c.Prerouting(Append, preroute...); err != nil {
   208  				return fmt.Errorf("Failed to inject %s in PREROUTING chain: %s", c.Name, err)
   209  			}
   210  		} else if iptable.Exists(Nat, "PREROUTING", preroute...) && !enable {
   211  			if err := c.Prerouting(Delete, preroute...); err != nil {
   212  				return fmt.Errorf("Failed to remove %s in PREROUTING chain: %s", c.Name, err)
   213  			}
   214  		}
   215  		output := []string{
   216  			"-m", "addrtype",
   217  			"--dst-type", "LOCAL",
   218  			"-j", c.Name}
   219  		if !hairpinMode {
   220  			output = append(output, "!", "--dst", iptable.LoopbackByVersion())
   221  		}
   222  		if !iptable.Exists(Nat, "OUTPUT", output...) && enable {
   223  			if err := c.Output(Append, output...); err != nil {
   224  				return fmt.Errorf("Failed to inject %s in OUTPUT chain: %s", c.Name, err)
   225  			}
   226  		} else if iptable.Exists(Nat, "OUTPUT", output...) && !enable {
   227  			if err := c.Output(Delete, output...); err != nil {
   228  				return fmt.Errorf("Failed to inject %s in OUTPUT chain: %s", c.Name, err)
   229  			}
   230  		}
   231  	case Filter:
   232  		if bridgeName == "" {
   233  			return fmt.Errorf("Could not program chain %s/%s, missing bridge name",
   234  				c.Table, c.Name)
   235  		}
   236  		link := []string{
   237  			"-o", bridgeName,
   238  			"-j", c.Name}
   239  		if !iptable.Exists(Filter, "FORWARD", link...) && enable {
   240  			insert := append([]string{string(Insert), "FORWARD"}, link...)
   241  			if output, err := iptable.Raw(insert...); err != nil {
   242  				return err
   243  			} else if len(output) != 0 {
   244  				return fmt.Errorf("Could not create linking rule to %s/%s: %s", c.Table, c.Name, output)
   245  			}
   246  		} else if iptable.Exists(Filter, "FORWARD", link...) && !enable {
   247  			del := append([]string{string(Delete), "FORWARD"}, link...)
   248  			if output, err := iptable.Raw(del...); err != nil {
   249  				return err
   250  			} else if len(output) != 0 {
   251  				return fmt.Errorf("Could not delete linking rule from %s/%s: %s", c.Table, c.Name, output)
   252  			}
   253  
   254  		}
   255  		establish := []string{
   256  			"-o", bridgeName,
   257  			"-m", "conntrack",
   258  			"--ctstate", "RELATED,ESTABLISHED",
   259  			"-j", "ACCEPT"}
   260  		if !iptable.Exists(Filter, "FORWARD", establish...) && enable {
   261  			insert := append([]string{string(Insert), "FORWARD"}, establish...)
   262  			if output, err := iptable.Raw(insert...); err != nil {
   263  				return err
   264  			} else if len(output) != 0 {
   265  				return fmt.Errorf("Could not create establish rule to %s: %s", c.Table, output)
   266  			}
   267  		} else if iptable.Exists(Filter, "FORWARD", establish...) && !enable {
   268  			del := append([]string{string(Delete), "FORWARD"}, establish...)
   269  			if output, err := iptable.Raw(del...); err != nil {
   270  				return err
   271  			} else if len(output) != 0 {
   272  				return fmt.Errorf("Could not delete establish rule from %s: %s", c.Table, output)
   273  			}
   274  		}
   275  	}
   276  	return nil
   277  }
   278  
   279  // RemoveExistingChain removes existing chain from the table.
   280  func (iptable IPTable) RemoveExistingChain(name string, table Table) error {
   281  	c := &ChainInfo{
   282  		Name:    name,
   283  		Table:   table,
   284  		IPTable: iptable,
   285  	}
   286  	if string(c.Table) == "" {
   287  		c.Table = Filter
   288  	}
   289  	return c.Remove()
   290  }
   291  
   292  // Forward adds forwarding rule to 'filter' table and corresponding nat rule to 'nat' table.
   293  func (c *ChainInfo) Forward(action Action, ip net.IP, port int, proto, destAddr string, destPort int, bridgeName string) error {
   294  
   295  	iptable := GetIptable(c.IPTable.Version)
   296  	daddr := ip.String()
   297  	if ip.IsUnspecified() {
   298  		// iptables interprets "0.0.0.0" as "0.0.0.0/32", whereas we
   299  		// want "0.0.0.0/0". "0/0" is correctly interpreted as "any
   300  		// value" by both iptables and ip6tables.
   301  		daddr = "0/0"
   302  	}
   303  
   304  	args := []string{
   305  		"-p", proto,
   306  		"-d", daddr,
   307  		"--dport", strconv.Itoa(port),
   308  		"-j", "DNAT",
   309  		"--to-destination", net.JoinHostPort(destAddr, strconv.Itoa(destPort))}
   310  
   311  	if !c.HairpinMode {
   312  		args = append(args, "!", "-i", bridgeName)
   313  	}
   314  	if err := iptable.ProgramRule(Nat, c.Name, action, args); err != nil {
   315  		return err
   316  	}
   317  
   318  	args = []string{
   319  		"!", "-i", bridgeName,
   320  		"-o", bridgeName,
   321  		"-p", proto,
   322  		"-d", destAddr,
   323  		"--dport", strconv.Itoa(destPort),
   324  		"-j", "ACCEPT",
   325  	}
   326  	if err := iptable.ProgramRule(Filter, c.Name, action, args); err != nil {
   327  		return err
   328  	}
   329  
   330  	args = []string{
   331  		"-p", proto,
   332  		"-s", destAddr,
   333  		"-d", destAddr,
   334  		"--dport", strconv.Itoa(destPort),
   335  		"-j", "MASQUERADE",
   336  	}
   337  
   338  	if err := iptable.ProgramRule(Nat, "POSTROUTING", action, args); err != nil {
   339  		return err
   340  	}
   341  
   342  	if proto == "sctp" {
   343  		// Linux kernel v4.9 and below enables NETIF_F_SCTP_CRC for veth by
   344  		// the following commit.
   345  		// This introduces a problem when conbined with a physical NIC without
   346  		// NETIF_F_SCTP_CRC. As for a workaround, here we add an iptables entry
   347  		// to fill the checksum.
   348  		//
   349  		// https://github.com/torvalds/linux/commit/c80fafbbb59ef9924962f83aac85531039395b18
   350  		args = []string{
   351  			"-p", proto,
   352  			"--sport", strconv.Itoa(destPort),
   353  			"-j", "CHECKSUM",
   354  			"--checksum-fill",
   355  		}
   356  		if err := iptable.ProgramRule(Mangle, "POSTROUTING", action, args); err != nil {
   357  			return err
   358  		}
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  // Link adds reciprocal ACCEPT rule for two supplied IP addresses.
   365  // Traffic is allowed from ip1 to ip2 and vice-versa
   366  func (c *ChainInfo) Link(action Action, ip1, ip2 net.IP, port int, proto string, bridgeName string) error {
   367  	iptable := GetIptable(c.IPTable.Version)
   368  	// forward
   369  	args := []string{
   370  		"-i", bridgeName, "-o", bridgeName,
   371  		"-p", proto,
   372  		"-s", ip1.String(),
   373  		"-d", ip2.String(),
   374  		"--dport", strconv.Itoa(port),
   375  		"-j", "ACCEPT",
   376  	}
   377  
   378  	if err := iptable.ProgramRule(Filter, c.Name, action, args); err != nil {
   379  		return err
   380  	}
   381  	// reverse
   382  	args[7], args[9] = args[9], args[7]
   383  	args[10] = "--sport"
   384  	return iptable.ProgramRule(Filter, c.Name, action, args)
   385  }
   386  
   387  // ProgramRule adds the rule specified by args only if the
   388  // rule is not already present in the chain. Reciprocally,
   389  // it removes the rule only if present.
   390  func (iptable IPTable) ProgramRule(table Table, chain string, action Action, args []string) error {
   391  	if iptable.Exists(table, chain, args...) != (action == Delete) {
   392  		return nil
   393  	}
   394  	return iptable.RawCombinedOutput(append([]string{"-t", string(table), string(action), chain}, args...)...)
   395  }
   396  
   397  // Prerouting adds linking rule to nat/PREROUTING chain.
   398  func (c *ChainInfo) Prerouting(action Action, args ...string) error {
   399  	iptable := GetIptable(c.IPTable.Version)
   400  	a := []string{"-t", string(Nat), string(action), "PREROUTING"}
   401  	if len(args) > 0 {
   402  		a = append(a, args...)
   403  	}
   404  	if output, err := iptable.Raw(a...); err != nil {
   405  		return err
   406  	} else if len(output) != 0 {
   407  		return ChainError{Chain: "PREROUTING", Output: output}
   408  	}
   409  	return nil
   410  }
   411  
   412  // Output adds linking rule to an OUTPUT chain.
   413  func (c *ChainInfo) Output(action Action, args ...string) error {
   414  	iptable := GetIptable(c.IPTable.Version)
   415  	a := []string{"-t", string(c.Table), string(action), "OUTPUT"}
   416  	if len(args) > 0 {
   417  		a = append(a, args...)
   418  	}
   419  	if output, err := iptable.Raw(a...); err != nil {
   420  		return err
   421  	} else if len(output) != 0 {
   422  		return ChainError{Chain: "OUTPUT", Output: output}
   423  	}
   424  	return nil
   425  }
   426  
   427  // Remove removes the chain.
   428  func (c *ChainInfo) Remove() error {
   429  	iptable := GetIptable(c.IPTable.Version)
   430  	// Ignore errors - This could mean the chains were never set up
   431  	if c.Table == Nat {
   432  		c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name)
   433  		c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", iptable.LoopbackByVersion(), "-j", c.Name)
   434  		c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name) // Created in versions <= 0.1.6
   435  
   436  		c.Prerouting(Delete)
   437  		c.Output(Delete)
   438  	}
   439  	iptable.Raw("-t", string(c.Table), "-F", c.Name)
   440  	iptable.Raw("-t", string(c.Table), "-X", c.Name)
   441  	return nil
   442  }
   443  
   444  // Exists checks if a rule exists
   445  func (iptable IPTable) Exists(table Table, chain string, rule ...string) bool {
   446  	return iptable.exists(false, table, chain, rule...)
   447  }
   448  
   449  // ExistsNative behaves as Exists with the difference it
   450  // will always invoke `iptables` binary.
   451  func (iptable IPTable) ExistsNative(table Table, chain string, rule ...string) bool {
   452  	return iptable.exists(true, table, chain, rule...)
   453  }
   454  
   455  func (iptable IPTable) exists(native bool, table Table, chain string, rule ...string) bool {
   456  	f := iptable.Raw
   457  	if native {
   458  		f = iptable.raw
   459  	}
   460  
   461  	if string(table) == "" {
   462  		table = Filter
   463  	}
   464  
   465  	if err := initCheck(); err != nil {
   466  		// The exists() signature does not allow us to return an error, but at least
   467  		// we can skip the (likely invalid) exec invocation.
   468  		return false
   469  	}
   470  
   471  	if supportsCOpt {
   472  		// if exit status is 0 then return true, the rule exists
   473  		_, err := f(append([]string{"-t", string(table), "-C", chain}, rule...)...)
   474  		return err == nil
   475  	}
   476  
   477  	// parse "iptables -S" for the rule (it checks rules in a specific chain
   478  	// in a specific table and it is very unreliable)
   479  	return iptable.existsRaw(table, chain, rule...)
   480  }
   481  
   482  func (iptable IPTable) existsRaw(table Table, chain string, rule ...string) bool {
   483  	path := iptablesPath
   484  	if iptable.Version == IPv6 {
   485  		path = ip6tablesPath
   486  	}
   487  	ruleString := fmt.Sprintf("%s %s\n", chain, strings.Join(rule, " "))
   488  	existingRules, _ := exec.Command(path, "-t", string(table), "-S", chain).Output()
   489  
   490  	return strings.Contains(string(existingRules), ruleString)
   491  }
   492  
   493  // Maximum duration that an iptables operation can take
   494  // before flagging a warning.
   495  const opWarnTime = 2 * time.Second
   496  
   497  func filterOutput(start time.Time, output []byte, args ...string) []byte {
   498  	// Flag operations that have taken a long time to complete
   499  	opTime := time.Since(start)
   500  	if opTime > opWarnTime {
   501  		logrus.Warnf("xtables contention detected while running [%s]: Waited for %.2f seconds and received %q", strings.Join(args, " "), float64(opTime)/float64(time.Second), string(output))
   502  	}
   503  	// ignore iptables' message about xtables lock:
   504  	// it is a warning, not an error.
   505  	if strings.Contains(string(output), xLockWaitMsg) {
   506  		output = []byte("")
   507  	}
   508  	// Put further filters here if desired
   509  	return output
   510  }
   511  
   512  // Raw calls 'iptables' system command, passing supplied arguments.
   513  func (iptable IPTable) Raw(args ...string) ([]byte, error) {
   514  	if firewalldRunning {
   515  		// select correct IP version for firewalld
   516  		ipv := Iptables
   517  		if iptable.Version == IPv6 {
   518  			ipv = IP6Tables
   519  		}
   520  
   521  		startTime := time.Now()
   522  		output, err := Passthrough(ipv, args...)
   523  		if err == nil || !strings.Contains(err.Error(), "was not provided by any .service files") {
   524  			return filterOutput(startTime, output, args...), err
   525  		}
   526  	}
   527  	return iptable.raw(args...)
   528  }
   529  
   530  func (iptable IPTable) raw(args ...string) ([]byte, error) {
   531  	if err := initCheck(); err != nil {
   532  		return nil, err
   533  	}
   534  	if supportsXlock {
   535  		args = append([]string{"--wait"}, args...)
   536  	} else {
   537  		bestEffortLock.Lock()
   538  		defer bestEffortLock.Unlock()
   539  	}
   540  
   541  	path := iptablesPath
   542  	commandName := "iptables"
   543  	if iptable.Version == IPv6 {
   544  		path = ip6tablesPath
   545  		commandName = "ip6tables"
   546  	}
   547  
   548  	logrus.Debugf("%s, %v", path, args)
   549  
   550  	startTime := time.Now()
   551  	output, err := exec.Command(path, args...).CombinedOutput()
   552  	if err != nil {
   553  		return nil, fmt.Errorf("iptables failed: %s %v: %s (%s)", commandName, strings.Join(args, " "), output, err)
   554  	}
   555  
   556  	return filterOutput(startTime, output, args...), err
   557  }
   558  
   559  // RawCombinedOutput internally calls the Raw function and returns a non nil
   560  // error if Raw returned a non nil error or a non empty output
   561  func (iptable IPTable) RawCombinedOutput(args ...string) error {
   562  	if output, err := iptable.Raw(args...); err != nil || len(output) != 0 {
   563  		return fmt.Errorf("%s (%v)", string(output), err)
   564  	}
   565  	return nil
   566  }
   567  
   568  // RawCombinedOutputNative behave as RawCombinedOutput with the difference it
   569  // will always invoke `iptables` binary
   570  func (iptable IPTable) RawCombinedOutputNative(args ...string) error {
   571  	if output, err := iptable.raw(args...); err != nil || len(output) != 0 {
   572  		return fmt.Errorf("%s (%v)", string(output), err)
   573  	}
   574  	return nil
   575  }
   576  
   577  // ExistChain checks if a chain exists
   578  func (iptable IPTable) ExistChain(chain string, table Table) bool {
   579  	if _, err := iptable.Raw("-t", string(table), "-nL", chain); err == nil {
   580  		return true
   581  	}
   582  	return false
   583  }
   584  
   585  // GetVersion reads the iptables version numbers during initialization
   586  func GetVersion() (major, minor, micro int, err error) {
   587  	out, err := exec.Command(iptablesPath, "--version").CombinedOutput()
   588  	if err == nil {
   589  		major, minor, micro = parseVersionNumbers(string(out))
   590  	}
   591  	return
   592  }
   593  
   594  // SetDefaultPolicy sets the passed default policy for the table/chain
   595  func (iptable IPTable) SetDefaultPolicy(table Table, chain string, policy Policy) error {
   596  	if err := iptable.RawCombinedOutput("-t", string(table), "-P", chain, string(policy)); err != nil {
   597  		return fmt.Errorf("setting default policy to %v in %v chain failed: %v", policy, chain, err)
   598  	}
   599  	return nil
   600  }
   601  
   602  func parseVersionNumbers(input string) (major, minor, micro int) {
   603  	re := regexp.MustCompile(`v\d*.\d*.\d*`)
   604  	line := re.FindString(input)
   605  	fmt.Sscanf(line, "v%d.%d.%d", &major, &minor, &micro)
   606  	return
   607  }
   608  
   609  // iptables -C, --check option was added in v.1.4.11
   610  // http://ftp.netfilter.org/pub/iptables/changes-iptables-1.4.11.txt
   611  func supportsCOption(mj, mn, mc int) bool {
   612  	return mj > 1 || (mj == 1 && (mn > 4 || (mn == 4 && mc >= 11)))
   613  }
   614  
   615  // AddReturnRule adds a return rule for the chain in the filter table
   616  func (iptable IPTable) AddReturnRule(chain string) error {
   617  	var (
   618  		table = Filter
   619  		args  = []string{"-j", "RETURN"}
   620  	)
   621  
   622  	if iptable.Exists(table, chain, args...) {
   623  		return nil
   624  	}
   625  
   626  	err := iptable.RawCombinedOutput(append([]string{"-A", chain}, args...)...)
   627  	if err != nil {
   628  		return fmt.Errorf("unable to add return rule in %s chain: %s", chain, err.Error())
   629  	}
   630  
   631  	return nil
   632  }
   633  
   634  // EnsureJumpRule ensures the jump rule is on top
   635  func (iptable IPTable) EnsureJumpRule(fromChain, toChain string) error {
   636  	var (
   637  		table = Filter
   638  		args  = []string{"-j", toChain}
   639  	)
   640  
   641  	if iptable.Exists(table, fromChain, args...) {
   642  		err := iptable.RawCombinedOutput(append([]string{"-D", fromChain}, args...)...)
   643  		if err != nil {
   644  			return fmt.Errorf("unable to remove jump to %s rule in %s chain: %s", toChain, fromChain, err.Error())
   645  		}
   646  	}
   647  
   648  	err := iptable.RawCombinedOutput(append([]string{"-I", fromChain}, args...)...)
   649  	if err != nil {
   650  		return fmt.Errorf("unable to insert jump to %s rule in %s chain: %s", toChain, fromChain, err.Error())
   651  	}
   652  
   653  	return nil
   654  }