github.com/nycdavid/zeus@v0.0.0-20201208104106-9ba439429e03/go/unixsocket/oobreader.go (about) 1 package unixsocket 2 3 import ( 4 "errors" 5 "fmt" 6 "net" 7 "syscall" 8 ) 9 10 type oobReader struct { 11 Conn *net.UnixConn 12 oob []byte 13 oobs [][]byte 14 } 15 16 func (o *oobReader) ReadFD() (int, error) { 17 if len(o.oobs) > 0 { 18 oob := o.oobs[0] 19 o.oobs = o.oobs[1:] 20 return extractFileDescriptorFromOOB(oob) 21 } 22 23 b := make([]byte, 0) 24 _, err := o.Read(b) 25 if err != nil { 26 return -1, err 27 } 28 29 if len(o.oobs) > 0 { 30 oob := o.oobs[0] 31 o.oobs = o.oobs[1:] 32 return extractFileDescriptorFromOOB(oob) 33 } 34 35 return -1, errors.New("No FD received :(") 36 } 37 38 func (o *oobReader) Read(b []byte) (int, error) { 39 n, oobn, _, _, err := o.Conn.ReadMsgUnix(b, o.oob) 40 if oobn > 0 { 41 newOob := make([]byte, oobn) 42 copy(newOob, o.oob[:oobn]) 43 o.oobs = append(o.oobs, newOob) 44 } 45 return n, err 46 } 47 48 func extractFileDescriptorFromOOB(oob []byte) (int, error) { 49 scms, err := syscall.ParseSocketControlMessage(oob) 50 if err != nil { 51 return -1, err 52 } 53 if len(scms) != 1 { 54 return -1, errors.New(fmt.Sprintf("expected 1 SocketControlMessage; got scms = %#v", scms)) 55 } 56 scm := scms[0] 57 gotFds, err := syscall.ParseUnixRights(&scm) 58 if err != nil { 59 return -1, err 60 } 61 if len(gotFds) != 1 { 62 return -1, errors.New(fmt.Sprintf("wanted 1 fd; got %#v", gotFds)) 63 } 64 return gotFds[0], nil 65 }