github.com/criyle/go-sandbox@v0.10.3/pkg/unixsocket/socket_linux.go (about)

     1  // Package unixsocket provides wrapper for Linux unix socket to send and recv oob messages
     2  // including fd and user credential.
     3  package unixsocket
     4  
     5  import (
     6  	"bytes"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"syscall"
    11  )
    12  
    13  // oob size default to page size
    14  const oobSize = 4 << 10 // 4kb
    15  
    16  // Socket wrappers a unix socket connection
    17  type Socket struct {
    18  	*net.UnixConn
    19  	sendBuff []byte
    20  	recvBuff []byte
    21  }
    22  
    23  // Msg is the oob msg with the message
    24  type Msg struct {
    25  	Fds  []int          // unix rights
    26  	Cred *syscall.Ucred // unix credential
    27  }
    28  
    29  func newSocket(conn *net.UnixConn) *Socket {
    30  	return &Socket{
    31  		UnixConn: conn,
    32  		sendBuff: make([]byte, oobSize),
    33  		recvBuff: make([]byte, oobSize),
    34  	}
    35  }
    36  
    37  // NewSocket creates Socket conn struct using existing unix socket fd
    38  // creates by socketpair or net.DialUnix and mark it as close_on_exec (avoid fd leak)
    39  // it need SOCK_SEQPACKET socket for reliable transfer
    40  // it will need SO_PASSCRED to pass unix credential, Notice: in the documentation,
    41  // if cred is not specified, self information will be sent
    42  func NewSocket(fd int) (*Socket, error) {
    43  	syscall.SetNonblock(fd, true)
    44  	syscall.CloseOnExec(fd)
    45  
    46  	file := os.NewFile(uintptr(fd), "unix-socket")
    47  	if file == nil {
    48  		return nil, fmt.Errorf("NewSocket: %d is not a valid fd", fd)
    49  	}
    50  	defer file.Close()
    51  
    52  	conn, err := net.FileConn(file)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	unixConn, ok := conn.(*net.UnixConn)
    58  	if !ok {
    59  		conn.Close()
    60  		return nil, fmt.Errorf("NewSocket: %d is not a valid unix socket connection", fd)
    61  	}
    62  	return newSocket(unixConn), nil
    63  }
    64  
    65  // NewSocketPair creates connected unix socketpair using SOCK_SEQPACKET
    66  func NewSocketPair() (*Socket, *Socket, error) {
    67  	fd, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC, 0)
    68  	if err != nil {
    69  		return nil, nil, fmt.Errorf("NewSocketPair: failed to call socketpair %v", err)
    70  	}
    71  
    72  	ins, err := NewSocket(fd[0])
    73  	if err != nil {
    74  		syscall.Close(fd[0])
    75  		syscall.Close(fd[1])
    76  		return nil, nil, fmt.Errorf("NewSocketPair: failed to call NewSocket on sender %v", err)
    77  	}
    78  
    79  	outs, err := NewSocket(fd[1])
    80  	if err != nil {
    81  		ins.Close()
    82  		syscall.Close(fd[1])
    83  		return nil, nil, fmt.Errorf("NewSocketPair: failed to call NewSocket receiver %v", err)
    84  	}
    85  
    86  	return ins, outs, nil
    87  }
    88  
    89  // SetPassCred set sockopt for pass cred for unix socket
    90  func (s *Socket) SetPassCred(option int) error {
    91  	sysconn, err := s.SyscallConn()
    92  	if err != nil {
    93  		return err
    94  	}
    95  	return sysconn.Control(func(fd uintptr) {
    96  		syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_PASSCRED, option)
    97  	})
    98  }
    99  
   100  // SendMsg sendmsg to unix socket and encode possible unix right / credential
   101  func (s *Socket) SendMsg(b []byte, m Msg) error {
   102  	oob := bytes.NewBuffer(s.sendBuff[:0])
   103  	if len(m.Fds) > 0 {
   104  		oob.Write(syscall.UnixRights(m.Fds...))
   105  	}
   106  	if m.Cred != nil {
   107  		oob.Write(syscall.UnixCredentials(m.Cred))
   108  	}
   109  
   110  	_, _, err := s.WriteMsgUnix(b, oob.Bytes(), nil)
   111  	if err != nil {
   112  		return err
   113  	}
   114  	return nil
   115  }
   116  
   117  // RecvMsg recvmsg from unix socket and parse possible unix right / credential
   118  func (s *Socket) RecvMsg(b []byte) (int, Msg, error) {
   119  	var msg Msg
   120  	n, oobn, _, _, err := s.ReadMsgUnix(b, s.recvBuff)
   121  	if err != nil {
   122  		return 0, msg, err
   123  	}
   124  	// parse oob msg
   125  	msgs, err := syscall.ParseSocketControlMessage(s.recvBuff[:oobn])
   126  	if err != nil {
   127  		return 0, msg, err
   128  	}
   129  	msg, err = parseMsg(msgs)
   130  	if err != nil {
   131  		return 0, msg, err
   132  	}
   133  	return n, msg, nil
   134  }
   135  
   136  func parseMsg(msgs []syscall.SocketControlMessage) (msg Msg, err error) {
   137  	defer func() {
   138  		if err != nil {
   139  			for _, f := range msg.Fds {
   140  				syscall.Close(f)
   141  			}
   142  			msg.Fds = nil
   143  		}
   144  	}()
   145  	for _, m := range msgs {
   146  		if m.Header.Level != syscall.SOL_SOCKET {
   147  			continue
   148  		}
   149  
   150  		switch m.Header.Type {
   151  		case syscall.SCM_CREDENTIALS:
   152  			cred, err := syscall.ParseUnixCredentials(&m)
   153  			if err != nil {
   154  				return msg, err
   155  			}
   156  			msg.Cred = cred
   157  
   158  		case syscall.SCM_RIGHTS:
   159  			fds, err := syscall.ParseUnixRights(&m)
   160  			if err != nil {
   161  				return msg, err
   162  			}
   163  			msg.Fds = fds
   164  		}
   165  	}
   166  	return msg, nil
   167  }