github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/handshake.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/rs/zerolog/log"
    21  
    22  	"github.com/datastax/go-cassandra-native-protocol/frame"
    23  	"github.com/datastax/go-cassandra-native-protocol/message"
    24  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    25  )
    26  
    27  // PerformHandshake performs a handshake between the given client and server connections, using the provided protocol
    28  // version. The handshake will use stream id 1, unless the client connection is in managed mode.
    29  func PerformHandshake(clientConn *CqlClientConnection, serverConn *CqlServerConnection, version primitive.ProtocolVersion, streamId int16) error {
    30  	clientChan := make(chan error)
    31  	serverChan := make(chan error)
    32  	go func() {
    33  		clientChan <- clientConn.InitiateHandshake(version, streamId)
    34  	}()
    35  	go func() {
    36  		serverChan <- serverConn.AcceptHandshake()
    37  	}()
    38  	for clientChan != nil || serverChan != nil {
    39  		select {
    40  		case err := <-clientChan:
    41  			if err != nil {
    42  				return fmt.Errorf("client handshake failed: %w", err)
    43  			}
    44  			clientChan = nil
    45  		case err := <-serverChan:
    46  			if err != nil {
    47  				return fmt.Errorf("server handshake failed %w", err)
    48  			}
    49  			serverChan = nil
    50  		}
    51  	}
    52  	return nil
    53  }
    54  
    55  // InitiateHandshake initiates the handshake procedure to initialize the client connection, using the given protocol
    56  // version. The handshake will use authentication if the connection was created with auth credentials; otherwise it will
    57  // proceed without authentication. Use stream id zero to activate automatic stream id management.
    58  func (c *CqlClientConnection) InitiateHandshake(version primitive.ProtocolVersion, streamId int16) (err error) {
    59  	log.Debug().Msgf("%v: performing handshake", c)
    60  	if startup, err := c.NewStartupRequest(version, streamId); err != nil {
    61  		return err
    62  	} else {
    63  		var response *frame.Frame
    64  		if response, err = c.SendAndReceive(startup); err == nil {
    65  			if c.credentials == nil {
    66  				if _, authSuccess := response.Body.Message.(*message.Ready); !authSuccess {
    67  					err = fmt.Errorf("expected READY, got %v", response.Body.Message)
    68  				}
    69  			} else {
    70  				switch msg := response.Body.Message.(type) {
    71  				case *message.Ready:
    72  					log.Warn().Msgf("%v: expected AUTHENTICATE, got READY – is authentication required?", c)
    73  					break
    74  				case *message.Authenticate:
    75  					authenticator := &PlainTextAuthenticator{c.credentials}
    76  					var initialResponse []byte
    77  					if initialResponse, err = authenticator.InitialResponse(msg.Authenticator); err == nil {
    78  						authResponse := frame.NewFrame(version, streamId, &message.AuthResponse{Token: initialResponse})
    79  						if response, err = c.SendAndReceive(authResponse); err != nil {
    80  							err = fmt.Errorf("could not send AUTH RESPONSE: %w", err)
    81  						} else {
    82  							switch msg := response.Body.Message.(type) {
    83  							case *message.AuthSuccess:
    84  								break
    85  							case *message.AuthChallenge:
    86  								var challenge []byte
    87  								if challenge, err = authenticator.EvaluateChallenge(msg.Token); err == nil {
    88  									authResponse := frame.NewFrame(version, streamId, &message.AuthResponse{Token: challenge})
    89  									if response, err = c.SendAndReceive(authResponse); err != nil {
    90  										err = fmt.Errorf("could not send AUTH RESPONSE: %w", err)
    91  									} else if _, authSuccess := response.Body.Message.(*message.AuthSuccess); !authSuccess {
    92  										err = fmt.Errorf("expected AUTH_SUCCESS, got %v", response.Body.Message)
    93  									}
    94  								}
    95  							default:
    96  								err = fmt.Errorf("expected AUTH_CHALLENGE or AUTH_SUCCESS, got %v", response.Body.Message)
    97  							}
    98  						}
    99  					}
   100  				default:
   101  					err = fmt.Errorf("expected AUTHENTICATE or READY, got %v", response.Body.Message)
   102  				}
   103  			}
   104  		}
   105  		if err == nil {
   106  			log.Info().Msgf("%v: handshake successful", c)
   107  		} else {
   108  			log.Error().Err(err).Msgf("%v: handshake failed", c)
   109  		}
   110  		return err
   111  	}
   112  }
   113  
   114  // AcceptHandshake Listens for a client STARTUP request and proceeds with the server-side handshake procedure.
   115  // Authentication will be required if the connection was created with auth credentials; otherwise the handshake will
   116  // proceed without authentication.
   117  // This method is intended for use when server-side handshake should be triggered manually. For automatic server-side
   118  // handshake, consider using HandshakeHandler instead.
   119  func (c *CqlServerConnection) AcceptHandshake() (err error) {
   120  	log.Debug().Msgf("%v: performing handshake", c)
   121  	var request *frame.Frame
   122  	authSuccess := false
   123  	done := false
   124  	for !done && err == nil {
   125  		if request, err = c.Receive(); err == nil {
   126  			switch request.Body.Message.(type) {
   127  			case *message.Options:
   128  				supported := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Supported{})
   129  				err = c.Send(supported)
   130  				continue
   131  			case *message.Startup:
   132  				if c.credentials == nil {
   133  					authSuccess = true
   134  					ready := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Ready{})
   135  					err = c.Send(ready)
   136  				} else {
   137  					authenticate := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.Authenticate{Authenticator: "org.apache.cassandra.auth.PasswordAuthenticator"})
   138  					if err = c.Send(authenticate); err == nil {
   139  						if request, err = c.Receive(); err == nil {
   140  							if authResponse, ok := request.Body.Message.(*message.AuthResponse); !ok {
   141  								err = fmt.Errorf("expected AUTH RESPONSE, got %v", request.Body.Message)
   142  							} else {
   143  								credentials := &AuthCredentials{}
   144  								if err = credentials.Unmarshal(authResponse.Token); err == nil {
   145  									if credentials.Username == c.credentials.Username && credentials.Password == c.credentials.Password {
   146  										authSuccess = true
   147  										authSuccess := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.AuthSuccess{})
   148  										err = c.Send(authSuccess)
   149  									} else {
   150  										authError := frame.NewFrame(request.Header.Version, request.Header.StreamId, &message.AuthenticationError{ErrorMessage: "invalid credentials"})
   151  										err = c.Send(authError)
   152  									}
   153  								}
   154  							}
   155  						}
   156  					}
   157  				}
   158  				done = true
   159  			default:
   160  				err = fmt.Errorf("expected STARTUP or OPTIONS, got %v", request.Body.Message)
   161  				done = true
   162  			}
   163  		}
   164  	}
   165  	if err == nil {
   166  		if authSuccess {
   167  			log.Info().Msgf("%v: handshake successful", c)
   168  		} else {
   169  			log.Error().Msgf("%v: authentication error: invalid credentials", c)
   170  		}
   171  	} else {
   172  		log.Error().Err(err).Msgf("%v: handshake failed", c)
   173  	}
   174  	return err
   175  }
   176  
   177  const (
   178  	handshakeStateKey     = "HANDSHAKE"
   179  	handshakeStateStarted = "STARTED"
   180  	handshakeStateDone    = "DONE"
   181  )
   182  
   183  // HandshakeHandler is a RequestHandler to handle server-side handshakes. This is an alternative to
   184  // CqlServerConnection.AcceptHandshake to make the server connection automatically handle all incoming handshake
   185  // attempts.
   186  var HandshakeHandler RequestHandler = func(request *frame.Frame, conn *CqlServerConnection, ctx RequestHandlerContext) (response *frame.Frame) {
   187  	if ctx.GetAttribute(handshakeStateKey) == handshakeStateDone {
   188  		return
   189  	}
   190  	version := request.Header.Version
   191  	id := request.Header.StreamId
   192  	switch msg := request.Body.Message.(type) {
   193  	case *message.Options:
   194  		log.Debug().Msgf("%v: [handshake handler]: intercepted OPTIONS before STARTUP", conn)
   195  		response = frame.NewFrame(version, id, &message.Supported{})
   196  	case *message.Startup:
   197  		if conn.Credentials() == nil {
   198  			ctx.PutAttribute(handshakeStateKey, handshakeStateDone)
   199  			log.Info().Msgf("%v: [handshake handler]: handshake successful", conn)
   200  			response = frame.NewFrame(version, id, &message.Ready{})
   201  		} else {
   202  			ctx.PutAttribute(handshakeStateKey, handshakeStateStarted)
   203  			response = frame.NewFrame(version, id, &message.Authenticate{Authenticator: "org.apache.cassandra.auth.PasswordAuthenticator"})
   204  		}
   205  	case *message.AuthResponse:
   206  		if ctx.GetAttribute(handshakeStateKey) == handshakeStateStarted {
   207  			userCredentials := &AuthCredentials{}
   208  			if err := userCredentials.Unmarshal(msg.Token); err == nil {
   209  				serverCredentials := conn.Credentials()
   210  				if userCredentials.Username == serverCredentials.Username &&
   211  					userCredentials.Password == serverCredentials.Password {
   212  					log.Info().Msgf("%v: [handshake handler]: handshake successful", conn)
   213  					response = frame.NewFrame(version, id, &message.AuthSuccess{})
   214  				} else {
   215  					log.Error().Msgf("%v: [handshake handler]: authentication error: invalid credentials", conn)
   216  					response = frame.NewFrame(version, id, &message.AuthenticationError{ErrorMessage: "invalid credentials"})
   217  				}
   218  				ctx.PutAttribute(handshakeStateKey, handshakeStateDone)
   219  			}
   220  		} else {
   221  			ctx.PutAttribute(handshakeStateKey, handshakeStateDone)
   222  			log.Error().Msgf("%v: [handshake handler]: expected STARTUP, got AUTH_RESPONSE", conn)
   223  			response = frame.NewFrame(version, id, &message.ProtocolError{ErrorMessage: "handshake failed"})
   224  		}
   225  	default:
   226  		ctx.PutAttribute(handshakeStateKey, handshakeStateDone)
   227  		log.Error().Msgf("%v: [handshake handler]: expected OPTIONS, STARTUP or AUTH_RESPONSE, got %v", conn, msg)
   228  		response = frame.NewFrame(version, id, &message.ProtocolError{ErrorMessage: "handshake failed"})
   229  	}
   230  	return
   231  }
   232  
   233  func isReady(f *frame.Frame) bool {
   234  	_, ok := f.Body.Message.(*message.Ready)
   235  	return ok
   236  }
   237  
   238  func isAuthenticate(f *frame.Frame) bool {
   239  	_, ok := f.Body.Message.(*message.Authenticate)
   240  	return ok
   241  }