github.com/jspc/eggos@v0.5.1-0.20221028160421-556c75c878a5/inet/sockfile.go (about)

     1  package inet
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"net"
     7  	"syscall"
     8  	"time"
     9  	"unsafe"
    10  
    11  	"github.com/jspc/eggos/fs"
    12  	"github.com/jspc/eggos/log"
    13  
    14  	"gvisor.dev/gvisor/pkg/abi/linux"
    15  	"gvisor.dev/gvisor/pkg/tcpip"
    16  	"gvisor.dev/gvisor/pkg/waiter"
    17  )
    18  
    19  //go:linkname evnotify github.com/jspc/eggos/kernel.epollNotify
    20  func evnotify(fd, events uintptr)
    21  
    22  type sockFile struct {
    23  	fd int
    24  	ep tcpip.Endpoint
    25  	wq *waiter.Queue
    26  }
    27  
    28  func allocSockFile(ep tcpip.Endpoint, wq *waiter.Queue) *sockFile {
    29  	fd, ni := fs.AllocInode()
    30  
    31  	sfile := &sockFile{
    32  		fd: fd,
    33  		ep: ep,
    34  		wq: wq,
    35  	}
    36  	sfile.setupEvent()
    37  
    38  	ni.File = sfile
    39  	return sfile
    40  }
    41  
    42  func findSockFile(fd uintptr) (*sockFile, error) {
    43  	ni, err := fs.GetInode(int(fd))
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	sf, ok := ni.File.(*sockFile)
    48  	if !ok {
    49  		return nil, syscall.EBADF
    50  	}
    51  	return sf, nil
    52  }
    53  
    54  func (s *sockFile) Read(p []byte) (int, error) {
    55  	var terr tcpip.Error
    56  	var result tcpip.ReadResult
    57  
    58  	w := tcpip.SliceWriter(p)
    59  	result, terr = s.ep.Read(&w, tcpip.ReadOptions{})
    60  
    61  	switch terr.(type) {
    62  	case nil:
    63  	case *tcpip.ErrWouldBlock:
    64  		return 0, syscall.EAGAIN
    65  	case *tcpip.ErrClosedForReceive:
    66  		return 0, nil
    67  	default:
    68  		log.Infof("[socket] read error:%s", terr)
    69  		return 0, e(terr)
    70  	}
    71  	if result.Count < result.Total {
    72  		// make next epoll_wait success
    73  		s.evcallback(nil, waiter.EventIn)
    74  	}
    75  	return result.Count, nil
    76  }
    77  
    78  func (s *sockFile) Write(p []byte) (int, error) {
    79  	n, terr := s.ep.Write(bytes.NewBuffer(p), tcpip.WriteOptions{})
    80  	if n != 0 {
    81  		return int(n), nil
    82  	}
    83  
    84  	switch terr.(type) {
    85  	case *tcpip.ErrWouldBlock:
    86  		return 0, syscall.EAGAIN
    87  	case *tcpip.ErrClosedForSend:
    88  		return 0, syscall.EPIPE
    89  	default:
    90  		log.Infof("[socket] write error:%s", terr)
    91  		return 0, e(terr)
    92  	}
    93  }
    94  
    95  func (s *sockFile) Close() error {
    96  	s.ep.Close()
    97  	return nil
    98  }
    99  
   100  type evcallback func(*waiter.Entry, waiter.EventMask)
   101  
   102  func (e evcallback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
   103  	e(entry, mask)
   104  }
   105  
   106  func (s *sockFile) setupEvent() {
   107  	s.wq.EventRegister(&waiter.Entry{
   108  		Callback: evcallback(s.evcallback),
   109  	}, waiter.EventIn|waiter.EventOut|waiter.EventErr|waiter.EventHUp)
   110  }
   111  
   112  func (s *sockFile) stopEvent() {
   113  	s.wq.EventUnregister(nil)
   114  }
   115  
   116  func (s *sockFile) evcallback(e *waiter.Entry, mask waiter.EventMask) {
   117  	// log.Infof("ev:%x fd:%d", mask, s.fd)
   118  	// syscall.Syscall(kernel.SYS_EPOLL_NOTIFY, uintptr(s.fd), uintptr(mask.ToLinux()), 0)
   119  	evnotify(uintptr(s.fd), uintptr(mask.ToLinux()))
   120  }
   121  
   122  func (s *sockFile) Bind(uaddr, uaddrlen uintptr) error {
   123  	var saddr *linux.SockAddrInet
   124  	if uaddrlen < unsafe.Sizeof(*saddr) {
   125  		return errors.New("bad bind address")
   126  	}
   127  	saddr = (*linux.SockAddrInet)(unsafe.Pointer(uaddr))
   128  	ip := net.IPv4(saddr.Addr[0], saddr.Addr[1], saddr.Addr[2], saddr.Addr[3])
   129  	addr := tcpip.FullAddress{
   130  		// NIC:  defaultNIC,
   131  		Addr: tcpip.Address(ip),
   132  		Port: ntohs(saddr.Port),
   133  	}
   134  	err := s.ep.Bind(addr)
   135  	if err != nil {
   136  		log.Infof("[socket] bind error:%s", err)
   137  		return e(err)
   138  	}
   139  	return nil
   140  }
   141  
   142  func (s *sockFile) Connect(uaddr, uaddrlen uintptr) error {
   143  	var saddr *linux.SockAddrInet
   144  	if uaddrlen < unsafe.Sizeof(*saddr) {
   145  		return syscall.EINVAL
   146  	}
   147  	saddr = (*linux.SockAddrInet)(unsafe.Pointer(uaddr))
   148  	addr := tcpip.FullAddress{
   149  		Addr: tcpip.Address(saddr.Addr[:]),
   150  		Port: ntohs(saddr.Port),
   151  	}
   152  	err := s.ep.Connect(addr)
   153  	if _, ok := err.(*tcpip.ErrConnectStarted); ok {
   154  		return syscall.EINPROGRESS
   155  	}
   156  	if err != nil {
   157  		log.Infof("[socket] connect error:%s", err)
   158  		return e(err)
   159  	}
   160  	return nil
   161  }
   162  
   163  func (s *sockFile) Listen(n uintptr) error {
   164  	err := s.ep.Listen(int(n))
   165  	if err != nil {
   166  		log.Infof("[socket] listen error:%s", err)
   167  	}
   168  	return e(err)
   169  }
   170  
   171  func (s *sockFile) Accept4(uaddr, uaddrlen, flag uintptr) (int, error) {
   172  	var saddr *linux.SockAddrInet
   173  	if uaddrlen < unsafe.Sizeof(*saddr) {
   174  		return 0, syscall.EINVAL
   175  	}
   176  	saddr = (*linux.SockAddrInet)(unsafe.Pointer(uaddr))
   177  	newep, wq, err := s.ep.Accept(nil)
   178  	switch err.(type) {
   179  	case nil:
   180  	case *tcpip.ErrWouldBlock:
   181  		return 0, syscall.EAGAIN
   182  	default:
   183  		log.Infof("[socket] accept error:%s", err)
   184  		return 0, e(err)
   185  	}
   186  
   187  	newaddr, err := newep.GetRemoteAddress()
   188  	if err != nil {
   189  		log.Infof("[socket] accept getRemoteAddress error:%s", err)
   190  		return 0, e(err)
   191  	}
   192  	saddr.Family = syscall.AF_INET
   193  	saddr.Port = htons(newaddr.Port)
   194  	copy(saddr.Addr[:], newaddr.Addr)
   195  	sfile := allocSockFile(newep, wq)
   196  	return sfile.fd, nil
   197  }
   198  
   199  func (s *sockFile) Setsockopt(level, opt, vptr, vlen uintptr) error {
   200  	switch level {
   201  	case syscall.SOL_SOCKET, syscall.IPPROTO_TCP:
   202  	default:
   203  		log.Infof("[socket] setsockopt:unsupport socket opt level:%d", level)
   204  		return syscall.EINVAL
   205  	}
   206  
   207  	if vlen != 4 {
   208  		log.Infof("[socket] setsockopt:bad opt value length:%d", vlen)
   209  		return syscall.EINVAL
   210  	}
   211  
   212  	var terr tcpip.Error
   213  	value := *(*uint32)(unsafe.Pointer(vptr))
   214  	sockopt := s.ep.SocketOptions()
   215  
   216  	switch opt {
   217  	case syscall.SO_REUSEADDR:
   218  		sockopt.SetReuseAddress(value != 0)
   219  	case syscall.SO_BROADCAST:
   220  		sockopt.SetBroadcast(value != 0)
   221  	case syscall.TCP_NODELAY:
   222  		sockopt.SetDelayOption(value != 0)
   223  	case syscall.SO_KEEPALIVE:
   224  		sockopt.SetKeepAlive(value != 0)
   225  	case syscall.TCP_KEEPINTVL:
   226  		v := tcpip.KeepaliveIntervalOption(time.Duration(value) * time.Second)
   227  		terr = s.ep.SetSockOpt(&v)
   228  	case syscall.TCP_KEEPIDLE:
   229  		v := tcpip.KeepaliveIdleOption(time.Duration(value) * time.Second)
   230  		terr = s.ep.SetSockOpt(&v)
   231  	default:
   232  		log.Infof("[socket] setsockopt:unknow socket option:%d", opt)
   233  		return syscall.EINVAL
   234  	}
   235  
   236  	if terr != nil {
   237  		return e(terr)
   238  	}
   239  	return nil
   240  }
   241  
   242  func (s *sockFile) Getsockopt(level, opt, vptr, vlenptr uintptr) error {
   243  	if level != syscall.SOL_SOCKET {
   244  		log.Infof("[socket] getsockopt:unsupport socket opt level:%d", level)
   245  		return syscall.EINVAL
   246  	}
   247  	vlen := (*int)(unsafe.Pointer(vlenptr))
   248  	if *vlen != 4 {
   249  		log.Infof("[socket] getsockopt:bad opt value length:%d", vlen)
   250  		return syscall.EINVAL
   251  	}
   252  	value := (*uint32)(unsafe.Pointer(vptr))
   253  
   254  	switch opt {
   255  	case syscall.SO_ERROR:
   256  		terr := s.ep.SocketOptions().GetLastError()
   257  		switch terr.(type) {
   258  		case nil:
   259  		case *tcpip.ErrConnectionRefused:
   260  			*value = uint32(syscall.ECONNREFUSED)
   261  		case *tcpip.ErrNoRoute:
   262  			*value = uint32(syscall.EHOSTUNREACH)
   263  		default:
   264  			log.Infof("[socket] getsockopt:unknow socket error:%s", terr)
   265  			return e(terr)
   266  		}
   267  	default:
   268  		log.Infof("[socket] getsockopt:unknow socket option:%d", opt)
   269  		return syscall.EINVAL
   270  	}
   271  	return nil
   272  }
   273  
   274  func (s *sockFile) Getpeername(uaddr, uaddrlen uintptr) error {
   275  	saddr := (*linux.SockAddrInet)(unsafe.Pointer(uaddr))
   276  	addr, err := s.ep.GetRemoteAddress()
   277  	if err != nil {
   278  		log.Infof("[socket] getpeername error:%s", err)
   279  		return e(err)
   280  	}
   281  	saddr.Family = syscall.AF_INET
   282  	copy(saddr.Addr[:], addr.Addr)
   283  	saddr.Port = ntohs(addr.Port)
   284  	return nil
   285  }
   286  
   287  func (s *sockFile) Getsockname(uaddr, uaddrlen uintptr) error {
   288  	saddr := (*linux.SockAddrInet)(unsafe.Pointer(uaddr))
   289  	addr, err := s.ep.GetLocalAddress()
   290  	if err != nil {
   291  		log.Infof("[socket] getsockname error:%s", err)
   292  		return e(err)
   293  	}
   294  	saddr.Family = syscall.AF_INET
   295  	copy(saddr.Addr[:], addr.Addr)
   296  	saddr.Port = htons(addr.Port)
   297  	return nil
   298  }