github.com/pion/dtls/v2@v2.2.12/flight1handler.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"context"
     8  
     9  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    10  	"github.com/pion/dtls/v2/pkg/protocol"
    11  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    12  	"github.com/pion/dtls/v2/pkg/protocol/extension"
    13  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    14  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    15  )
    16  
    17  func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
    18  	// HelloVerifyRequest can be skipped by the server,
    19  	// so allow ServerHello during flight1 also
    20  	seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
    21  		handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
    22  		handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
    23  	)
    24  	if !ok {
    25  		// No valid message received. Keep reading
    26  		return 0, nil, nil
    27  	}
    28  
    29  	if _, ok := msgs[handshake.TypeServerHello]; ok {
    30  		// Flight1 and flight2 were skipped.
    31  		// Parse as flight3.
    32  		return flight3Parse(ctx, c, state, cache, cfg)
    33  	}
    34  
    35  	if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
    36  		// DTLS 1.2 clients must not assume that the server will use the protocol version
    37  		// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
    38  		if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
    39  			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
    40  		}
    41  		state.cookie = append([]byte{}, h.Cookie...)
    42  		state.handshakeRecvSequence = seq
    43  		return flight3, nil, nil
    44  	}
    45  
    46  	return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
    47  }
    48  
    49  func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
    50  	var zeroEpoch uint16
    51  	state.localEpoch.Store(zeroEpoch)
    52  	state.remoteEpoch.Store(zeroEpoch)
    53  	state.namedCurve = defaultNamedCurve
    54  	state.cookie = nil
    55  
    56  	if err := state.localRandom.Populate(); err != nil {
    57  		return nil, nil, err
    58  	}
    59  
    60  	extensions := []extension.Extension{
    61  		&extension.SupportedSignatureAlgorithms{
    62  			SignatureHashAlgorithms: cfg.localSignatureSchemes,
    63  		},
    64  		&extension.RenegotiationInfo{
    65  			RenegotiatedConnection: 0,
    66  		},
    67  	}
    68  
    69  	var setEllipticCurveCryptographyClientHelloExtensions bool
    70  	for _, c := range cfg.localCipherSuites {
    71  		if c.ECC() {
    72  			setEllipticCurveCryptographyClientHelloExtensions = true
    73  			break
    74  		}
    75  	}
    76  
    77  	if setEllipticCurveCryptographyClientHelloExtensions {
    78  		extensions = append(extensions, []extension.Extension{
    79  			&extension.SupportedEllipticCurves{
    80  				EllipticCurves: cfg.ellipticCurves,
    81  			},
    82  			&extension.SupportedPointFormats{
    83  				PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
    84  			},
    85  		}...)
    86  	}
    87  
    88  	if len(cfg.localSRTPProtectionProfiles) > 0 {
    89  		extensions = append(extensions, &extension.UseSRTP{
    90  			ProtectionProfiles: cfg.localSRTPProtectionProfiles,
    91  		})
    92  	}
    93  
    94  	if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
    95  		cfg.extendedMasterSecret == RequireExtendedMasterSecret {
    96  		extensions = append(extensions, &extension.UseExtendedMasterSecret{
    97  			Supported: true,
    98  		})
    99  	}
   100  
   101  	if len(cfg.serverName) > 0 {
   102  		extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
   103  	}
   104  
   105  	if len(cfg.supportedProtocols) > 0 {
   106  		extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
   107  	}
   108  
   109  	if cfg.sessionStore != nil {
   110  		cfg.log.Tracef("[handshake] try to resume session")
   111  		if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil {
   112  			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
   113  		} else if s.ID != nil {
   114  			cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
   115  
   116  			state.SessionID = s.ID
   117  			state.masterSecret = s.Secret
   118  		}
   119  	}
   120  
   121  	return []*packet{
   122  		{
   123  			record: &recordlayer.RecordLayer{
   124  				Header: recordlayer.Header{
   125  					Version: protocol.Version1_2,
   126  				},
   127  				Content: &handshake.Handshake{
   128  					Message: &handshake.MessageClientHello{
   129  						Version:            protocol.Version1_2,
   130  						SessionID:          state.SessionID,
   131  						Cookie:             state.cookie,
   132  						Random:             state.localRandom,
   133  						CipherSuiteIDs:     cipherSuiteIDs(cfg.localCipherSuites),
   134  						CompressionMethods: defaultCompressionMethods(),
   135  						Extensions:         extensions,
   136  					},
   137  				},
   138  			},
   139  		},
   140  	}, nil, nil
   141  }