github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/qtls/qtls.go (about)

     1  package qtls
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"fmt"
     7  
     8  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
     9  )
    10  
    11  func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
    12  	conf := qconf.TLSConfig
    13  
    14  	// Workaround for https://github.com/golang/go/issues/60506.
    15  	// This initializes the session tickets _before_ cloning the config.
    16  	_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
    17  
    18  	conf = conf.Clone()
    19  	conf.MinVersion = tls.VersionTLS13
    20  	qconf.TLSConfig = conf
    21  
    22  	// add callbacks to save transport parameters into the session ticket
    23  	origWrapSession := conf.WrapSession
    24  	conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
    25  		// Add QUIC session ticket
    26  		state.Extra = append(state.Extra, addExtraPrefix(getData()))
    27  
    28  		if origWrapSession != nil {
    29  			return origWrapSession(cs, state)
    30  		}
    31  		b, err := conf.EncryptTicket(cs, state)
    32  		return b, err
    33  	}
    34  	origUnwrapSession := conf.UnwrapSession
    35  	// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
    36  	// However, using 0-RTT is only possible with the first session ticket.
    37  	// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
    38  	var unwrapCount int
    39  	conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
    40  		unwrapCount++
    41  		var state *tls.SessionState
    42  		var err error
    43  		if origUnwrapSession != nil {
    44  			state, err = origUnwrapSession(identity, connState)
    45  		} else {
    46  			state, err = conf.DecryptTicket(identity, connState)
    47  		}
    48  		if err != nil || state == nil {
    49  			return nil, err
    50  		}
    51  
    52  		extra := findExtraData(state.Extra)
    53  		if extra != nil {
    54  			state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
    55  		} else {
    56  			state.EarlyData = false
    57  		}
    58  
    59  		return state, nil
    60  	}
    61  }
    62  
    63  func SetupConfigForClient(
    64  	qconf *tls.QUICConfig,
    65  	getData func(earlyData bool) []byte,
    66  	setData func(data []byte, earlyData bool) (allowEarlyData bool),
    67  ) {
    68  	conf := qconf.TLSConfig
    69  	if conf.ClientSessionCache != nil {
    70  		origCache := conf.ClientSessionCache
    71  		conf.ClientSessionCache = &clientSessionCache{
    72  			wrapped: origCache,
    73  			getData: getData,
    74  			setData: setData,
    75  		}
    76  	}
    77  }
    78  
    79  func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
    80  	switch e {
    81  	case protocol.EncryptionInitial:
    82  		return tls.QUICEncryptionLevelInitial
    83  	case protocol.EncryptionHandshake:
    84  		return tls.QUICEncryptionLevelHandshake
    85  	case protocol.Encryption1RTT:
    86  		return tls.QUICEncryptionLevelApplication
    87  	case protocol.Encryption0RTT:
    88  		return tls.QUICEncryptionLevelEarly
    89  	default:
    90  		panic(fmt.Sprintf("unexpected encryption level: %s", e))
    91  	}
    92  }
    93  
    94  func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
    95  	switch e {
    96  	case tls.QUICEncryptionLevelInitial:
    97  		return protocol.EncryptionInitial
    98  	case tls.QUICEncryptionLevelHandshake:
    99  		return protocol.EncryptionHandshake
   100  	case tls.QUICEncryptionLevelApplication:
   101  		return protocol.Encryption1RTT
   102  	case tls.QUICEncryptionLevelEarly:
   103  		return protocol.Encryption0RTT
   104  	default:
   105  		panic(fmt.Sprintf("unexpect encryption level: %s", e))
   106  	}
   107  }
   108  
   109  const extraPrefix = "quic-go1"
   110  
   111  func addExtraPrefix(b []byte) []byte {
   112  	return append([]byte(extraPrefix), b...)
   113  }
   114  
   115  func findExtraData(extras [][]byte) []byte {
   116  	prefix := []byte(extraPrefix)
   117  	for _, extra := range extras {
   118  		if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
   119  			continue
   120  		}
   121  		return extra[len(prefix):]
   122  	}
   123  	return nil
   124  }