github.com/rootless-containers/rootlesskit/v2@v2.3.4/pkg/port/builtin/msg/msg.go (about)

     1  package msg
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"time"
     8  
     9  	"golang.org/x/sys/unix"
    10  
    11  	"github.com/rootless-containers/rootlesskit/v2/pkg/lowlevelmsgutil"
    12  	"github.com/rootless-containers/rootlesskit/v2/pkg/port"
    13  )
    14  
    15  const (
    16  	RequestTypeInit    = "init"
    17  	RequestTypeConnect = "connect"
    18  )
    19  
    20  // Request and Response are encoded as JSON with uint32le length header.
    21  type Request struct {
    22  	Type          string // "init" or "connect"
    23  	Proto         string // "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6"
    24  	IP            string
    25  	Port          int
    26  	ParentIP      string
    27  	HostGatewayIP string
    28  }
    29  
    30  // Reply may contain FD as OOB
    31  type Reply struct {
    32  	Error string
    33  }
    34  
    35  // Initiate sends "init" request to the child UNIX socket.
    36  func Initiate(c *net.UnixConn) error {
    37  	req := Request{
    38  		Type: RequestTypeInit,
    39  	}
    40  	if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil {
    41  		return err
    42  	}
    43  	if err := c.CloseWrite(); err != nil {
    44  		return err
    45  	}
    46  	var rep Reply
    47  	if _, err := lowlevelmsgutil.UnmarshalFromReader(c, &rep); err != nil {
    48  		return err
    49  	}
    50  	return c.CloseRead()
    51  }
    52  
    53  func hostGatewayIP() string {
    54  	addrs, err := net.InterfaceAddrs()
    55  	if err != nil {
    56  		return ""
    57  	}
    58  
    59  	for _, addr := range addrs {
    60  		if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
    61  			if ipnet.IP.To4() != nil {
    62  				return ipnet.IP.String()
    63  			}
    64  		}
    65  	}
    66  
    67  	return ""
    68  }
    69  
    70  // ConnectToChild connects to the child UNIX socket, and obtains TCP or UDP socket FD
    71  // that corresponds to the port spec.
    72  func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) {
    73  	req := Request{
    74  		Type:          RequestTypeConnect,
    75  		Proto:         spec.Proto,
    76  		Port:          spec.ChildPort,
    77  		IP:            spec.ChildIP,
    78  		ParentIP:      spec.ParentIP,
    79  		HostGatewayIP: hostGatewayIP(),
    80  	}
    81  	if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil {
    82  		return 0, err
    83  	}
    84  	if err := c.CloseWrite(); err != nil {
    85  		return 0, err
    86  	}
    87  	oobSpace := unix.CmsgSpace(4)
    88  	oob := make([]byte, oobSpace)
    89  	var (
    90  		oobN int
    91  		err  error
    92  	)
    93  	for {
    94  		_, oobN, _, _, err = c.ReadMsgUnix(nil, oob)
    95  		if err != unix.EINTR {
    96  			break
    97  		}
    98  	}
    99  	if err != nil {
   100  		return 0, err
   101  	}
   102  	if oobN != oobSpace {
   103  		return 0, fmt.Errorf("expected OOB space %d, got %d", oobSpace, oobN)
   104  	}
   105  	oob = oob[:oobN]
   106  	fd, err := parseFDFromOOB(oob)
   107  	if err != nil {
   108  		return 0, err
   109  	}
   110  	if err := c.CloseRead(); err != nil {
   111  		return 0, err
   112  	}
   113  	return fd, nil
   114  }
   115  
   116  // ConnectToChildWithSocketPath wraps ConnectToChild
   117  func ConnectToChildWithSocketPath(socketPath string, spec port.Spec) (int, error) {
   118  	var dialer net.Dialer
   119  	conn, err := dialer.Dial("unix", socketPath)
   120  	if err != nil {
   121  		return 0, err
   122  	}
   123  	defer conn.Close()
   124  	c := conn.(*net.UnixConn)
   125  	return ConnectToChild(c, spec)
   126  }
   127  
   128  // ConnectToChildWithRetry retries ConnectToChild every (i*5) milliseconds.
   129  func ConnectToChildWithRetry(socketPath string, spec port.Spec, retries int) (int, error) {
   130  	for i := 0; i < retries; i++ {
   131  		fd, err := ConnectToChildWithSocketPath(socketPath, spec)
   132  		if i == retries-1 && err != nil {
   133  			return 0, err
   134  		}
   135  		if err == nil {
   136  			return fd, err
   137  		}
   138  		// TODO: backoff
   139  		time.Sleep(time.Duration(i*5) * time.Millisecond)
   140  	}
   141  	// NOT REACHED
   142  	return 0, errors.New("reached max retry")
   143  }
   144  
   145  func parseFDFromOOB(oob []byte) (int, error) {
   146  	scms, err := unix.ParseSocketControlMessage(oob)
   147  	if err != nil {
   148  		return 0, err
   149  	}
   150  	if len(scms) != 1 {
   151  		return 0, fmt.Errorf("unexpected scms: %v", scms)
   152  	}
   153  	scm := scms[0]
   154  	fds, err := unix.ParseUnixRights(&scm)
   155  	if err != nil {
   156  		return 0, err
   157  	}
   158  	if len(fds) != 1 {
   159  		return 0, fmt.Errorf("unexpected fds: %v", fds)
   160  	}
   161  	return fds[0], nil
   162  }