github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/flipcall/ctrl_futex.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package flipcall
    16  
    17  import (
    18  	"encoding/json"
    19  	"fmt"
    20  	"math"
    21  	"sync/atomic"
    22  
    23  	"github.com/SagerNet/gvisor/pkg/log"
    24  )
    25  
    26  type endpointControlImpl struct {
    27  	state int32
    28  }
    29  
    30  // Bits in endpointControlImpl.state.
    31  const (
    32  	epsBlocked = 1 << iota
    33  	epsShutdown
    34  )
    35  
    36  func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error {
    37  	if len(opts) != 0 {
    38  		return fmt.Errorf("unknown EndpointOption: %T", opts[0])
    39  	}
    40  	return nil
    41  }
    42  
    43  func (ep *Endpoint) ctrlConnect() error {
    44  	if err := ep.enterFutexWait(); err != nil {
    45  		return err
    46  	}
    47  	defer ep.exitFutexWait()
    48  
    49  	// Write the connection request.
    50  	w := ep.NewWriter()
    51  	if err := json.NewEncoder(w).Encode(struct{}{}); err != nil {
    52  		return fmt.Errorf("error writing connection request: %v", err)
    53  	}
    54  	*ep.dataLen() = w.Len()
    55  
    56  	// Exchange control with the server.
    57  	if err := ep.futexSetPeerActive(); err != nil {
    58  		return err
    59  	}
    60  	if err := ep.futexWakePeer(); err != nil {
    61  		return err
    62  	}
    63  	if err := ep.futexWaitUntilActive(); err != nil {
    64  		return err
    65  	}
    66  
    67  	// Read the connection response.
    68  	var resp struct{}
    69  	respLen := atomic.LoadUint32(ep.dataLen())
    70  	if respLen > ep.dataCap {
    71  		return fmt.Errorf("invalid connection response length %d (maximum %d)", respLen, ep.dataCap)
    72  	}
    73  	if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil {
    74  		return fmt.Errorf("error reading connection response: %v", err)
    75  	}
    76  
    77  	return nil
    78  }
    79  
    80  func (ep *Endpoint) ctrlWaitFirst() error {
    81  	if err := ep.enterFutexWait(); err != nil {
    82  		return err
    83  	}
    84  	defer ep.exitFutexWait()
    85  
    86  	// Wait for the connection request.
    87  	if err := ep.futexWaitUntilActive(); err != nil {
    88  		return err
    89  	}
    90  
    91  	// Read the connection request.
    92  	reqLen := atomic.LoadUint32(ep.dataLen())
    93  	if reqLen > ep.dataCap {
    94  		return fmt.Errorf("invalid connection request length %d (maximum %d)", reqLen, ep.dataCap)
    95  	}
    96  	var req struct{}
    97  	if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil {
    98  		return fmt.Errorf("error reading connection request: %v", err)
    99  	}
   100  
   101  	// Write the connection response.
   102  	w := ep.NewWriter()
   103  	if err := json.NewEncoder(w).Encode(struct{}{}); err != nil {
   104  		return fmt.Errorf("error writing connection response: %v", err)
   105  	}
   106  	*ep.dataLen() = w.Len()
   107  
   108  	// Return control to the client.
   109  	raceBecomeInactive()
   110  	if err := ep.futexSetPeerActive(); err != nil {
   111  		return err
   112  	}
   113  	if err := ep.futexWakePeer(); err != nil {
   114  		return err
   115  	}
   116  
   117  	// Wait for the first non-connection message.
   118  	return ep.futexWaitUntilActive()
   119  }
   120  
   121  func (ep *Endpoint) ctrlRoundTrip() error {
   122  	if err := ep.enterFutexWait(); err != nil {
   123  		return err
   124  	}
   125  	defer ep.exitFutexWait()
   126  
   127  	if err := ep.futexSetPeerActive(); err != nil {
   128  		return err
   129  	}
   130  	if err := ep.futexWakePeer(); err != nil {
   131  		return err
   132  	}
   133  	return ep.futexWaitUntilActive()
   134  }
   135  
   136  func (ep *Endpoint) ctrlWakeLast() error {
   137  	if err := ep.futexSetPeerActive(); err != nil {
   138  		return err
   139  	}
   140  	return ep.futexWakePeer()
   141  }
   142  
   143  func (ep *Endpoint) enterFutexWait() error {
   144  	switch eps := atomic.AddInt32(&ep.ctrl.state, epsBlocked); eps {
   145  	case epsBlocked:
   146  		return nil
   147  	case epsBlocked | epsShutdown:
   148  		atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
   149  		return ShutdownError{}
   150  	default:
   151  		// Most likely due to ep.enterFutexWait() being called concurrently
   152  		// from multiple goroutines.
   153  		panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state before flipcall.Endpoint.enterFutexWait(): %v", eps-epsBlocked))
   154  	}
   155  }
   156  
   157  func (ep *Endpoint) exitFutexWait() {
   158  	switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
   159  	case 0:
   160  		return
   161  	case epsShutdown:
   162  		// ep.ctrlShutdown() was called while we were blocked, so we are
   163  		// repsonsible for indicating connection shutdown.
   164  		ep.shutdownConn()
   165  	default:
   166  		panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
   167  	}
   168  }
   169  
   170  func (ep *Endpoint) ctrlShutdown() {
   171  	// Set epsShutdown to ensure that future calls to ep.enterFutexWait() fail.
   172  	if atomic.AddInt32(&ep.ctrl.state, epsShutdown)&epsBlocked != 0 {
   173  		// Wake the blocked thread. This must loop because it's possible that
   174  		// FUTEX_WAKE occurs after the waiter sets epsBlocked, but before it
   175  		// blocks in FUTEX_WAIT.
   176  		for {
   177  			// Wake MaxInt32 threads to prevent a broken or malicious peer from
   178  			// swallowing our wakeup by FUTEX_WAITing from multiple threads.
   179  			if err := ep.futexWakeConnState(math.MaxInt32); err != nil {
   180  				log.Warningf("failed to FUTEX_WAKE Endpoints: %v", err)
   181  				break
   182  			}
   183  			yieldThread()
   184  			if atomic.LoadInt32(&ep.ctrl.state)&epsBlocked == 0 {
   185  				break
   186  			}
   187  		}
   188  	} else {
   189  		// There is no blocked thread, so we are responsible for indicating
   190  		// connection shutdown.
   191  		ep.shutdownConn()
   192  	}
   193  }
   194  
   195  func (ep *Endpoint) shutdownConn() {
   196  	switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
   197  	case ep.activeState:
   198  		if err := ep.futexWakeConnState(1); err != nil {
   199  			log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
   200  		}
   201  	case ep.inactiveState:
   202  		// The peer is currently active and will detect shutdown when it tries
   203  		// to update the connection state.
   204  	case csShutdown:
   205  		// The peer also called Endpoint.Shutdown().
   206  	default:
   207  		log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
   208  	}
   209  }