github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/fs/host/socket_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package host 16 17 import ( 18 "reflect" 19 "testing" 20 21 "golang.org/x/sys/unix" 22 "github.com/SagerNet/gvisor/pkg/fd" 23 "github.com/SagerNet/gvisor/pkg/fdnotifier" 24 "github.com/SagerNet/gvisor/pkg/sentry/contexttest" 25 ktime "github.com/SagerNet/gvisor/pkg/sentry/kernel/time" 26 "github.com/SagerNet/gvisor/pkg/sentry/socket" 27 "github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport" 28 "github.com/SagerNet/gvisor/pkg/syserr" 29 "github.com/SagerNet/gvisor/pkg/tcpip" 30 "github.com/SagerNet/gvisor/pkg/usermem" 31 "github.com/SagerNet/gvisor/pkg/waiter" 32 ) 33 34 var ( 35 // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint. 36 _ = transport.ConnectedEndpoint(new(ConnectedEndpoint)) 37 38 // Make sure that ConnectedEndpoint implements transport.Receiver. 39 _ = transport.Receiver(new(ConnectedEndpoint)) 40 ) 41 42 func getFl(fd int) (uint32, error) { 43 fl, _, err := unix.RawSyscall(unix.SYS_FCNTL, uintptr(fd), unix.F_GETFL, 0) 44 if err == 0 { 45 return uint32(fl), nil 46 } 47 return 0, err 48 } 49 50 func TestSocketIsBlocking(t *testing.T) { 51 // Using socketpair here because it's already connected. 52 pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) 53 if err != nil { 54 t.Fatalf("host socket creation failed: %v", err) 55 } 56 57 fl, err := getFl(pair[0]) 58 if err != nil { 59 t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) 60 } 61 if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { 62 t.Fatalf("Expected socket %v to be blocking", pair[0]) 63 } 64 if fl, err = getFl(pair[1]); err != nil { 65 t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) 66 } 67 if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { 68 t.Fatalf("Expected socket %v to be blocking", pair[1]) 69 } 70 ctx := contexttest.Context(t) 71 sock, err := newSocket(ctx, pair[0], false) 72 if err != nil { 73 t.Fatalf("newSocket(%v) failed => %v", pair[0], err) 74 } 75 defer sock.DecRef(ctx) 76 // Test that the socket now is non-blocking. 77 if fl, err = getFl(pair[0]); err != nil { 78 t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) 79 } 80 if fl&unix.O_NONBLOCK != unix.O_NONBLOCK { 81 t.Errorf("Expected socket %v to have become non-blocking", pair[0]) 82 } 83 if fl, err = getFl(pair[1]); err != nil { 84 t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) 85 } 86 if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { 87 t.Errorf("Did not expect socket %v to become non-blocking", pair[1]) 88 } 89 } 90 91 func TestSocketWritev(t *testing.T) { 92 // Using socketpair here because it's already connected. 93 pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) 94 if err != nil { 95 t.Fatalf("host socket creation failed: %v", err) 96 } 97 ctx := contexttest.Context(t) 98 socket, err := newSocket(ctx, pair[0], false) 99 if err != nil { 100 t.Fatalf("newSocket(%v) => %v", pair[0], err) 101 } 102 defer socket.DecRef(ctx) 103 buf := []byte("hello world\n") 104 n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) 105 if err != nil { 106 t.Fatalf("socket writev failed: %v", err) 107 } 108 109 if n != int64(len(buf)) { 110 t.Fatalf("socket writev wrote incorrect bytes: %d", n) 111 } 112 } 113 114 func TestSocketWritevLen0(t *testing.T) { 115 // Using socketpair here because it's already connected. 116 pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) 117 if err != nil { 118 t.Fatalf("host socket creation failed: %v", err) 119 } 120 ctx := contexttest.Context(t) 121 socket, err := newSocket(ctx, pair[0], false) 122 if err != nil { 123 t.Fatalf("newSocket(%v) => %v", pair[0], err) 124 } 125 defer socket.DecRef(ctx) 126 n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) 127 if err != nil { 128 t.Fatalf("socket writev failed: %v", err) 129 } 130 131 if n != 0 { 132 t.Fatalf("socket writev wrote incorrect bytes: %d", n) 133 } 134 } 135 136 func TestSocketSendMsgLen0(t *testing.T) { 137 // Using socketpair here because it's already connected. 138 pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) 139 if err != nil { 140 t.Fatalf("host socket creation failed: %v", err) 141 } 142 ctx := contexttest.Context(t) 143 sfile, err := newSocket(ctx, pair[0], false) 144 if err != nil { 145 t.Fatalf("newSocket(%v) => %v", pair[0], err) 146 } 147 defer sfile.DecRef(ctx) 148 149 s := sfile.FileOperations.(socket.Socket) 150 n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) 151 if n != 0 { 152 t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n) 153 } 154 155 if terr != nil { 156 t.Fatalf("socket sendmsg() failed: %v", terr) 157 } 158 } 159 160 func TestListen(t *testing.T) { 161 pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) 162 if err != nil { 163 t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err) 164 } 165 ctx := contexttest.Context(t) 166 sfile1, err := newSocket(ctx, pair[0], false) 167 if err != nil { 168 t.Fatalf("newSocket(%v) => %v", pair[0], err) 169 } 170 defer sfile1.DecRef(ctx) 171 socket1 := sfile1.FileOperations.(socket.Socket) 172 173 sfile2, err := newSocket(ctx, pair[1], false) 174 if err != nil { 175 t.Fatalf("newSocket(%v) => %v", pair[1], err) 176 } 177 defer sfile2.DecRef(ctx) 178 socket2 := sfile2.FileOperations.(socket.Socket) 179 180 // Socketpairs can not be listened to. 181 if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { 182 t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) 183 } 184 if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { 185 t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) 186 } 187 188 // Create a Unix socket, do not bind it. 189 sock, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) 190 if err != nil { 191 t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err) 192 } 193 sfile3, err := newSocket(ctx, sock, false) 194 if err != nil { 195 t.Fatalf("newSocket(%v) => %v", sock, err) 196 } 197 defer sfile3.DecRef(ctx) 198 socket3 := sfile3.FileOperations.(socket.Socket) 199 200 // This socket is not bound so we can't listen on it. 201 if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { 202 t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) 203 } 204 } 205 206 func TestPasscred(t *testing.T) { 207 e := &ConnectedEndpoint{} 208 if got, want := e.Passcred(), false; got != want { 209 t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) 210 } 211 } 212 213 func TestGetLocalAddress(t *testing.T) { 214 e := &ConnectedEndpoint{path: "foo"} 215 want := tcpip.FullAddress{Addr: tcpip.Address("foo")} 216 if got, err := e.GetLocalAddress(); err != nil || got != want { 217 t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) 218 } 219 } 220 221 func TestQueuedSize(t *testing.T) { 222 e := &ConnectedEndpoint{} 223 tests := []struct { 224 name string 225 f func() int64 226 }{ 227 {"SendQueuedSize", e.SendQueuedSize}, 228 {"RecvQueuedSize", e.RecvQueuedSize}, 229 } 230 231 for _, test := range tests { 232 if got, want := test.f(), int64(-1); got != want { 233 t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want) 234 } 235 } 236 } 237 238 func TestRelease(t *testing.T) { 239 f, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) 240 if err != nil { 241 t.Fatal("Creating socket:", err) 242 } 243 c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} 244 want := &ConnectedEndpoint{queue: c.queue} 245 ctx := contexttest.Context(t) 246 want.ref.DecRef(ctx) 247 fdnotifier.AddFD(int32(c.file.FD()), nil) 248 c.Release(ctx) 249 if !reflect.DeepEqual(c, want) { 250 t.Errorf("got = %#v, want = %#v", c, want) 251 } 252 }