github.com/pion/dtls/v2@v2.2.12/handshaker.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "context" 8 "crypto/tls" 9 "crypto/x509" 10 "fmt" 11 "io" 12 "sync" 13 "time" 14 15 "github.com/pion/dtls/v2/pkg/crypto/elliptic" 16 "github.com/pion/dtls/v2/pkg/crypto/signaturehash" 17 "github.com/pion/dtls/v2/pkg/protocol/alert" 18 "github.com/pion/dtls/v2/pkg/protocol/handshake" 19 "github.com/pion/logging" 20 ) 21 22 // [RFC6347 Section-4.2.4] 23 // +-----------+ 24 // +---> | PREPARING | <--------------------+ 25 // | +-----------+ | 26 // | | | 27 // | | Buffer next flight | 28 // | | | 29 // | \|/ | 30 // | +-----------+ | 31 // | | SENDING |<------------------+ | Send 32 // | +-----------+ | | HelloRequest 33 // Receive | | | | 34 // next | | Send flight | | or 35 // flight | +--------+ | | 36 // | | | Set retransmit timer | | Receive 37 // | | \|/ | | HelloRequest 38 // | | +-----------+ | | Send 39 // +--)--| WAITING |-------------------+ | ClientHello 40 // | | +-----------+ Timer expires | | 41 // | | | | | 42 // | | +------------------------+ | 43 // Receive | | Send Read retransmit | 44 // last | | last | 45 // flight | | flight | 46 // | | | 47 // \|/\|/ | 48 // +-----------+ | 49 // | FINISHED | -------------------------------+ 50 // +-----------+ 51 // | /|\ 52 // | | 53 // +---+ 54 // Read retransmit 55 // Retransmit last flight 56 57 type handshakeState uint8 58 59 const ( 60 handshakeErrored handshakeState = iota 61 handshakePreparing 62 handshakeSending 63 handshakeWaiting 64 handshakeFinished 65 ) 66 67 func (s handshakeState) String() string { 68 switch s { 69 case handshakeErrored: 70 return "Errored" 71 case handshakePreparing: 72 return "Preparing" 73 case handshakeSending: 74 return "Sending" 75 case handshakeWaiting: 76 return "Waiting" 77 case handshakeFinished: 78 return "Finished" 79 default: 80 return "Unknown" 81 } 82 } 83 84 type handshakeFSM struct { 85 currentFlight flightVal 86 flights []*packet 87 retransmit bool 88 state *State 89 cache *handshakeCache 90 cfg *handshakeConfig 91 closed chan struct{} 92 } 93 94 type handshakeConfig struct { 95 localPSKCallback PSKCallback 96 localPSKIdentityHint []byte 97 localCipherSuites []CipherSuite // Available CipherSuites 98 localSignatureSchemes []signaturehash.Algorithm // Available signature schemes 99 extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension 100 localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support 101 serverName string 102 supportedProtocols []string 103 clientAuth ClientAuthType // If we are a client should we request a client certificate 104 localCertificates []tls.Certificate 105 nameToCertificate map[string]*tls.Certificate 106 insecureSkipVerify bool 107 verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error 108 verifyConnection func(*State) error 109 sessionStore SessionStore 110 rootCAs *x509.CertPool 111 clientCAs *x509.CertPool 112 retransmitInterval time.Duration 113 customCipherSuites func() []CipherSuite 114 ellipticCurves []elliptic.Curve 115 insecureSkipHelloVerify bool 116 117 onFlightState func(flightVal, handshakeState) 118 log logging.LeveledLogger 119 keyLogWriter io.Writer 120 121 localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) 122 localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) 123 124 initialEpoch uint16 125 126 mu sync.Mutex 127 } 128 129 type flightConn interface { 130 notify(ctx context.Context, level alert.Level, desc alert.Description) error 131 writePackets(context.Context, []*packet) error 132 recvHandshake() <-chan chan struct{} 133 setLocalEpoch(epoch uint16) 134 handleQueuedPackets(context.Context) error 135 sessionKey() []byte 136 } 137 138 func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { 139 if c.keyLogWriter == nil { 140 return 141 } 142 c.mu.Lock() 143 defer c.mu.Unlock() 144 _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret))) 145 if err != nil { 146 c.log.Debugf("failed to write key log file: %s", err) 147 } 148 } 149 150 func srvCliStr(isClient bool) string { 151 if isClient { 152 return "client" 153 } 154 return "server" 155 } 156 157 func newHandshakeFSM( 158 s *State, cache *handshakeCache, cfg *handshakeConfig, 159 initialFlight flightVal, 160 ) *handshakeFSM { 161 return &handshakeFSM{ 162 currentFlight: initialFlight, 163 state: s, 164 cache: cache, 165 cfg: cfg, 166 closed: make(chan struct{}), 167 } 168 } 169 170 func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error { 171 state := initialState 172 defer func() { 173 close(s.closed) 174 }() 175 for { 176 s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) 177 if s.cfg.onFlightState != nil { 178 s.cfg.onFlightState(s.currentFlight, state) 179 } 180 var err error 181 switch state { 182 case handshakePreparing: 183 state, err = s.prepare(ctx, c) 184 case handshakeSending: 185 state, err = s.send(ctx, c) 186 case handshakeWaiting: 187 state, err = s.wait(ctx, c) 188 case handshakeFinished: 189 state, err = s.finish(ctx, c) 190 default: 191 return errInvalidFSMTransition 192 } 193 if err != nil { 194 return err 195 } 196 } 197 } 198 199 func (s *handshakeFSM) Done() <-chan struct{} { 200 return s.closed 201 } 202 203 func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) { 204 s.flights = nil 205 // Prepare flights 206 var ( 207 a *alert.Alert 208 err error 209 pkts []*packet 210 ) 211 gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() 212 if errFlight != nil { 213 err = errFlight 214 a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} 215 } else { 216 pkts, a, err = gen(c, s.state, s.cache, s.cfg) 217 s.retransmit = retransmit 218 } 219 if a != nil { 220 if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil { 221 if err != nil { 222 err = alertErr 223 } 224 } 225 } 226 if err != nil { 227 return handshakeErrored, err 228 } 229 230 s.flights = pkts 231 epoch := s.cfg.initialEpoch 232 nextEpoch := epoch 233 for _, p := range s.flights { 234 p.record.Header.Epoch += epoch 235 if p.record.Header.Epoch > nextEpoch { 236 nextEpoch = p.record.Header.Epoch 237 } 238 if h, ok := p.record.Content.(*handshake.Handshake); ok { 239 h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) 240 s.state.handshakeSendSequence++ 241 } 242 } 243 if epoch != nextEpoch { 244 s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) 245 c.setLocalEpoch(nextEpoch) 246 } 247 return handshakeSending, nil 248 } 249 250 func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { 251 // Send flights 252 if err := c.writePackets(ctx, s.flights); err != nil { 253 return handshakeErrored, err 254 } 255 256 if s.currentFlight.isLastSendFlight() { 257 return handshakeFinished, nil 258 } 259 return handshakeWaiting, nil 260 } 261 262 func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit 263 parse, errFlight := s.currentFlight.getFlightParser() 264 if errFlight != nil { 265 if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { 266 if errFlight != nil { 267 return handshakeErrored, alertErr 268 } 269 } 270 return handshakeErrored, errFlight 271 } 272 273 retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) 274 for { 275 select { 276 case done := <-c.recvHandshake(): 277 nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) 278 close(done) 279 if alert != nil { 280 if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { 281 if err != nil { 282 err = alertErr 283 } 284 } 285 } 286 if err != nil { 287 return handshakeErrored, err 288 } 289 if nextFlight == 0 { 290 break 291 } 292 s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String()) 293 if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { 294 return handshakeFinished, nil 295 } 296 s.currentFlight = nextFlight 297 return handshakePreparing, nil 298 299 case <-retransmitTimer.C: 300 if !s.retransmit { 301 return handshakeWaiting, nil 302 } 303 return handshakeSending, nil 304 case <-ctx.Done(): 305 return handshakeErrored, ctx.Err() 306 } 307 } 308 } 309 310 func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { 311 parse, errFlight := s.currentFlight.getFlightParser() 312 if errFlight != nil { 313 if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { 314 if errFlight != nil { 315 return handshakeErrored, alertErr 316 } 317 } 318 return handshakeErrored, errFlight 319 } 320 321 retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) 322 select { 323 case done := <-c.recvHandshake(): 324 nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) 325 close(done) 326 if alert != nil { 327 if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { 328 if err != nil { 329 err = alertErr 330 } 331 } 332 } 333 if err != nil { 334 return handshakeErrored, err 335 } 336 if nextFlight == 0 { 337 break 338 } 339 if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { 340 return handshakeFinished, nil 341 } 342 <-retransmitTimer.C 343 // Retransmit last flight 344 return handshakeSending, nil 345 346 case <-ctx.Done(): 347 return handshakeErrored, ctx.Err() 348 } 349 return handshakeFinished, nil 350 }