github.com/nycdavid/zeus@v0.0.0-20201208104106-9ba439429e03/go/unixsocket/oobreader.go (about)

     1  package unixsocket
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"syscall"
     8  )
     9  
    10  type oobReader struct {
    11  	Conn *net.UnixConn
    12  	oob  []byte
    13  	oobs [][]byte
    14  }
    15  
    16  func (o *oobReader) ReadFD() (int, error) {
    17  	if len(o.oobs) > 0 {
    18  		oob := o.oobs[0]
    19  		o.oobs = o.oobs[1:]
    20  		return extractFileDescriptorFromOOB(oob)
    21  	}
    22  
    23  	b := make([]byte, 0)
    24  	_, err := o.Read(b)
    25  	if err != nil {
    26  		return -1, err
    27  	}
    28  
    29  	if len(o.oobs) > 0 {
    30  		oob := o.oobs[0]
    31  		o.oobs = o.oobs[1:]
    32  		return extractFileDescriptorFromOOB(oob)
    33  	}
    34  
    35  	return -1, errors.New("No FD received :(")
    36  }
    37  
    38  func (o *oobReader) Read(b []byte) (int, error) {
    39  	n, oobn, _, _, err := o.Conn.ReadMsgUnix(b, o.oob)
    40  	if oobn > 0 {
    41  		newOob := make([]byte, oobn)
    42  		copy(newOob, o.oob[:oobn])
    43  		o.oobs = append(o.oobs, newOob)
    44  	}
    45  	return n, err
    46  }
    47  
    48  func extractFileDescriptorFromOOB(oob []byte) (int, error) {
    49  	scms, err := syscall.ParseSocketControlMessage(oob)
    50  	if err != nil {
    51  		return -1, err
    52  	}
    53  	if len(scms) != 1 {
    54  		return -1, errors.New(fmt.Sprintf("expected 1 SocketControlMessage; got scms = %#v", scms))
    55  	}
    56  	scm := scms[0]
    57  	gotFds, err := syscall.ParseUnixRights(&scm)
    58  	if err != nil {
    59  		return -1, err
    60  	}
    61  	if len(gotFds) != 1 {
    62  		return -1, errors.New(fmt.Sprintf("wanted 1 fd; got %#v", gotFds))
    63  	}
    64  	return gotFds[0], nil
    65  }