github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/zerocopy/tcp.go (about)

     1  package zerocopy
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	mrand "math/rand/v2"
    10  	"net"
    11  
    12  	"github.com/database64128/shadowsocks-go/conn"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  var (
    17  	ErrAcceptDoneNoRelay     = errors.New("the accepted connection has been handled without relaying")
    18  	ErrAcceptRequiresTCPConn = errors.New("rawRW is required to be a *net.TCPConn")
    19  )
    20  
    21  // TCPClientInfo contains information about a TCP client.
    22  type TCPClientInfo struct {
    23  	// Name is the name of the TCP client.
    24  	Name string
    25  
    26  	// NativeInitialPayload reports whether the protocol natively supports
    27  	// sending the initial payload within or along with the request header.
    28  	NativeInitialPayload bool
    29  }
    30  
    31  // TCPClient is a protocol's TCP client.
    32  type TCPClient interface {
    33  	// ClientInfo returns information about the TCP client.
    34  	Info() TCPClientInfo
    35  
    36  	// Dial creates a connection to the target address under the protocol's
    37  	// encapsulation and returns the established connection and a ReadWriter for read-write access.
    38  	Dial(ctx context.Context, targetAddr conn.Addr, payload []byte) (rawRW DirectReadWriteCloser, rw ReadWriter, err error)
    39  }
    40  
    41  // TCPServerInfo contains information about a TCP server.
    42  type TCPServerInfo struct {
    43  	// NativeInitialPayload reports whether the protocol natively supports
    44  	// sending the initial payload within or along with the request header.
    45  	NativeInitialPayload bool
    46  
    47  	// DefaultTCPConnCloser is the server's default function for handling a potentially malicious TCP connection.
    48  	DefaultTCPConnCloser TCPConnCloser
    49  }
    50  
    51  // TCPServer provides a protocol's TCP service.
    52  type TCPServer interface {
    53  	// ServerInfo returns information about the TCP server.
    54  	Info() TCPServerInfo
    55  
    56  	// Accept takes a newly-accepted TCP connection and wraps it into a protocol stream server.
    57  	//
    58  	// To make it easier to write tests, rawRW is of type [DirectReadWriteCloser].
    59  	// If the stream server needs to access TCP-specific features, it must type-assert and return
    60  	// [ErrAcceptRequiresTCPConn] on error.
    61  	//
    62  	// If the returned error is [ErrAcceptDoneNoRelay], the connection has been handled by this method.
    63  	// Two-way relay is not needed.
    64  	//
    65  	// If accept fails, the returned payload must be either nil/empty or the data that has been read
    66  	// from the connection.
    67  	Accept(rawRW DirectReadWriteCloser) (rw ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error)
    68  }
    69  
    70  // TCPConnOpener stores information for opening TCP connections.
    71  //
    72  // TCPConnOpener implements the DirectReadWriteCloserOpener interface.
    73  type TCPConnOpener struct {
    74  	dialer           conn.Dialer
    75  	network, address string
    76  }
    77  
    78  // NewTCPConnOpener returns a new TCPConnOpener using the specified dialer, network and address.
    79  func NewTCPConnOpener(dialer conn.Dialer, network, address string) *TCPConnOpener {
    80  	return &TCPConnOpener{
    81  		dialer:  dialer,
    82  		network: network,
    83  		address: address,
    84  	}
    85  }
    86  
    87  // Open implements the DirectReadWriteCloserOpener Open method.
    88  func (o *TCPConnOpener) Open(ctx context.Context, b []byte) (DirectReadWriteCloser, error) {
    89  	return o.dialer.DialTCP(ctx, o.network, o.address, b)
    90  }
    91  
    92  // TCPConnCloser handles a potentially malicious TCP connection.
    93  // Upon returning, the TCP connection is safe to close.
    94  type TCPConnCloser func(conn *net.TCPConn, logger *zap.Logger)
    95  
    96  // JustClose closes the TCP connection without any special handling.
    97  func JustClose(conn *net.TCPConn, logger *zap.Logger) {
    98  }
    99  
   100  // ForceReset forces a reset of the TCP connection, regardless of
   101  // whether there's unread data or not.
   102  func ForceReset(conn *net.TCPConn, logger *zap.Logger) {
   103  	if err := conn.SetLinger(0); err != nil {
   104  		logger.Warn("Failed to set SO_LINGER on TCP connection", zap.Error(err))
   105  	}
   106  	logger.Info("Forcing RST on TCP connection")
   107  }
   108  
   109  // CloseWriteDrain closes the write end of the TCP connection,
   110  // then drain the read end.
   111  func CloseWriteDrain(conn *net.TCPConn, logger *zap.Logger) {
   112  	if err := conn.CloseWrite(); err != nil {
   113  		logger.Warn("Failed to close write half of TCP connection", zap.Error(err))
   114  	}
   115  
   116  	n, err := io.Copy(io.Discard, conn)
   117  	logger.Info("Drained TCP connection",
   118  		zap.Int64("bytesRead", n),
   119  		zap.Error(err),
   120  	)
   121  }
   122  
   123  // ReplyWithGibberish keeps reading and replying with random garbage until EOF or error.
   124  func ReplyWithGibberish(conn *net.TCPConn, logger *zap.Logger) {
   125  	const (
   126  		riBits = 7
   127  		riMask = 1<<riBits - 1
   128  		riMax  = 64 / riBits
   129  	)
   130  
   131  	var (
   132  		ri        uint64
   133  		remaining int
   134  	)
   135  
   136  	const (
   137  		bufBaseSize    = 1 << 14
   138  		bufVarSizeMask = bufBaseSize - 1
   139  	)
   140  
   141  	var (
   142  		bytesRead    int64
   143  		bytesWritten int64
   144  		n            int
   145  		err          error
   146  	)
   147  
   148  	b := make([]byte, bufBaseSize+mrand.Uint64()&bufVarSizeMask) // [16k, 32k)
   149  
   150  	for {
   151  		n, err = conn.Read(b)
   152  		bytesRead += int64(n)
   153  		if err != nil { // For TCPConn, when err == io.EOF, n == 0.
   154  			break
   155  		}
   156  
   157  		// n is in [129, 256].
   158  		// getrandom(2) won't block if the request size is not greater than 256.
   159  		if remaining == 0 {
   160  			ri = mrand.Uint64()
   161  			remaining = riMax
   162  		}
   163  		n = 129 + int(ri&riMask)
   164  		ri >>= riBits
   165  		remaining--
   166  
   167  		garbage := b[:n]
   168  		_, err = rand.Read(garbage)
   169  		if err != nil {
   170  			logger.Error("Failed to generate random garbage", zap.Error(err))
   171  			break
   172  		}
   173  
   174  		n, err = conn.Write(garbage)
   175  		bytesWritten += int64(n)
   176  		if err != nil {
   177  			break
   178  		}
   179  	}
   180  
   181  	logger.Info("Replied with gibberish",
   182  		zap.Int64("bytesRead", bytesRead),
   183  		zap.Int64("bytesWritten", bytesWritten),
   184  		zap.Error(err),
   185  	)
   186  }
   187  
   188  // ParseRejectPolicy parses a string representation of a reject policy.
   189  func ParseRejectPolicy(rejectPolicy string, serverDefault TCPConnCloser) (TCPConnCloser, error) {
   190  	switch rejectPolicy {
   191  	case "":
   192  		return serverDefault, nil
   193  	case "JustClose":
   194  		return JustClose, nil
   195  	case "ForceReset":
   196  		return ForceReset, nil
   197  	case "CloseWriteDrain":
   198  		return CloseWriteDrain, nil
   199  	case "ReplyWithGibberish":
   200  		return ReplyWithGibberish, nil
   201  	default:
   202  		return nil, fmt.Errorf("invalid reject policy: %s", rejectPolicy)
   203  	}
   204  }