github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/fs/host/socket_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package host
    16  
    17  import (
    18  	"reflect"
    19  	"testing"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/SagerNet/gvisor/pkg/fd"
    23  	"github.com/SagerNet/gvisor/pkg/fdnotifier"
    24  	"github.com/SagerNet/gvisor/pkg/sentry/contexttest"
    25  	ktime "github.com/SagerNet/gvisor/pkg/sentry/kernel/time"
    26  	"github.com/SagerNet/gvisor/pkg/sentry/socket"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport"
    28  	"github.com/SagerNet/gvisor/pkg/syserr"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip"
    30  	"github.com/SagerNet/gvisor/pkg/usermem"
    31  	"github.com/SagerNet/gvisor/pkg/waiter"
    32  )
    33  
    34  var (
    35  	// Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint.
    36  	_ = transport.ConnectedEndpoint(new(ConnectedEndpoint))
    37  
    38  	// Make sure that ConnectedEndpoint implements transport.Receiver.
    39  	_ = transport.Receiver(new(ConnectedEndpoint))
    40  )
    41  
    42  func getFl(fd int) (uint32, error) {
    43  	fl, _, err := unix.RawSyscall(unix.SYS_FCNTL, uintptr(fd), unix.F_GETFL, 0)
    44  	if err == 0 {
    45  		return uint32(fl), nil
    46  	}
    47  	return 0, err
    48  }
    49  
    50  func TestSocketIsBlocking(t *testing.T) {
    51  	// Using socketpair here because it's already connected.
    52  	pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
    53  	if err != nil {
    54  		t.Fatalf("host socket creation failed: %v", err)
    55  	}
    56  
    57  	fl, err := getFl(pair[0])
    58  	if err != nil {
    59  		t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
    60  	}
    61  	if fl&unix.O_NONBLOCK == unix.O_NONBLOCK {
    62  		t.Fatalf("Expected socket %v to be blocking", pair[0])
    63  	}
    64  	if fl, err = getFl(pair[1]); err != nil {
    65  		t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
    66  	}
    67  	if fl&unix.O_NONBLOCK == unix.O_NONBLOCK {
    68  		t.Fatalf("Expected socket %v to be blocking", pair[1])
    69  	}
    70  	ctx := contexttest.Context(t)
    71  	sock, err := newSocket(ctx, pair[0], false)
    72  	if err != nil {
    73  		t.Fatalf("newSocket(%v) failed => %v", pair[0], err)
    74  	}
    75  	defer sock.DecRef(ctx)
    76  	// Test that the socket now is non-blocking.
    77  	if fl, err = getFl(pair[0]); err != nil {
    78  		t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
    79  	}
    80  	if fl&unix.O_NONBLOCK != unix.O_NONBLOCK {
    81  		t.Errorf("Expected socket %v to have become non-blocking", pair[0])
    82  	}
    83  	if fl, err = getFl(pair[1]); err != nil {
    84  		t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
    85  	}
    86  	if fl&unix.O_NONBLOCK == unix.O_NONBLOCK {
    87  		t.Errorf("Did not expect socket %v to become non-blocking", pair[1])
    88  	}
    89  }
    90  
    91  func TestSocketWritev(t *testing.T) {
    92  	// Using socketpair here because it's already connected.
    93  	pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
    94  	if err != nil {
    95  		t.Fatalf("host socket creation failed: %v", err)
    96  	}
    97  	ctx := contexttest.Context(t)
    98  	socket, err := newSocket(ctx, pair[0], false)
    99  	if err != nil {
   100  		t.Fatalf("newSocket(%v) => %v", pair[0], err)
   101  	}
   102  	defer socket.DecRef(ctx)
   103  	buf := []byte("hello world\n")
   104  	n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf))
   105  	if err != nil {
   106  		t.Fatalf("socket writev failed: %v", err)
   107  	}
   108  
   109  	if n != int64(len(buf)) {
   110  		t.Fatalf("socket writev wrote incorrect bytes: %d", n)
   111  	}
   112  }
   113  
   114  func TestSocketWritevLen0(t *testing.T) {
   115  	// Using socketpair here because it's already connected.
   116  	pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
   117  	if err != nil {
   118  		t.Fatalf("host socket creation failed: %v", err)
   119  	}
   120  	ctx := contexttest.Context(t)
   121  	socket, err := newSocket(ctx, pair[0], false)
   122  	if err != nil {
   123  		t.Fatalf("newSocket(%v) => %v", pair[0], err)
   124  	}
   125  	defer socket.DecRef(ctx)
   126  	n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil))
   127  	if err != nil {
   128  		t.Fatalf("socket writev failed: %v", err)
   129  	}
   130  
   131  	if n != 0 {
   132  		t.Fatalf("socket writev wrote incorrect bytes: %d", n)
   133  	}
   134  }
   135  
   136  func TestSocketSendMsgLen0(t *testing.T) {
   137  	// Using socketpair here because it's already connected.
   138  	pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
   139  	if err != nil {
   140  		t.Fatalf("host socket creation failed: %v", err)
   141  	}
   142  	ctx := contexttest.Context(t)
   143  	sfile, err := newSocket(ctx, pair[0], false)
   144  	if err != nil {
   145  		t.Fatalf("newSocket(%v) => %v", pair[0], err)
   146  	}
   147  	defer sfile.DecRef(ctx)
   148  
   149  	s := sfile.FileOperations.(socket.Socket)
   150  	n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{})
   151  	if n != 0 {
   152  		t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n)
   153  	}
   154  
   155  	if terr != nil {
   156  		t.Fatalf("socket sendmsg() failed: %v", terr)
   157  	}
   158  }
   159  
   160  func TestListen(t *testing.T) {
   161  	pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
   162  	if err != nil {
   163  		t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err)
   164  	}
   165  	ctx := contexttest.Context(t)
   166  	sfile1, err := newSocket(ctx, pair[0], false)
   167  	if err != nil {
   168  		t.Fatalf("newSocket(%v) => %v", pair[0], err)
   169  	}
   170  	defer sfile1.DecRef(ctx)
   171  	socket1 := sfile1.FileOperations.(socket.Socket)
   172  
   173  	sfile2, err := newSocket(ctx, pair[1], false)
   174  	if err != nil {
   175  		t.Fatalf("newSocket(%v) => %v", pair[1], err)
   176  	}
   177  	defer sfile2.DecRef(ctx)
   178  	socket2 := sfile2.FileOperations.(socket.Socket)
   179  
   180  	// Socketpairs can not be listened to.
   181  	if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
   182  		t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
   183  	}
   184  	if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
   185  		t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
   186  	}
   187  
   188  	// Create a Unix socket, do not bind it.
   189  	sock, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0)
   190  	if err != nil {
   191  		t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err)
   192  	}
   193  	sfile3, err := newSocket(ctx, sock, false)
   194  	if err != nil {
   195  		t.Fatalf("newSocket(%v) => %v", sock, err)
   196  	}
   197  	defer sfile3.DecRef(ctx)
   198  	socket3 := sfile3.FileOperations.(socket.Socket)
   199  
   200  	// This socket is not bound so we can't listen on it.
   201  	if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
   202  		t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
   203  	}
   204  }
   205  
   206  func TestPasscred(t *testing.T) {
   207  	e := &ConnectedEndpoint{}
   208  	if got, want := e.Passcred(), false; got != want {
   209  		t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
   210  	}
   211  }
   212  
   213  func TestGetLocalAddress(t *testing.T) {
   214  	e := &ConnectedEndpoint{path: "foo"}
   215  	want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
   216  	if got, err := e.GetLocalAddress(); err != nil || got != want {
   217  		t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
   218  	}
   219  }
   220  
   221  func TestQueuedSize(t *testing.T) {
   222  	e := &ConnectedEndpoint{}
   223  	tests := []struct {
   224  		name string
   225  		f    func() int64
   226  	}{
   227  		{"SendQueuedSize", e.SendQueuedSize},
   228  		{"RecvQueuedSize", e.RecvQueuedSize},
   229  	}
   230  
   231  	for _, test := range tests {
   232  		if got, want := test.f(), int64(-1); got != want {
   233  			t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want)
   234  		}
   235  	}
   236  }
   237  
   238  func TestRelease(t *testing.T) {
   239  	f, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0)
   240  	if err != nil {
   241  		t.Fatal("Creating socket:", err)
   242  	}
   243  	c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
   244  	want := &ConnectedEndpoint{queue: c.queue}
   245  	ctx := contexttest.Context(t)
   246  	want.ref.DecRef(ctx)
   247  	fdnotifier.AddFD(int32(c.file.FD()), nil)
   248  	c.Release(ctx)
   249  	if !reflect.DeepEqual(c, want) {
   250  		t.Errorf("got = %#v, want = %#v", c, want)
   251  	}
   252  }