github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/socket/unix/transport/connectioned.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transport 16 17 import ( 18 "github.com/SagerNet/gvisor/pkg/abi/linux" 19 "github.com/SagerNet/gvisor/pkg/context" 20 "github.com/SagerNet/gvisor/pkg/sync" 21 "github.com/SagerNet/gvisor/pkg/syserr" 22 "github.com/SagerNet/gvisor/pkg/tcpip" 23 "github.com/SagerNet/gvisor/pkg/waiter" 24 ) 25 26 // UniqueIDProvider generates a sequence of unique identifiers useful for, 27 // among other things, lock ordering. 28 type UniqueIDProvider interface { 29 // UniqueID returns a new unique identifier. 30 UniqueID() uint64 31 } 32 33 // A ConnectingEndpoint is a connectioned unix endpoint that is attempting to 34 // establish a bidirectional connection with a BoundEndpoint. 35 type ConnectingEndpoint interface { 36 // ID returns the endpoint's globally unique identifier. This identifier 37 // must be used to determine locking order if more than one endpoint is 38 // to be locked in the same codepath. The endpoint with the smaller 39 // identifier must be locked before endpoints with larger identifiers. 40 ID() uint64 41 42 // Passcred implements socket.Credentialer.Passcred. 43 Passcred() bool 44 45 // Type returns the socket type, typically either SockStream or 46 // SockSeqpacket. The connection attempt must be aborted if this 47 // value doesn't match the ConnectableEndpoint's type. 48 Type() linux.SockType 49 50 // GetLocalAddress returns the bound path. 51 GetLocalAddress() (tcpip.FullAddress, tcpip.Error) 52 53 // Locker protects the following methods. While locked, only the holder of 54 // the lock can change the return value of the protected methods. 55 sync.Locker 56 57 // Connected returns true iff the ConnectingEndpoint is in the connected 58 // state. ConnectingEndpoints can only be connected to a single endpoint, 59 // so the connection attempt must be aborted if this returns true. 60 Connected() bool 61 62 // Listening returns true iff the ConnectingEndpoint is in the listening 63 // state. ConnectingEndpoints cannot make connections while listening, so 64 // the connection attempt must be aborted if this returns true. 65 Listening() bool 66 67 // WaiterQueue returns a pointer to the endpoint's waiter queue. 68 WaiterQueue() *waiter.Queue 69 } 70 71 // connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements 72 // ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint. 73 // 74 // connectionedEndpoints must be in connected state in order to transfer data. 75 // 76 // This implementation includes STREAM and SEQPACKET Unix sockets created with 77 // socket(2), accept(2) or socketpair(2) and dgram unix sockets created with 78 // socketpair(2). See unix_connectionless.go for the implementation of DGRAM 79 // Unix sockets created with socket(2). 80 // 81 // The state is much simpler than a TCP endpoint, so it is not encoded 82 // explicitly. Instead we enforce the following invariants: 83 // 84 // receiver != nil, connected != nil => connected. 85 // path != "" && acceptedChan == nil => bound, not listening. 86 // path != "" && acceptedChan != nil => bound and listening. 87 // 88 // Only one of these will be true at any moment. 89 // 90 // +stateify savable 91 type connectionedEndpoint struct { 92 baseEndpoint 93 94 // id is the unique endpoint identifier. This is used exclusively for 95 // lock ordering within connect. 96 id uint64 97 98 // idGenerator is used to generate new unique endpoint identifiers. 99 idGenerator UniqueIDProvider 100 101 // stype is used by connecting sockets to ensure that they are the 102 // same type. The value is typically either tcpip.SockSeqpacket or 103 // tcpip.SockStream. 104 stype linux.SockType 105 106 // acceptedChan is per the TCP endpoint implementation. Note that the 107 // sockets in this channel are _already in the connected state_, and 108 // have another associated connectionedEndpoint. 109 // 110 // If nil, then no listen call has been made. 111 acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"` 112 } 113 114 var ( 115 _ = BoundEndpoint((*connectionedEndpoint)(nil)) 116 _ = Endpoint((*connectionedEndpoint)(nil)) 117 ) 118 119 // NewConnectioned creates a new unbound connectionedEndpoint. 120 func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { 121 return newConnectioned(ctx, stype, uid) 122 } 123 124 func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { 125 ep := &connectionedEndpoint{ 126 baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, 127 id: uid.UniqueID(), 128 idGenerator: uid, 129 stype: stype, 130 } 131 132 ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) 133 ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) 134 ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) 135 return ep 136 } 137 138 // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. 139 func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { 140 a := newConnectioned(ctx, stype, uid) 141 b := newConnectioned(ctx, stype, uid) 142 143 q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize} 144 q1.InitRefs() 145 q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize} 146 q2.InitRefs() 147 148 if stype == linux.SOCK_STREAM { 149 a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} 150 b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} 151 } else { 152 a.receiver = &queueReceiver{q1} 153 b.receiver = &queueReceiver{q2} 154 } 155 156 q2.IncRef() 157 a.connected = &connectedEndpoint{ 158 endpoint: b, 159 writeQueue: q2, 160 } 161 q1.IncRef() 162 b.connected = &connectedEndpoint{ 163 endpoint: a, 164 writeQueue: q1, 165 } 166 167 return a, b 168 } 169 170 // NewExternal creates a new externally backed Endpoint. It behaves like a 171 // socketpair. 172 func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { 173 ep := &connectionedEndpoint{ 174 baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, 175 id: uid.UniqueID(), 176 idGenerator: uid, 177 stype: stype, 178 } 179 ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) 180 ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) 181 ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) 182 return ep 183 } 184 185 // ID implements ConnectingEndpoint.ID. 186 func (e *connectionedEndpoint) ID() uint64 { 187 return e.id 188 } 189 190 // Type implements ConnectingEndpoint.Type and Endpoint.Type. 191 func (e *connectionedEndpoint) Type() linux.SockType { 192 return e.stype 193 } 194 195 // WaiterQueue implements ConnectingEndpoint.WaiterQueue. 196 func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue { 197 return e.Queue 198 } 199 200 // isBound returns true iff the connectionedEndpoint is bound (but not 201 // listening). 202 func (e *connectionedEndpoint) isBound() bool { 203 return e.path != "" && e.acceptedChan == nil 204 } 205 206 // Listening implements ConnectingEndpoint.Listening. 207 func (e *connectionedEndpoint) Listening() bool { 208 return e.acceptedChan != nil 209 } 210 211 // Close puts the connectionedEndpoint in a closed state and frees all 212 // resources associated with it. 213 // 214 // The socket will be a fresh state after a call to close and may be reused. 215 // That is, close may be used to "unbind" or "disconnect" the socket in error 216 // paths. 217 func (e *connectionedEndpoint) Close(ctx context.Context) { 218 e.Lock() 219 var c ConnectedEndpoint 220 var r Receiver 221 switch { 222 case e.Connected(): 223 e.connected.CloseSend() 224 e.receiver.CloseRecv() 225 // Still have unread data? If yes, we set this into the write 226 // end so that the peer can get ECONNRESET) when it does read. 227 if e.receiver.RecvQueuedSize() > 0 { 228 e.connected.CloseUnread() 229 } 230 c = e.connected 231 r = e.receiver 232 e.connected = nil 233 e.receiver = nil 234 case e.isBound(): 235 e.path = "" 236 case e.Listening(): 237 close(e.acceptedChan) 238 for n := range e.acceptedChan { 239 n.Close(ctx) 240 } 241 e.acceptedChan = nil 242 e.path = "" 243 } 244 e.Unlock() 245 if c != nil { 246 c.CloseNotify() 247 c.Release(ctx) 248 } 249 if r != nil { 250 r.CloseNotify() 251 r.Release(ctx) 252 } 253 } 254 255 // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. 256 func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { 257 if ce.Type() != e.stype { 258 return syserr.ErrWrongProtocolForSocket 259 } 260 261 // Check if ce is e to avoid a deadlock. 262 if ce, ok := ce.(*connectionedEndpoint); ok && ce == e { 263 return syserr.ErrInvalidEndpointState 264 } 265 266 // Do a dance to safely acquire locks on both endpoints. 267 if e.id < ce.ID() { 268 e.Lock() 269 ce.Lock() 270 } else { 271 ce.Lock() 272 e.Lock() 273 } 274 275 // Check connecting state. 276 if ce.Connected() { 277 e.Unlock() 278 ce.Unlock() 279 return syserr.ErrAlreadyConnected 280 } 281 if ce.Listening() { 282 e.Unlock() 283 ce.Unlock() 284 return syserr.ErrInvalidEndpointState 285 } 286 287 // Check bound state. 288 if !e.Listening() { 289 e.Unlock() 290 ce.Unlock() 291 return syserr.ErrConnectionRefused 292 } 293 294 // Create a newly bound connectionedEndpoint. 295 ne := &connectionedEndpoint{ 296 baseEndpoint: baseEndpoint{ 297 path: e.path, 298 Queue: &waiter.Queue{}, 299 }, 300 id: e.idGenerator.UniqueID(), 301 idGenerator: e.idGenerator, 302 stype: e.stype, 303 } 304 ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) 305 ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) 306 ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) 307 308 readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} 309 readQueue.InitRefs() 310 ne.connected = &connectedEndpoint{ 311 endpoint: ce, 312 writeQueue: readQueue, 313 } 314 315 // Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF. 316 writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()} 317 writeQueue.InitRefs() 318 if e.stype == linux.SOCK_STREAM { 319 ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} 320 } else { 321 ne.receiver = &queueReceiver{readQueue: writeQueue} 322 } 323 324 select { 325 case e.acceptedChan <- ne: 326 // Commit state. 327 writeQueue.IncRef() 328 connected := &connectedEndpoint{ 329 endpoint: ne, 330 writeQueue: writeQueue, 331 } 332 readQueue.IncRef() 333 if e.stype == linux.SOCK_STREAM { 334 returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) 335 } else { 336 returnConnect(&queueReceiver{readQueue: readQueue}, connected) 337 } 338 339 // Notify can deadlock if we are holding these locks. 340 e.Unlock() 341 ce.Unlock() 342 343 // Notify on both ends. 344 e.Notify(waiter.ReadableEvents) 345 ce.WaiterQueue().Notify(waiter.WritableEvents) 346 347 return nil 348 default: 349 // Busy; return EAGAIN per spec. 350 ne.Close(ctx) 351 e.Unlock() 352 ce.Unlock() 353 return syserr.ErrTryAgain 354 } 355 } 356 357 // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. 358 func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { 359 return nil, syserr.ErrConnectionRefused 360 } 361 362 // Connect attempts to directly connect to another Endpoint. 363 // Implements Endpoint.Connect. 364 func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { 365 returnConnect := func(r Receiver, ce ConnectedEndpoint) { 366 e.receiver = r 367 e.connected = ce 368 // Make sure the newly created connected endpoint's write queue is updated 369 // to reflect this endpoint's send buffer size. 370 if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() { 371 e.ops.SetSendBufferSize(bufSz, false /* notify */) 372 e.ops.SetReceiveBufferSize(bufSz, false /* notify */) 373 } 374 } 375 376 return server.BidirectionalConnect(ctx, e, returnConnect) 377 } 378 379 // Listen starts listening on the connection. 380 func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { 381 e.Lock() 382 defer e.Unlock() 383 if e.Listening() { 384 // Adjust the size of the channel iff we can fix existing 385 // pending connections into the new one. 386 if len(e.acceptedChan) > backlog { 387 return syserr.ErrInvalidEndpointState 388 } 389 origChan := e.acceptedChan 390 e.acceptedChan = make(chan *connectionedEndpoint, backlog) 391 close(origChan) 392 for ep := range origChan { 393 e.acceptedChan <- ep 394 } 395 return nil 396 } 397 if !e.isBound() { 398 return syserr.ErrInvalidEndpointState 399 } 400 401 // Normal case. 402 e.acceptedChan = make(chan *connectionedEndpoint, backlog) 403 return nil 404 } 405 406 // Accept accepts a new connection. 407 func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { 408 e.Lock() 409 defer e.Unlock() 410 411 if !e.Listening() { 412 return nil, syserr.ErrInvalidEndpointState 413 } 414 415 select { 416 case ne := <-e.acceptedChan: 417 if peerAddr != nil { 418 ne.Lock() 419 c := ne.connected 420 ne.Unlock() 421 if c != nil { 422 addr, err := c.GetLocalAddress() 423 if err != nil { 424 return nil, syserr.TranslateNetstackError(err) 425 } 426 *peerAddr = addr 427 } 428 } 429 return ne, nil 430 431 default: 432 // Nothing left. 433 return nil, syserr.ErrWouldBlock 434 } 435 } 436 437 // Bind binds the connection. 438 // 439 // For Unix connectionedEndpoints, this _only sets the address associated with 440 // the socket_. Work associated with sockets in the filesystem or finding those 441 // sockets must be done by a higher level. 442 // 443 // Bind will fail only if the socket is connected, bound or the passed address 444 // is invalid (the empty string). 445 func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error { 446 e.Lock() 447 defer e.Unlock() 448 if e.isBound() || e.Listening() { 449 return syserr.ErrAlreadyBound 450 } 451 if addr.Addr == "" { 452 // The empty string is not permitted. 453 return syserr.ErrBadLocalAddress 454 } 455 if commit != nil { 456 if err := commit(); err != nil { 457 return err 458 } 459 } 460 461 // Save the bound address. 462 e.path = string(addr.Addr) 463 return nil 464 } 465 466 // SendMsg writes data and a control message to the endpoint's peer. 467 // This method does not block if the data cannot be written. 468 func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) { 469 // Stream sockets do not support specifying the endpoint. Seqpacket 470 // sockets ignore the passed endpoint. 471 if e.stype == linux.SOCK_STREAM && to != nil { 472 return 0, syserr.ErrNotSupported 473 } 474 return e.baseEndpoint.SendMsg(ctx, data, c, to) 475 } 476 477 // Readiness returns the current readiness of the connectionedEndpoint. For 478 // example, if waiter.EventIn is set, the connectionedEndpoint is immediately 479 // readable. 480 func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { 481 e.Lock() 482 defer e.Unlock() 483 484 ready := waiter.EventMask(0) 485 switch { 486 case e.Connected(): 487 if mask&waiter.ReadableEvents != 0 && e.receiver.Readable() { 488 ready |= waiter.ReadableEvents 489 } 490 if mask&waiter.WritableEvents != 0 && e.connected.Writable() { 491 ready |= waiter.WritableEvents 492 } 493 case e.Listening(): 494 if mask&waiter.ReadableEvents != 0 && len(e.acceptedChan) > 0 { 495 ready |= waiter.ReadableEvents 496 } 497 } 498 499 return ready 500 } 501 502 // State implements socket.Socket.State. 503 func (e *connectionedEndpoint) State() uint32 { 504 e.Lock() 505 defer e.Unlock() 506 507 if e.Connected() { 508 return linux.SS_CONNECTED 509 } 510 return linux.SS_UNCONNECTED 511 } 512 513 // OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. 514 func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { 515 if e.Connected() { 516 return e.baseEndpoint.connected.SetSendBufferSize(v) 517 } 518 return v 519 }