github.com/sagernet/sing-box@v1.9.0-rc.20/common/badtls/read_wait.go (about) 1 //go:build go1.21 && !without_badtls 2 3 package badtls 4 5 import ( 6 "bytes" 7 "context" 8 "net" 9 "os" 10 "reflect" 11 "sync" 12 "unsafe" 13 14 "github.com/sagernet/sing/common/buf" 15 E "github.com/sagernet/sing/common/exceptions" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/tls" 18 ) 19 20 var _ N.ReadWaiter = (*ReadWaitConn)(nil) 21 22 type ReadWaitConn struct { 23 tls.Conn 24 halfAccess *sync.Mutex 25 rawInput *bytes.Buffer 26 input *bytes.Reader 27 hand *bytes.Buffer 28 readWaitOptions N.ReadWaitOptions 29 tlsReadRecord func() error 30 tlsHandlePostHandshakeMessage func() error 31 } 32 33 func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { 34 var ( 35 loaded bool 36 tlsReadRecord func() error 37 tlsHandlePostHandshakeMessage func() error 38 ) 39 for _, tlsCreator := range tlsRegistry { 40 loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn) 41 if loaded { 42 break 43 } 44 } 45 if !loaded { 46 return nil, os.ErrInvalid 47 } 48 rawConn := reflect.Indirect(reflect.ValueOf(conn)) 49 rawHalfConn := rawConn.FieldByName("in") 50 if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { 51 return nil, E.New("badtls: invalid half conn") 52 } 53 rawHalfMutex := rawHalfConn.FieldByName("Mutex") 54 if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { 55 return nil, E.New("badtls: invalid half mutex") 56 } 57 halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) 58 rawRawInput := rawConn.FieldByName("rawInput") 59 if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct { 60 return nil, E.New("badtls: invalid raw input") 61 } 62 rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr())) 63 rawInput0 := rawConn.FieldByName("input") 64 if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct { 65 return nil, E.New("badtls: invalid input") 66 } 67 input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr())) 68 rawHand := rawConn.FieldByName("hand") 69 if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct { 70 return nil, E.New("badtls: invalid hand") 71 } 72 hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) 73 return &ReadWaitConn{ 74 Conn: conn, 75 halfAccess: halfAccess, 76 rawInput: rawInput, 77 input: input, 78 hand: hand, 79 tlsReadRecord: tlsReadRecord, 80 tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage, 81 }, nil 82 } 83 84 func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 85 c.readWaitOptions = options 86 return false 87 } 88 89 func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { 90 err = c.HandshakeContext(context.Background()) 91 if err != nil { 92 return 93 } 94 c.halfAccess.Lock() 95 defer c.halfAccess.Unlock() 96 for c.input.Len() == 0 { 97 err = c.tlsReadRecord() 98 if err != nil { 99 return 100 } 101 for c.hand.Len() > 0 { 102 err = c.tlsHandlePostHandshakeMessage() 103 if err != nil { 104 return 105 } 106 } 107 } 108 buffer = c.readWaitOptions.NewBuffer() 109 n, err := c.input.Read(buffer.FreeBytes()) 110 if err != nil { 111 buffer.Release() 112 return 113 } 114 buffer.Truncate(n) 115 116 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && 117 // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { 118 c.rawInput.Bytes()[0] == 21 { 119 _ = c.tlsReadRecord() 120 // return n, err // will be io.EOF on closeNotify 121 } 122 123 c.readWaitOptions.PostReturn(buffer) 124 return 125 } 126 127 func (c *ReadWaitConn) Upstream() any { 128 return c.Conn 129 } 130 131 var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) 132 133 func init() { 134 tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { 135 tlsConn, loaded := conn.(*tls.STDConn) 136 if !loaded { 137 return 138 } 139 return true, func() error { 140 return stdTLSReadRecord(tlsConn) 141 }, func() error { 142 return stdTLSHandlePostHandshakeMessage(tlsConn) 143 } 144 }) 145 } 146 147 //go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord 148 func stdTLSReadRecord(c *tls.STDConn) error 149 150 //go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage 151 func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error