github.com/database64128/shadowsocks-go@v1.7.0/zerocopy/tcp.go (about)

     1  package zerocopy
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  
    10  	"github.com/database64128/shadowsocks-go/conn"
    11  	"github.com/database64128/shadowsocks-go/magic"
    12  	"github.com/database64128/tfo-go/v2"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  var (
    17  	ErrAcceptDoneNoRelay     = errors.New("the accepted connection has been handled without relaying")
    18  	ErrAcceptRequiresTCPConn = errors.New("rawRW is required to be a *net.TCPConn")
    19  )
    20  
    21  // TCPClientInfo contains information about a TCP client.
    22  type TCPClientInfo struct {
    23  	// Name is the name of the TCP client.
    24  	Name string
    25  
    26  	// NativeInitialPayload reports whether the protocol natively supports
    27  	// sending the initial payload within or along with the request header.
    28  	NativeInitialPayload bool
    29  }
    30  
    31  // TCPClient is a protocol's TCP client.
    32  type TCPClient interface {
    33  	// ClientInfo returns information about the TCP client.
    34  	Info() TCPClientInfo
    35  
    36  	// Dial creates a connection to the target address under the protocol's
    37  	// encapsulation and returns the established connection and a ReadWriter for read-write access.
    38  	Dial(targetAddr conn.Addr, payload []byte) (rawRW DirectReadWriteCloser, rw ReadWriter, err error)
    39  }
    40  
    41  // TCPServerInfo contains information about a TCP server.
    42  type TCPServerInfo struct {
    43  	// NativeInitialPayload reports whether the protocol natively supports
    44  	// sending the initial payload within or along with the request header.
    45  	NativeInitialPayload bool
    46  
    47  	// DefaultTCPConnCloser is the server's default function for handling a potentially malicious TCP connection.
    48  	DefaultTCPConnCloser TCPConnCloser
    49  }
    50  
    51  // TCPServer provides a protocol's TCP service.
    52  type TCPServer interface {
    53  	// ServerInfo returns information about the TCP server.
    54  	Info() TCPServerInfo
    55  
    56  	// Accept takes a newly-accepted TCP connection and wraps it into a protocol stream server.
    57  	//
    58  	// To make it easier to write tests, rawRW is of type [DirectReadWriteCloser].
    59  	// If the stream server needs to access TCP-specific features, it must type-assert and return
    60  	// [ErrAcceptRequiresTCPConn] on error.
    61  	//
    62  	// If the returned error is [ErrAcceptDoneNoRelay], the connection has been handled by this method.
    63  	// Two-way relay is not needed.
    64  	//
    65  	// If accept fails, the returned payload must be either nil/empty or the data that has been read
    66  	// from the connection.
    67  	Accept(rawRW DirectReadWriteCloser) (rw ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error)
    68  }
    69  
    70  // TCPConnOpener stores information for opening TCP connections.
    71  //
    72  // TCPConnOpener implements the DirectReadWriteCloserOpener interface.
    73  type TCPConnOpener struct {
    74  	dialer           tfo.Dialer
    75  	network, address string
    76  }
    77  
    78  // NewTCPConnOpener returns a new TCPConnOpener using the specified dialer, network and address.
    79  func NewTCPConnOpener(dialer tfo.Dialer, network, address string) *TCPConnOpener {
    80  	return &TCPConnOpener{
    81  		dialer:  dialer,
    82  		network: network,
    83  		address: address,
    84  	}
    85  }
    86  
    87  // Open implements the DirectReadWriteCloserOpener Open method.
    88  func (o *TCPConnOpener) Open(b []byte) (DirectReadWriteCloser, error) {
    89  	c, err := o.dialer.Dial(o.network, o.address, b)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	return c.(DirectReadWriteCloser), nil
    94  }
    95  
    96  // TCPConnCloser handles a potentially malicious TCP connection.
    97  // Upon returning, the TCP connection is safe to close.
    98  type TCPConnCloser func(conn *net.TCPConn, serverName, listenAddress, clientAddress string, logger *zap.Logger)
    99  
   100  // JustClose closes the TCP connection without any special handling.
   101  func JustClose(conn *net.TCPConn, serverName, listenAddress, clientAddress string, logger *zap.Logger) {
   102  }
   103  
   104  // ForceReset forces a reset of the TCP connection, regardless of
   105  // whether there's unread data or not.
   106  func ForceReset(conn *net.TCPConn, serverName, listenAddress, clientAddress string, logger *zap.Logger) {
   107  	if err := conn.SetLinger(0); err != nil {
   108  		logger.Warn("Failed to set SO_LINGER on TCP connection",
   109  			zap.String("server", serverName),
   110  			zap.String("listenAddress", listenAddress),
   111  			zap.String("clientAddress", clientAddress),
   112  			zap.Error(err),
   113  		)
   114  	}
   115  
   116  	logger.Info("Forcing RST on TCP connection",
   117  		zap.String("server", serverName),
   118  		zap.String("listenAddress", listenAddress),
   119  		zap.String("clientAddress", clientAddress),
   120  	)
   121  }
   122  
   123  // CloseWriteDrain closes the write end of the TCP connection,
   124  // then drain the read end.
   125  func CloseWriteDrain(conn *net.TCPConn, serverName, listenAddress, clientAddress string, logger *zap.Logger) {
   126  	if err := conn.CloseWrite(); err != nil {
   127  		logger.Warn("Failed to close write half of TCP connection",
   128  			zap.String("server", serverName),
   129  			zap.String("listenAddress", listenAddress),
   130  			zap.String("clientAddress", clientAddress),
   131  			zap.Error(err),
   132  		)
   133  	}
   134  
   135  	n, err := io.Copy(io.Discard, conn)
   136  	logger.Info("Drained TCP connection",
   137  		zap.String("server", serverName),
   138  		zap.String("listenAddress", listenAddress),
   139  		zap.String("clientAddress", clientAddress),
   140  		zap.Int64("bytesRead", n),
   141  		zap.Error(err),
   142  	)
   143  }
   144  
   145  // ReplyWithGibberish keeps reading and replying with random garbage until EOF or error.
   146  func ReplyWithGibberish(conn *net.TCPConn, serverName, listenAddress, clientAddress string, logger *zap.Logger) {
   147  	const (
   148  		riBits = 7
   149  		riMask = 1<<riBits - 1
   150  		riMax  = 64 / riBits
   151  	)
   152  
   153  	var (
   154  		ri        uint64
   155  		remaining int
   156  	)
   157  
   158  	const (
   159  		bufBaseSize    = 1 << 14
   160  		bufVarSizeMask = bufBaseSize - 1
   161  	)
   162  
   163  	var (
   164  		bytesRead    int64
   165  		bytesWritten int64
   166  		n            int
   167  		err          error
   168  	)
   169  
   170  	b := make([]byte, bufBaseSize+magic.Fastrandu()&bufVarSizeMask) // [16k, 32k)
   171  
   172  	for {
   173  		n, err = conn.Read(b)
   174  		bytesRead += int64(n)
   175  		if err != nil { // For TCPConn, when err == io.EOF, n == 0.
   176  			break
   177  		}
   178  
   179  		// n is in [129, 256].
   180  		// getrandom(2) won't block if the request size is not greater than 256.
   181  		if remaining == 0 {
   182  			ri = magic.Fastrand64()
   183  			remaining = riMax
   184  		}
   185  		n = 129 + int(ri&riMask)
   186  		ri >>= riBits
   187  		remaining--
   188  
   189  		garbage := b[:n]
   190  		_, err = rand.Read(garbage)
   191  		if err != nil {
   192  			panic(err)
   193  		}
   194  
   195  		n, err = conn.Write(garbage)
   196  		bytesWritten += int64(n)
   197  		if err != nil {
   198  			break
   199  		}
   200  	}
   201  
   202  	logger.Info("Replied with gibberish",
   203  		zap.String("server", serverName),
   204  		zap.String("listenAddress", listenAddress),
   205  		zap.String("clientAddress", clientAddress),
   206  		zap.Int64("bytesRead", bytesRead),
   207  		zap.Int64("bytesWritten", bytesWritten),
   208  		zap.Error(err),
   209  	)
   210  }
   211  
   212  // ParseRejectPolicy parses a string representation of a reject policy.
   213  func ParseRejectPolicy(rejectPolicy string, serverDefault TCPConnCloser) (TCPConnCloser, error) {
   214  	switch rejectPolicy {
   215  	case "":
   216  		return serverDefault, nil
   217  	case "JustClose":
   218  		return JustClose, nil
   219  	case "ForceReset":
   220  		return ForceReset, nil
   221  	case "CloseWriteDrain":
   222  		return CloseWriteDrain, nil
   223  	case "ReplyWithGibberish":
   224  		return ReplyWithGibberish, nil
   225  	default:
   226  		return nil, fmt.Errorf("invalid reject policy: %s", rejectPolicy)
   227  	}
   228  }