github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/handshake.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client 16 17 import ( 18 "fmt" 19 20 "github.com/rs/zerolog/log" 21 22 "github.com/datastax/go-cassandra-native-protocol/frame" 23 "github.com/datastax/go-cassandra-native-protocol/message" 24 "github.com/datastax/go-cassandra-native-protocol/primitive" 25 ) 26 27 // PerformHandshake performs a handshake between the given client and server connections, using the provided protocol 28 // version. The handshake will use stream id 1, unless the client connection is in managed mode. 29 func PerformHandshake(clientConn *CqlClientConnection, serverConn *CqlServerConnection, version primitive.ProtocolVersion, streamId int16) error { 30 clientChan := make(chan error) 31 serverChan := make(chan error) 32 go func() { 33 clientChan <- clientConn.InitiateHandshake(version, streamId) 34 }() 35 go func() { 36 serverChan <- serverConn.AcceptHandshake() 37 }() 38 for clientChan != nil || serverChan != nil { 39 select { 40 case err := <-clientChan: 41 if err != nil { 42 return fmt.Errorf("client handshake failed: %w", err) 43 } 44 clientChan = nil 45 case err := <-serverChan: 46 if err != nil { 47 return fmt.Errorf("server handshake failed %w", err) 48 } 49 serverChan = nil 50 } 51 } 52 return nil 53 } 54 55 // InitiateHandshake initiates the handshake procedure to initialize the client connection, using the given protocol 56 // version. The handshake will use authentication if the connection was created with auth credentials; otherwise it will 57 // proceed without authentication. Use stream id zero to activate automatic stream id management. 58 func (c *CqlClientConnection) InitiateHandshake(version primitive.ProtocolVersion, streamId int16) (err error) { 59 log.Debug().Msgf("%v: performing handshake", c) 60 if startup, err := c.NewStartupRequest(version, streamId); err != nil { 61 return err 62 } else { 63 var response *frame.Frame 64 if response, err = c.SendAndReceive(startup); err == nil { 65 if c.credentials == nil { 66 if _, authSuccess := response.Body.Message.(*message.Ready); !authSuccess { 67 err = fmt.Errorf("expected READY, got %v", response.Body.Message) 68 } 69 } else { 70 switch msg := response.Body.Message.(type) { 71 case *message.Ready: 72 log.Warn().Msgf("%v: expected AUTHENTICATE, got READY – is authentication required?", c) 73 break 74 case *message.Authenticate: 75 authenticator := &PlainTextAuthenticator{c.credentials} 76 var initialResponse []byte 77 if initialResponse, err = authenticator.InitialResponse(msg.Authenticator); err == nil { 78 authResponse := frame.NewFrame(version, streamId, &message.AuthResponse{Token: initialResponse}) 79 if response, err = c.SendAndReceive(authResponse); err != nil { 80 err = fmt.Errorf("could not send AUTH RESPONSE: %w", err) 81 } else { 82 switch msg := response.Body.Message.(type) { 83 case *message.AuthSuccess: 84 break 85 case *message.AuthChallenge: 86 var challenge []byte 87 if challenge, err = authenticator.EvaluateChallenge(msg.Token); err == nil { 88 authResponse := frame.NewFrame(version, streamId, &message.AuthResponse{Token: challenge}) 89 if response, err = c.SendAndReceive(authResponse); err != nil { 90 err = fmt.Errorf("could not send AUTH RESPONSE: %w", err) 91 } else if _, authSuccess := response.Body.Message.(*message.AuthSuccess); !authSuccess { 92 err = fmt.Errorf("expected AUTH_SUCCESS, got %v", response.Body.Message) 93 } 94 } 95 default: 96 err = fmt.Errorf("expected AUTH_CHALLENGE or AUTH_SUCCESS, got %v", response.Body.Message) 97 } 98 } 99 } 100 default: 101 err = fmt.Errorf("expected AUTHENTICATE or READY, got %v", response.Body.Message) 102 } 103 } 104 } 105 if err == nil { 106 log.Info().Msgf("%v: handshake successful", c) 107 } else { 108 log.Error().Err(err).Msgf("%v: handshake failed", c) 109 } 110 return err 111 } 112 } 113 114 // AcceptHandshake Listens for a client STARTUP request and proceeds with the server-side handshake procedure. 115 // Authentication will be required if the connection was created with auth credentials; otherwise the handshake will 116 // proceed without authentication. 117 // This method is intended for use when server-side handshake should be triggered manually. For automatic server-side 118 // handshake, consider using HandshakeHandler instead. 119 func (c *CqlServerConnection) AcceptHandshake() (err error) { 120 log.Debug().Msgf("%v: performing handshake", c) 121 var request *frame.Frame 122 authSuccess := false 123 done := false 124 for !done && err == nil { 125 if request, err = c.Receive(); err == nil { 126 switch request.Body.Message.(type) { 127 case *message.Options: 128 supported := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Supported{}) 129 err = c.Send(supported) 130 continue 131 case *message.Startup: 132 if c.credentials == nil { 133 authSuccess = true 134 ready := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{}) 135 err = c.Send(ready) 136 } else { 137 authenticate := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Authenticate{Authenticator: "org.apache.cassandra.auth.PasswordAuthenticator"}) 138 if err = c.Send(authenticate); err == nil { 139 if request, err = c.Receive(); err == nil { 140 if authResponse, ok := request.Body.Message.(*message.AuthResponse); !ok { 141 err = fmt.Errorf("expected AUTH RESPONSE, got %v", request.Body.Message) 142 } else { 143 credentials := &AuthCredentials{} 144 if err = credentials.Unmarshal(authResponse.Token); err == nil { 145 if credentials.Username == c.credentials.Username && credentials.Password == c.credentials.Password { 146 authSuccess = true 147 authSuccess := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.AuthSuccess{}) 148 err = c.Send(authSuccess) 149 } else { 150 authError := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.AuthenticationError{ErrorMessage: "invalid credentials"}) 151 err = c.Send(authError) 152 } 153 } 154 } 155 } 156 } 157 } 158 done = true 159 default: 160 err = fmt.Errorf("expected STARTUP or OPTIONS, got %v", request.Body.Message) 161 done = true 162 } 163 } 164 } 165 if err == nil { 166 if authSuccess { 167 log.Info().Msgf("%v: handshake successful", c) 168 } else { 169 log.Error().Msgf("%v: authentication error: invalid credentials", c) 170 } 171 } else { 172 log.Error().Err(err).Msgf("%v: handshake failed", c) 173 } 174 return err 175 } 176 177 const ( 178 handshakeStateKey = "HANDSHAKE" 179 handshakeStateStarted = "STARTED" 180 handshakeStateDone = "DONE" 181 ) 182 183 // HandshakeHandler is a RequestHandler to handle server-side handshakes. This is an alternative to 184 // CqlServerConnection.AcceptHandshake to make the server connection automatically handle all incoming handshake 185 // attempts. 186 var HandshakeHandler RequestHandler = func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (response *frame.Frame) { 187 if ctx.GetAttribute(handshakeStateKey) == handshakeStateDone { 188 return 189 } 190 version := request.Header.Version 191 id := request.Header.StreamId 192 switch msg := request.Body.Message.(type) { 193 case *message.Options: 194 log.Debug().Msgf("%v: [handshake handler]: intercepted OPTIONS before STARTUP", conn) 195 response = frame.NewFrame(version, id, &message.Supported{}) 196 case *message.Startup: 197 if conn.Credentials() == nil { 198 ctx.PutAttribute(handshakeStateKey, handshakeStateDone) 199 log.Info().Msgf("%v: [handshake handler]: handshake successful", conn) 200 response = frame.NewFrame(version, id, &message.Ready{}) 201 } else { 202 ctx.PutAttribute(handshakeStateKey, handshakeStateStarted) 203 response = frame.NewFrame(version, id, &message.Authenticate{Authenticator: "org.apache.cassandra.auth.PasswordAuthenticator"}) 204 } 205 case *message.AuthResponse: 206 if ctx.GetAttribute(handshakeStateKey) == handshakeStateStarted { 207 userCredentials := &AuthCredentials{} 208 if err := userCredentials.Unmarshal(msg.Token); err == nil { 209 serverCredentials := conn.Credentials() 210 if userCredentials.Username == serverCredentials.Username && 211 userCredentials.Password == serverCredentials.Password { 212 log.Info().Msgf("%v: [handshake handler]: handshake successful", conn) 213 response = frame.NewFrame(version, id, &message.AuthSuccess{}) 214 } else { 215 log.Error().Msgf("%v: [handshake handler]: authentication error: invalid credentials", conn) 216 response = frame.NewFrame(version, id, &message.AuthenticationError{ErrorMessage: "invalid credentials"}) 217 } 218 ctx.PutAttribute(handshakeStateKey, handshakeStateDone) 219 } 220 } else { 221 ctx.PutAttribute(handshakeStateKey, handshakeStateDone) 222 log.Error().Msgf("%v: [handshake handler]: expected STARTUP, got AUTH_RESPONSE", conn) 223 response = frame.NewFrame(version, id, &message.ProtocolError{ErrorMessage: "handshake failed"}) 224 } 225 default: 226 ctx.PutAttribute(handshakeStateKey, handshakeStateDone) 227 log.Error().Msgf("%v: [handshake handler]: expected OPTIONS, STARTUP or AUTH_RESPONSE, got %v", conn, msg) 228 response = frame.NewFrame(version, id, &message.ProtocolError{ErrorMessage: "handshake failed"}) 229 } 230 return 231 } 232 233 func isReady(f *frame.Frame) bool { 234 _, ok := f.Body.Message.(*message.Ready) 235 return ok 236 } 237 238 func isAuthenticate(f *frame.Frame) bool { 239 _, ok := f.Body.Message.(*message.Authenticate) 240 return ok 241 }