github.com/nycdavid/zeus@v0.0.0-20201208104106-9ba439429e03/go/unixsocket/usock.go (about)

     1  package unixsocket
     2  
     3  import (
     4  	"bufio"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"strings"
    11  	"sync"
    12  	"syscall"
    13  )
    14  
    15  type Usock struct {
    16  	reader *oobReader
    17  	rbuf   *bufio.Reader
    18  
    19  	sync.Mutex
    20  }
    21  
    22  func New(conn *net.UnixConn) *Usock {
    23  	u := &Usock{
    24  		reader: &oobReader{
    25  			oob:  make([]byte, 32),
    26  			Conn: conn,
    27  		},
    28  	}
    29  	u.rbuf = bufio.NewReader(u.reader)
    30  	return u
    31  }
    32  
    33  func NewFromFile(f *os.File) (*Usock, error) {
    34  	fileConn, err := net.FileConn(f)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	unixConn, ok := fileConn.(*net.UnixConn)
    40  	if !ok {
    41  		return nil, errors.New(fmt.Sprintf("unexpected FileConn type; expected UnixConn, got %T", unixConn))
    42  	}
    43  
    44  	return New(unixConn), nil
    45  }
    46  
    47  func (u *Usock) Close() {
    48  	u.reader.Conn.Close()
    49  }
    50  
    51  func (u *Usock) ReadMessage() (s string, err error) {
    52  	u.Lock()
    53  	defer u.Unlock()
    54  
    55  	for {
    56  		s, err = u.rbuf.ReadString(0)
    57  		if err == nil {
    58  			s = strings.TrimRight(s, "\000")
    59  		}
    60  		if err != nil || s != "" {
    61  			return
    62  		}
    63  	}
    64  }
    65  
    66  func (u *Usock) WriteMessage(msg string) (int, error) {
    67  	u.Lock()
    68  	defer u.Unlock()
    69  
    70  	completeMessage := strings.NewReader(msg + "\000")
    71  	n, err := io.Copy(u.reader.Conn, completeMessage)
    72  	return int(n - 1), err
    73  }
    74  
    75  func (u *Usock) ReadFD() (int, error) {
    76  	u.Lock()
    77  	defer u.Unlock()
    78  
    79  	return u.reader.ReadFD()
    80  }
    81  
    82  func (u *Usock) WriteFD(fd int) error {
    83  	u.Lock()
    84  	defer u.Unlock()
    85  
    86  	rights := syscall.UnixRights(fd)
    87  
    88  	dummyByte := []byte{0}
    89  	n, oobn, err := u.reader.Conn.WriteMsgUnix(dummyByte, rights, nil)
    90  	if err != nil {
    91  		str := fmt.Sprintf("Usock#WriteFD:WriteMsgUnix: %v / %v\n", err, syscall.EINVAL)
    92  		return errors.New(str)
    93  	}
    94  	if n != 1 || oobn != len(rights) {
    95  		str := fmt.Sprintf("Usock#WriteFD:WriteMsgUnix = %d, %d; want 1, %d\n", n, oobn, len(rights))
    96  		return errors.New(str)
    97  	}
    98  	return nil
    99  }