github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/tls.go (about)

     1  //go:build go1.21
     2  
     3  package terasu
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"hash"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"unsafe"
    15  	_ "unsafe"
    16  )
    17  
    18  type recordType uint8
    19  
    20  const (
    21  	recordTypeChangeCipherSpec recordType = 20
    22  	recordTypeAlert            recordType = 21
    23  	recordTypeHandshake        recordType = 22
    24  	recordTypeApplicationData  recordType = 23
    25  )
    26  
    27  const (
    28  	recordHeaderLen = 5 // record header length
    29  )
    30  
    31  type alert uint8
    32  
    33  //go:linkname alertError tls.(tls.alert).Error
    34  func alertError(e alert) string
    35  
    36  func (e alert) Error() string {
    37  	return alertError(e)
    38  }
    39  
    40  // A halfConn represents one direction of the record layer
    41  // connection, either sending or receiving.
    42  type halfConn struct {
    43  	sync.Mutex
    44  
    45  	err     error  // first permanent error
    46  	version uint16 // protocol version
    47  	cipher  any    // cipher algorithm
    48  	mac     hash.Hash
    49  	seq     [8]byte // 64-bit sequence number
    50  
    51  	scratchBuf [13]byte // to avoid allocs; interface method args escape
    52  
    53  	nextCipher any       // next encryption state
    54  	nextMac    hash.Hash // next MAC algorithm
    55  
    56  	level         tls.QUICEncryptionLevel // current QUIC encryption level
    57  	trafficSecret []byte                  // current TLS 1.3 traffic secret
    58  }
    59  
    60  type Conn tls.Conn
    61  
    62  // A _trsconn represents a secured connection.
    63  // It implements the net._trsconn interface.
    64  type _trsconn struct {
    65  	// constant
    66  	conn        net.Conn
    67  	isClient    bool
    68  	handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
    69  	quic        *uintptr                    // nil for non-QUIC connections
    70  
    71  	// isHandshakeComplete is true if the connection is currently transferring
    72  	// application data (i.e. is not currently processing a handshake).
    73  	// isHandshakeComplete is true implies handshakeErr == nil.
    74  	isHandshakeComplete atomic.Bool
    75  	// constant after handshake; protected by handshakeMutex
    76  	handshakeMutex sync.Mutex
    77  	handshakeErr   error       // error resulting from handshake
    78  	vers           uint16      // TLS version
    79  	haveVers       bool        // version has been negotiated
    80  	config         *tls.Config // configuration passed to constructor
    81  	// handshakes counts the number of handshakes performed on the
    82  	// connection so far. If renegotiation is disabled then this is either
    83  	// zero or one.
    84  	handshakes       int
    85  	extMasterSecret  bool
    86  	didResume        bool // whether this connection was a session resumption
    87  	cipherSuite      uint16
    88  	ocspResponse     []byte   // stapled OCSP response
    89  	scts             [][]byte // signed certificate timestamps from server
    90  	peerCertificates []*x509.Certificate
    91  	// activeCertHandles contains the cache handles to certificates in
    92  	// peerCertificates that are used to track active references.
    93  	activeCertHandles []*uintptr
    94  	// verifiedChains contains the certificate chains that we built, as
    95  	// opposed to the ones presented by the server.
    96  	verifiedChains [][]*x509.Certificate
    97  	// serverName contains the server name indicated by the client, if any.
    98  	serverName string
    99  	// secureRenegotiation is true if the server echoed the secure
   100  	// renegotiation extension. (This is meaningless as a server because
   101  	// renegotiation is not supported in that case.)
   102  	secureRenegotiation bool
   103  	// ekm is a closure for exporting keying material.
   104  	ekm func(label string, context []byte, length int) ([]byte, error)
   105  	// resumptionSecret is the resumption_master_secret for handling
   106  	// or sending NewSessionTicket messages.
   107  	resumptionSecret []byte
   108  
   109  	// ticketKeys is the set of active session ticket keys for this
   110  	// connection. The first one is used to encrypt new tickets and
   111  	// all are tried to decrypt tickets.
   112  	ticketKeys []byte
   113  
   114  	// clientFinishedIsFirst is true if the client sent the first Finished
   115  	// message during the most recent handshake. This is recorded because
   116  	// the first transmitted Finished message is the tls-unique
   117  	// channel-binding value.
   118  	clientFinishedIsFirst bool
   119  
   120  	// closeNotifyErr is any error from sending the alertCloseNotify record.
   121  	closeNotifyErr error
   122  	// closeNotifySent is true if the Conn attempted to send an
   123  	// alertCloseNotify record.
   124  	closeNotifySent bool
   125  
   126  	// clientFinished and serverFinished contain the Finished message sent
   127  	// by the client or server in the most recent handshake. This is
   128  	// retained to support the renegotiation extension and tls-unique
   129  	// channel-binding.
   130  	clientFinished [12]byte
   131  	serverFinished [12]byte
   132  
   133  	// clientProtocol is the negotiated ALPN protocol.
   134  	clientProtocol string
   135  
   136  	// input/output
   137  	in, out halfConn
   138  }
   139  
   140  //go:linkname outBufPool crypto/tls.outBufPool
   141  var outBufPool sync.Pool
   142  
   143  //go:linkname tlsWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked
   144  func tlsWriteRecordLocked(c *_trsconn, typ recordType, data []byte) (int, error)
   145  
   146  //go:linkname maxPayloadSizeForWrite crypto/tls.(*Conn).maxPayloadSizeForWrite
   147  func maxPayloadSizeForWrite(c *_trsconn, typ recordType) int
   148  
   149  func (c *_trsconn) maxPayloadSizeForWrite(typ recordType) int {
   150  	return maxPayloadSizeForWrite(c, typ)
   151  }
   152  
   153  //go:linkname sliceForAppend crypto/tls.sliceForAppend
   154  func sliceForAppend(in []byte, n int) (head, tail []byte)
   155  
   156  //go:linkname encrypt crypto/tls.(*halfConn).encrypt
   157  func encrypt(hc *halfConn, record, payload []byte, rand io.Reader) ([]byte, error)
   158  
   159  func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
   160  	return encrypt(hc, record, payload, rand)
   161  }
   162  
   163  //go:linkname rand crypto/tls.(*Config).rand
   164  func rand(c *tls.Config) io.Reader
   165  
   166  //go:linkname write crypto/tls.(*Conn).write
   167  func write(c *_trsconn, data []byte) (int, error)
   168  
   169  func (c *_trsconn) write(data []byte) (int, error) {
   170  	return write(c, data)
   171  }
   172  
   173  //go:linkname flush crypto/tls.(*Conn).flush
   174  func flush(c *_trsconn) (int, error)
   175  
   176  func (c *_trsconn) flush() (int, error) {
   177  	return flush(c)
   178  }
   179  
   180  //go:linkname changeCipherSpec crypto/tls.(*halfConn).changeCipherSpec
   181  func changeCipherSpec(hc *halfConn) error
   182  
   183  func (hc *halfConn) changeCipherSpec() error {
   184  	return changeCipherSpec(hc)
   185  }
   186  
   187  //go:linkname sendAlertLocked crypto/tls.(*Conn).sendAlertLocked
   188  func sendAlertLocked(c *_trsconn, err alert) error
   189  
   190  func (c *_trsconn) sendAlertLocked(err alert) error {
   191  	return sendAlertLocked(c, err)
   192  }
   193  
   194  // writeRecordLocked writes a TLS record with the given type and payload to the
   195  // connection and updates the record layer state.
   196  func (c *_trsconn) writeRecordLocked(typ recordType, firstFragmentLen uint8, data []byte) (int, error) {
   197  	if c.quic != nil {
   198  		return tlsWriteRecordLocked(c, typ, data)
   199  	}
   200  
   201  	outBufPtr := outBufPool.Get().(*[]byte)
   202  	outBuf := *outBufPtr
   203  	defer func() {
   204  		// You might be tempted to simplify this by just passing &outBuf to Put,
   205  		// but that would make the local copy of the outBuf slice header escape
   206  		// to the heap, causing an allocation. Instead, we keep around the
   207  		// pointer to the slice header returned by Get, which is already on the
   208  		// heap, and overwrite and return that.
   209  		*outBufPtr = outBuf
   210  		outBufPool.Put(outBufPtr)
   211  	}()
   212  
   213  	var n int
   214  	isFirstLoop := true
   215  	for len(data) > 0 {
   216  		m := len(data)
   217  		if !isFirstLoop {
   218  			if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
   219  				m = maxPayload
   220  			}
   221  		} else {
   222  			m = int(firstFragmentLen)
   223  		}
   224  
   225  		_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
   226  		outBuf[0] = byte(typ)
   227  		vers := c.vers
   228  		if vers == 0 {
   229  			// Some TLS servers fail if the record version is
   230  			// greater than TLS 1.0 for the initial ClientHello.
   231  			vers = tls.VersionTLS10
   232  		} else if vers == tls.VersionTLS13 {
   233  			// TLS 1.3 froze the record layer version to 1.2.
   234  			// See RFC 8446, Section 5.1.
   235  			vers = tls.VersionTLS12
   236  		}
   237  		outBuf[1] = byte(vers >> 8)
   238  		outBuf[2] = byte(vers)
   239  		outBuf[3] = byte(m >> 8)
   240  		outBuf[4] = byte(m)
   241  
   242  		var err error
   243  		outBuf, err = c.out.encrypt(outBuf, data[:m], rand(c.config))
   244  		if err != nil {
   245  			return n, err
   246  		}
   247  		if _, err := c.write(outBuf); err != nil {
   248  			return n, err
   249  		}
   250  		n += m
   251  		data = data[m:]
   252  		if isFirstLoop {
   253  			isFirstLoop = false
   254  			if _, err := c.flush(); err != nil {
   255  				return n, err
   256  			}
   257  		}
   258  	}
   259  
   260  	if typ == recordTypeChangeCipherSpec && c.vers != tls.VersionTLS13 {
   261  		if err := c.out.changeCipherSpec(); err != nil {
   262  			return n, c.sendAlertLocked(alert(
   263  				*(*uintptr)(
   264  					unsafe.Add(unsafe.Pointer(&err), unsafe.Sizeof(uintptr(0))),
   265  				),
   266  			))
   267  		}
   268  	}
   269  
   270  	return n, nil
   271  }