github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/sink/kafka/v2/gssapi_test.go (about)

     1  // Copyright 2023 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package v2
    15  
    16  import (
    17  	"context"
    18  	"encoding/binary"
    19  	"encoding/hex"
    20  	"testing"
    21  
    22  	"github.com/golang/mock/gomock"
    23  	"github.com/jcmturner/gokrb5/v8/credentials"
    24  	"github.com/jcmturner/gokrb5/v8/gssapi"
    25  	"github.com/jcmturner/gokrb5/v8/iana/nametype"
    26  	"github.com/jcmturner/gokrb5/v8/messages"
    27  	"github.com/jcmturner/gokrb5/v8/types"
    28  	"github.com/pingcap/errors"
    29  	mock "github.com/pingcap/tiflow/pkg/sink/kafka/v2/mock"
    30  	"github.com/segmentio/kafka-go/sasl"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  const (
    35  	// What a kerberized server might send
    36  	testChallengeFromAcceptor = "050401ff000c0000000000" +
    37  		"00575e85d601010000853b728d5268525a1386c19f"
    38  	// session key used to sign the tokens above
    39  	sessionKey     = "14f9bde6b50ec508201a97f74c4e5bd3"
    40  	sessionKeyType = 17
    41  )
    42  
    43  func getSessionKey() types.EncryptionKey {
    44  	key, _ := hex.DecodeString(sessionKey)
    45  	return types.EncryptionKey{
    46  		KeyType:  sessionKeyType,
    47  		KeyValue: key,
    48  	}
    49  }
    50  
    51  func getChallengeReference() *gssapi.WrapToken {
    52  	challenge, _ := hex.DecodeString(testChallengeFromAcceptor)
    53  	return &gssapi.WrapToken{
    54  		Flags:     0x01,
    55  		EC:        12,
    56  		RRC:       0,
    57  		SndSeqNum: binary.BigEndian.Uint64(challenge[8:16]),
    58  		Payload:   []byte{0x01, 0x01, 0x00, 0x00},
    59  		CheckSum:  challenge[20:32],
    60  	}
    61  }
    62  
    63  func TestStart(t *testing.T) {
    64  	ctrl := gomock.NewController(t)
    65  	client := mock.NewMockGokrb5v8Client(ctrl)
    66  	m := Gokrb5v8(client, "kafka")
    67  	require.Equal(t, "GSSAPI", m.Name())
    68  	_, _, err := m.Start(sasl.WithMetadata(context.Background(),
    69  		&sasl.Metadata{}))
    70  	require.Equal(t, err, StartWithoutHostError{})
    71  
    72  	ctx := sasl.WithMetadata(context.Background(), &sasl.Metadata{
    73  		Host: "localhost",
    74  		Port: 9092,
    75  	})
    76  	client.EXPECT().GetServiceTicket(gomock.Any()).Return(
    77  		messages.Ticket{}, getSessionKey(), errors.New("fake errors"))
    78  	_, _, err = m.Start(ctx)
    79  	require.Contains(t, "fake errors", err.Error())
    80  
    81  	client.EXPECT().GetServiceTicket(gomock.Any()).Return(
    82  		messages.Ticket{
    83  			SName: types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL,
    84  				"kafka@EXAMPLE.COM"),
    85  		}, getSessionKey(), nil)
    86  	client.EXPECT().Credentials().AnyTimes().
    87  		Return(credentials.New("user", "kafka"))
    88  	stm, data, err := m.Start(ctx)
    89  	require.NotNil(t, stm)
    90  	require.NotNil(t, data)
    91  	require.Nil(t, err)
    92  }
    93  
    94  func TestNext(t *testing.T) {
    95  	session := gokrb5v8Session{key: getSessionKey()}
    96  	wrapToken := getChallengeReference()
    97  	data, err := wrapToken.Marshal()
    98  	require.Nil(t, err)
    99  	ok, resp, err := session.Next(context.Background(), data)
   100  	require.False(t, ok)
   101  	require.NotNil(t, resp)
   102  	require.Nil(t, err)
   103  	ok, resp, err = session.Next(context.Background(), data)
   104  	require.True(t, ok)
   105  	require.Nil(t, resp)
   106  	require.Nil(t, err)
   107  }