github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/tls.go (about) 1 //go:build go1.21 2 3 package terasu 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "hash" 10 "io" 11 "net" 12 "sync" 13 "sync/atomic" 14 "unsafe" 15 _ "unsafe" 16 ) 17 18 type recordType uint8 19 20 const ( 21 recordTypeChangeCipherSpec recordType = 20 22 recordTypeAlert recordType = 21 23 recordTypeHandshake recordType = 22 24 recordTypeApplicationData recordType = 23 25 ) 26 27 const ( 28 recordHeaderLen = 5 // record header length 29 ) 30 31 type alert uint8 32 33 //go:linkname alertError tls.(tls.alert).Error 34 func alertError(e alert) string 35 36 func (e alert) Error() string { 37 return alertError(e) 38 } 39 40 // A halfConn represents one direction of the record layer 41 // connection, either sending or receiving. 42 type halfConn struct { 43 sync.Mutex 44 45 err error // first permanent error 46 version uint16 // protocol version 47 cipher any // cipher algorithm 48 mac hash.Hash 49 seq [8]byte // 64-bit sequence number 50 51 scratchBuf [13]byte // to avoid allocs; interface method args escape 52 53 nextCipher any // next encryption state 54 nextMac hash.Hash // next MAC algorithm 55 56 level tls.QUICEncryptionLevel // current QUIC encryption level 57 trafficSecret []byte // current TLS 1.3 traffic secret 58 } 59 60 type Conn tls.Conn 61 62 // A _trsconn represents a secured connection. 63 // It implements the net._trsconn interface. 64 type _trsconn struct { 65 // constant 66 conn net.Conn 67 isClient bool 68 handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake 69 quic *uintptr // nil for non-QUIC connections 70 71 // isHandshakeComplete is true if the connection is currently transferring 72 // application data (i.e. is not currently processing a handshake). 73 // isHandshakeComplete is true implies handshakeErr == nil. 74 isHandshakeComplete atomic.Bool 75 // constant after handshake; protected by handshakeMutex 76 handshakeMutex sync.Mutex 77 handshakeErr error // error resulting from handshake 78 vers uint16 // TLS version 79 haveVers bool // version has been negotiated 80 config *tls.Config // configuration passed to constructor 81 // handshakes counts the number of handshakes performed on the 82 // connection so far. If renegotiation is disabled then this is either 83 // zero or one. 84 handshakes int 85 extMasterSecret bool 86 didResume bool // whether this connection was a session resumption 87 cipherSuite uint16 88 ocspResponse []byte // stapled OCSP response 89 scts [][]byte // signed certificate timestamps from server 90 peerCertificates []*x509.Certificate 91 // activeCertHandles contains the cache handles to certificates in 92 // peerCertificates that are used to track active references. 93 activeCertHandles []*uintptr 94 // verifiedChains contains the certificate chains that we built, as 95 // opposed to the ones presented by the server. 96 verifiedChains [][]*x509.Certificate 97 // serverName contains the server name indicated by the client, if any. 98 serverName string 99 // secureRenegotiation is true if the server echoed the secure 100 // renegotiation extension. (This is meaningless as a server because 101 // renegotiation is not supported in that case.) 102 secureRenegotiation bool 103 // ekm is a closure for exporting keying material. 104 ekm func(label string, context []byte, length int) ([]byte, error) 105 // resumptionSecret is the resumption_master_secret for handling 106 // or sending NewSessionTicket messages. 107 resumptionSecret []byte 108 109 // ticketKeys is the set of active session ticket keys for this 110 // connection. The first one is used to encrypt new tickets and 111 // all are tried to decrypt tickets. 112 ticketKeys []byte 113 114 // clientFinishedIsFirst is true if the client sent the first Finished 115 // message during the most recent handshake. This is recorded because 116 // the first transmitted Finished message is the tls-unique 117 // channel-binding value. 118 clientFinishedIsFirst bool 119 120 // closeNotifyErr is any error from sending the alertCloseNotify record. 121 closeNotifyErr error 122 // closeNotifySent is true if the Conn attempted to send an 123 // alertCloseNotify record. 124 closeNotifySent bool 125 126 // clientFinished and serverFinished contain the Finished message sent 127 // by the client or server in the most recent handshake. This is 128 // retained to support the renegotiation extension and tls-unique 129 // channel-binding. 130 clientFinished [12]byte 131 serverFinished [12]byte 132 133 // clientProtocol is the negotiated ALPN protocol. 134 clientProtocol string 135 136 // input/output 137 in, out halfConn 138 } 139 140 //go:linkname outBufPool crypto/tls.outBufPool 141 var outBufPool sync.Pool 142 143 //go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked 144 func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error) 145 146 //go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite 147 func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int 148 149 func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int { 150 return maxPayloadSizeForWrite(c, typ) 151 } 152 153 //go:linkname sliceForAppend crypto/tls.sliceForAppend 154 func sliceForAppend(in []byte, n int) (head, tail []byte) 155 156 //go:linkname encrypt crypto/tls.(*halfConn).encrypt 157 func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error) 158 159 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { 160 return encrypt(hc, record, payload, rand) 161 } 162 163 //go:linkname rand crypto/tls.(*Config).rand 164 func rand(c *tls.Config) io.Reader 165 166 //go:linkname write crypto/tls.(*Conn).write 167 func write(c *_trsconn, data []byte) (int, error) 168 169 func (c *_trsconn) write(data []byte) (int, error) { 170 return write(c, data) 171 } 172 173 //go:linkname flush crypto/tls.(*Conn).flush 174 func flush(c *_trsconn) (int, error) 175 176 func (c *_trsconn) flush() (int, error) { 177 return flush(c) 178 } 179 180 //go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec 181 func changeCipherSpec(hc *halfConn) error 182 183 func (hc *halfConn) changeCipherSpec() error { 184 return changeCipherSpec(hc) 185 } 186 187 //go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked 188 func sendAlertLocked(c *_trsconn, err alert) error 189 190 func (c *_trsconn) sendAlertLocked(err alert) error { 191 return sendAlertLocked(c, err) 192 } 193 194 // writeRecordLocked writes a TLS record with the given type and payload to the 195 // connection and updates the record layer state. 196 func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) { 197 if c.quic != nil { 198 return tlsWriteRecordLocked(c, typ, data) 199 } 200 201 outBufPtr := outBufPool.Get().(*[]byte) 202 outBuf := *outBufPtr 203 defer func() { 204 // You might be tempted to simplify this by just passing &outBuf to Put, 205 // but that would make the local copy of the outBuf slice header escape 206 // to the heap, causing an allocation. Instead, we keep around the 207 // pointer to the slice header returned by Get, which is already on the 208 // heap, and overwrite and return that. 209 *outBufPtr = outBuf 210 outBufPool.Put(outBufPtr) 211 }() 212 213 var n int 214 isFirstLoop := true 215 for len(data) > 0 { 216 m := len(data) 217 if !isFirstLoop { 218 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { 219 m = maxPayload 220 } 221 } else { 222 m = int(firstFragmentLen) 223 } 224 225 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) 226 outBuf[0] = byte(typ) 227 vers := c.vers 228 if vers == 0 { 229 // Some TLS servers fail if the record version is 230 // greater than TLS 1.0 for the initial ClientHello. 231 vers = tls.VersionTLS10 232 } else if vers == tls.VersionTLS13 { 233 // TLS 1.3 froze the record layer version to 1.2. 234 // See RFC 8446, Section 5.1. 235 vers = tls.VersionTLS12 236 } 237 outBuf[1] = byte(vers >> 8) 238 outBuf[2] = byte(vers) 239 outBuf[3] = byte(m >> 8) 240 outBuf[4] = byte(m) 241 242 var err error 243 outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config)) 244 if err != nil { 245 return n, err 246 } 247 if _, err := c.write(outBuf); err != nil { 248 return n, err 249 } 250 n += m 251 data = data[m:] 252 if isFirstLoop { 253 isFirstLoop = false 254 if _, err := c.flush(); err != nil { 255 return n, err 256 } 257 } 258 } 259 260 if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 { 261 if err := c.out.changeCipherSpec(); err != nil { 262 return n, c.sendAlertLocked(alert( 263 *(*uintptr)( 264 unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))), 265 ), 266 )) 267 } 268 } 269 270 return n, nil 271 }