github.com/sagernet/sing-box@v1.9.0-rc.20/common/badtls/read_wait.go (about)

     1  //go:build go1.21 && !without_badtls
     2  
     3  package badtls
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"net"
     9  	"os"
    10  	"reflect"
    11  	"sync"
    12  	"unsafe"
    13  
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/tls"
    18  )
    19  
    20  var _ N.ReadWaiter = (*ReadWaitConn)(nil)
    21  
    22  type ReadWaitConn struct {
    23  	tls.Conn
    24  	halfAccess                    *sync.Mutex
    25  	rawInput                      *bytes.Buffer
    26  	input                         *bytes.Reader
    27  	hand                          *bytes.Buffer
    28  	readWaitOptions               N.ReadWaitOptions
    29  	tlsReadRecord                 func() error
    30  	tlsHandlePostHandshakeMessage func() error
    31  }
    32  
    33  func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
    34  	var (
    35  		loaded                        bool
    36  		tlsReadRecord                 func() error
    37  		tlsHandlePostHandshakeMessage func() error
    38  	)
    39  	for _, tlsCreator := range tlsRegistry {
    40  		loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
    41  		if loaded {
    42  			break
    43  		}
    44  	}
    45  	if !loaded {
    46  		return nil, os.ErrInvalid
    47  	}
    48  	rawConn := reflect.Indirect(reflect.ValueOf(conn))
    49  	rawHalfConn := rawConn.FieldByName("in")
    50  	if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
    51  		return nil, E.New("badtls: invalid half conn")
    52  	}
    53  	rawHalfMutex := rawHalfConn.FieldByName("Mutex")
    54  	if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
    55  		return nil, E.New("badtls: invalid half mutex")
    56  	}
    57  	halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
    58  	rawRawInput := rawConn.FieldByName("rawInput")
    59  	if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
    60  		return nil, E.New("badtls: invalid raw input")
    61  	}
    62  	rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
    63  	rawInput0 := rawConn.FieldByName("input")
    64  	if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
    65  		return nil, E.New("badtls: invalid input")
    66  	}
    67  	input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
    68  	rawHand := rawConn.FieldByName("hand")
    69  	if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
    70  		return nil, E.New("badtls: invalid hand")
    71  	}
    72  	hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
    73  	return &ReadWaitConn{
    74  		Conn:                          conn,
    75  		halfAccess:                    halfAccess,
    76  		rawInput:                      rawInput,
    77  		input:                         input,
    78  		hand:                          hand,
    79  		tlsReadRecord:                 tlsReadRecord,
    80  		tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
    81  	}, nil
    82  }
    83  
    84  func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
    85  	c.readWaitOptions = options
    86  	return false
    87  }
    88  
    89  func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
    90  	err = c.HandshakeContext(context.Background())
    91  	if err != nil {
    92  		return
    93  	}
    94  	c.halfAccess.Lock()
    95  	defer c.halfAccess.Unlock()
    96  	for c.input.Len() == 0 {
    97  		err = c.tlsReadRecord()
    98  		if err != nil {
    99  			return
   100  		}
   101  		for c.hand.Len() > 0 {
   102  			err = c.tlsHandlePostHandshakeMessage()
   103  			if err != nil {
   104  				return
   105  			}
   106  		}
   107  	}
   108  	buffer = c.readWaitOptions.NewBuffer()
   109  	n, err := c.input.Read(buffer.FreeBytes())
   110  	if err != nil {
   111  		buffer.Release()
   112  		return
   113  	}
   114  	buffer.Truncate(n)
   115  
   116  	if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
   117  		// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
   118  		c.rawInput.Bytes()[0] == 21 {
   119  		_ = c.tlsReadRecord()
   120  		// return n, err // will be io.EOF on closeNotify
   121  	}
   122  
   123  	c.readWaitOptions.PostReturn(buffer)
   124  	return
   125  }
   126  
   127  func (c *ReadWaitConn) Upstream() any {
   128  	return c.Conn
   129  }
   130  
   131  var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
   132  
   133  func init() {
   134  	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
   135  		tlsConn, loaded := conn.(*tls.STDConn)
   136  		if !loaded {
   137  			return
   138  		}
   139  		return true, func() error {
   140  				return stdTLSReadRecord(tlsConn)
   141  			}, func() error {
   142  				return stdTLSHandlePostHandshakeMessage(tlsConn)
   143  			}
   144  	})
   145  }
   146  
   147  //go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
   148  func stdTLSReadRecord(c *tls.STDConn) error
   149  
   150  //go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
   151  func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error