github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/badtls/badtls.go (about) 1 //go:build go1.20 && !go1.21 2 3 package badtls 4 5 import ( 6 "crypto/cipher" 7 "crypto/rand" 8 "crypto/tls" 9 "encoding/binary" 10 "io" 11 "net" 12 "reflect" 13 "sync" 14 "sync/atomic" 15 "unsafe" 16 17 "github.com/inazumav/sing-box/log" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/buf" 20 "github.com/sagernet/sing/common/bufio" 21 E "github.com/sagernet/sing/common/exceptions" 22 N "github.com/sagernet/sing/common/network" 23 aTLS "github.com/sagernet/sing/common/tls" 24 ) 25 26 type Conn struct { 27 *tls.Conn 28 writer N.ExtendedWriter 29 isHandshakeComplete *atomic.Bool 30 activeCall *atomic.Int32 31 closeNotifySent *bool 32 version *uint16 33 rand io.Reader 34 halfAccess *sync.Mutex 35 halfError *error 36 cipher cipher.AEAD 37 explicitNonceLen int 38 halfPtr uintptr 39 halfSeq []byte 40 halfScratchBuf []byte 41 } 42 43 func TryCreate(conn aTLS.Conn) aTLS.Conn { 44 tlsConn, ok := conn.(*tls.Conn) 45 if !ok { 46 return conn 47 } 48 badConn, err := Create(tlsConn) 49 if err != nil { 50 log.Warn("initialize badtls: ", err) 51 return conn 52 } 53 return badConn 54 } 55 56 func Create(conn *tls.Conn) (aTLS.Conn, error) { 57 rawConn := reflect.Indirect(reflect.ValueOf(conn)) 58 rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete") 59 if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct { 60 return nil, E.New("badtls: invalid isHandshakeComplete") 61 } 62 isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr())) 63 if !isHandshakeComplete.Load() { 64 return nil, E.New("handshake not finished") 65 } 66 rawActiveCall := rawConn.FieldByName("activeCall") 67 if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct { 68 return nil, E.New("badtls: invalid active call") 69 } 70 activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) 71 rawHalfConn := rawConn.FieldByName("out") 72 if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { 73 return nil, E.New("badtls: invalid half conn") 74 } 75 rawVersion := rawConn.FieldByName("vers") 76 if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 { 77 return nil, E.New("badtls: invalid version") 78 } 79 version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr())) 80 rawCloseNotifySent := rawConn.FieldByName("closeNotifySent") 81 if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool { 82 return nil, E.New("badtls: invalid notify") 83 } 84 closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr())) 85 rawConfig := reflect.Indirect(rawConn.FieldByName("config")) 86 if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct { 87 return nil, E.New("badtls: bad config") 88 } 89 config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr())) 90 randReader := config.Rand 91 if randReader == nil { 92 randReader = rand.Reader 93 } 94 rawHalfMutex := rawHalfConn.FieldByName("Mutex") 95 if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { 96 return nil, E.New("badtls: invalid half mutex") 97 } 98 halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) 99 rawHalfError := rawHalfConn.FieldByName("err") 100 if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface { 101 return nil, E.New("badtls: invalid half error") 102 } 103 halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr())) 104 rawHalfCipherInterface := rawHalfConn.FieldByName("cipher") 105 if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface { 106 return nil, E.New("badtls: invalid cipher interface") 107 } 108 rawHalfCipher := rawHalfCipherInterface.Elem() 109 aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD) 110 if !loaded { 111 return nil, E.New("badtls: invalid AEAD cipher") 112 } 113 var explicitNonceLen int 114 switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName { 115 case "tls.prefixNonceAEAD": 116 explicitNonceLen = aeadCipher.NonceSize() 117 case "tls.xorNonceAEAD": 118 default: 119 return nil, E.New("badtls: unknown cipher type: ", cipherName) 120 } 121 rawHalfSeq := rawHalfConn.FieldByName("seq") 122 if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array { 123 return nil, E.New("badtls: invalid seq") 124 } 125 halfSeq := rawHalfSeq.Bytes() 126 rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf") 127 if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array { 128 return nil, E.New("badtls: invalid scratchBuf") 129 } 130 halfScratchBuf := rawHalfScratchBuf.Bytes() 131 return &Conn{ 132 Conn: conn, 133 writer: bufio.NewExtendedWriter(conn.NetConn()), 134 isHandshakeComplete: isHandshakeComplete, 135 activeCall: activeCall, 136 closeNotifySent: closeNotifySent, 137 version: version, 138 halfAccess: halfAccess, 139 halfError: halfError, 140 cipher: aeadCipher, 141 explicitNonceLen: explicitNonceLen, 142 rand: randReader, 143 halfPtr: rawHalfConn.UnsafeAddr(), 144 halfSeq: halfSeq, 145 halfScratchBuf: halfScratchBuf, 146 }, nil 147 } 148 149 func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { 150 if buffer.Len() > maxPlaintext { 151 defer buffer.Release() 152 return common.Error(c.Write(buffer.Bytes())) 153 } 154 for { 155 x := c.activeCall.Load() 156 if x&1 != 0 { 157 return net.ErrClosed 158 } 159 if c.activeCall.CompareAndSwap(x, x+2) { 160 break 161 } 162 } 163 defer c.activeCall.Add(-2) 164 c.halfAccess.Lock() 165 defer c.halfAccess.Unlock() 166 if err := *c.halfError; err != nil { 167 return err 168 } 169 if *c.closeNotifySent { 170 return errShutdown 171 } 172 dataLen := buffer.Len() 173 dataBytes := buffer.Bytes() 174 outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen) 175 outBuf[0] = 23 176 version := *c.version 177 if version == 0 { 178 version = tls.VersionTLS10 179 } else if version == tls.VersionTLS13 { 180 version = tls.VersionTLS12 181 } 182 binary.BigEndian.PutUint16(outBuf[1:], version) 183 var nonce []byte 184 if c.explicitNonceLen > 0 { 185 nonce = outBuf[5 : 5+c.explicitNonceLen] 186 if c.explicitNonceLen < 16 { 187 copy(nonce, c.halfSeq) 188 } else { 189 if _, err := io.ReadFull(c.rand, nonce); err != nil { 190 return err 191 } 192 } 193 } 194 if len(nonce) == 0 { 195 nonce = c.halfSeq 196 } 197 if *c.version == tls.VersionTLS13 { 198 buffer.FreeBytes()[0] = 23 199 binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead())) 200 c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen]) 201 buffer.Extend(1 + c.cipher.Overhead()) 202 } else { 203 binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen)) 204 additionalData := append(c.halfScratchBuf[:0], c.halfSeq...) 205 additionalData = append(additionalData, outBuf[:recordHeaderLen]...) 206 c.cipher.Seal(outBuf, nonce, dataBytes, additionalData) 207 buffer.Extend(c.cipher.Overhead()) 208 binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead())) 209 } 210 incSeq(c.halfPtr) 211 log.Trace("badtls write ", buffer.Len()) 212 return c.writer.WriteBuffer(buffer) 213 } 214 215 func (c *Conn) FrontHeadroom() int { 216 return recordHeaderLen + c.explicitNonceLen 217 } 218 219 func (c *Conn) RearHeadroom() int { 220 return 1 + c.cipher.Overhead() 221 } 222 223 func (c *Conn) WriterMTU() int { 224 return maxPlaintext 225 } 226 227 func (c *Conn) Upstream() any { 228 return c.Conn 229 } 230 231 func (c *Conn) UpstreamWriter() any { 232 return c.NetConn() 233 }