github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/sink/kafka/v2/factory_test.go (about)

     1  // Copyright 2023 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package v2
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"testing"
    20  
    21  	"github.com/golang/mock/gomock"
    22  	"github.com/pingcap/errors"
    23  	"github.com/pingcap/tiflow/cdc/model"
    24  	cerror "github.com/pingcap/tiflow/pkg/errors"
    25  	"github.com/pingcap/tiflow/pkg/security"
    26  	pkafka "github.com/pingcap/tiflow/pkg/sink/kafka"
    27  	v2mock "github.com/pingcap/tiflow/pkg/sink/kafka/v2/mock"
    28  	"github.com/pingcap/tiflow/pkg/util"
    29  	"github.com/segmentio/kafka-go"
    30  	"github.com/segmentio/kafka-go/sasl/plain"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  func newOptions4Test() *pkafka.Options {
    35  	o := pkafka.NewOptions()
    36  	o.BrokerEndpoints = []string{"127.0.0.1:9092"}
    37  	o.ClientID = "kafka-go-test"
    38  	o.EnableTLS = true
    39  	o.Credential = &security.Credential{
    40  		CAPath:        "",
    41  		CertPath:      "",
    42  		KeyPath:       "",
    43  		CertAllowedCN: []string{""},
    44  	}
    45  	return o
    46  }
    47  
    48  func newFactory4Test(o *pkafka.Options, t *testing.T) *factory {
    49  	f, err := NewFactory(o, model.DefaultChangeFeedID("kafka-go-sink"))
    50  	require.NoError(t, err)
    51  
    52  	return f.(*factory)
    53  }
    54  
    55  func TestSyncProducer(t *testing.T) {
    56  	t.Parallel()
    57  
    58  	o := newOptions4Test()
    59  	factory := newFactory4Test(o, t)
    60  
    61  	sync, err := factory.SyncProducer(context.Background())
    62  	require.NoError(t, err)
    63  
    64  	p, ok := sync.(*syncWriter)
    65  	require.True(t, ok)
    66  	require.False(t, p.w.(*kafka.Writer).Async)
    67  }
    68  
    69  func TestCompression(t *testing.T) {
    70  	t.Parallel()
    71  
    72  	o := newOptions4Test()
    73  	factory := newFactory4Test(o, t)
    74  	factory.newWriter(false)
    75  	cases := []struct {
    76  		compression string
    77  		expected    kafka.Compression
    78  	}{
    79  		{"none", 0},
    80  		{"gzip", kafka.Gzip},
    81  		{"snappy", kafka.Snappy},
    82  		{"lz4", kafka.Lz4},
    83  		{"zstd", kafka.Zstd},
    84  		{"xxxx", 0},
    85  	}
    86  	for _, cs := range cases {
    87  		o.Compression = cs.compression
    88  		w := factory.newWriter(false)
    89  		require.Equal(t, cs.expected, w.Compression)
    90  	}
    91  }
    92  
    93  func TestAsyncProducer(t *testing.T) {
    94  	t.Parallel()
    95  
    96  	o := newOptions4Test()
    97  	factory := newFactory4Test(o, t)
    98  	require.Equal(
    99  		t, factory.transport.TLS,
   100  		&tls.Config{
   101  			MinVersion: tls.VersionTLS12,
   102  			NextProtos: []string{"h2", "http/1.1"},
   103  		},
   104  	)
   105  
   106  	ctx := context.Background()
   107  	async, err := factory.AsyncProducer(ctx, make(chan error, 1))
   108  	require.NoError(t, err)
   109  
   110  	asyncP, ok := async.(*asyncWriter)
   111  	w := asyncP.w.(*kafka.Writer)
   112  	require.True(t, ok)
   113  	require.True(t, w.Async)
   114  
   115  	require.Equal(t, w.ReadTimeout, o.ReadTimeout)
   116  	require.Equal(t, w.WriteTimeout, o.WriteTimeout)
   117  	require.Equal(t, w.RequiredAcks, kafka.RequiredAcks(o.RequiredAcks))
   118  	require.Equal(t, w.BatchBytes, int64(o.MaxMessageBytes))
   119  }
   120  
   121  func TestAsyncCompletion(t *testing.T) {
   122  	o := newOptions4Test()
   123  	factory := newFactory4Test(o, t)
   124  	ctx := context.Background()
   125  	async, err := factory.AsyncProducer(ctx, make(chan error, 1))
   126  	require.NoError(t, err)
   127  	asyncP, ok := async.(*asyncWriter)
   128  	require.True(t, ok)
   129  	w := asyncP.w.(*kafka.Writer)
   130  	acked := 0
   131  	callback := func() {
   132  		acked++
   133  	}
   134  	msgs := []kafka.Message{
   135  		{
   136  			WriterData: callback,
   137  		},
   138  		{
   139  			WriterData: callback,
   140  		},
   141  	}
   142  	w.Completion(msgs, nil)
   143  	require.Equal(t, 2, acked)
   144  	asyncP.errorsChan = make(chan error, 2)
   145  	w.Completion(msgs, errors.New("fake"))
   146  	require.Equal(t, 1, len(asyncP.errorsChan))
   147  	asyncP.errorsChan <- errors.New("fake 2")
   148  	w.Completion(msgs, errors.New("fake"))
   149  	require.Equal(t, 2, len(asyncP.errorsChan))
   150  	require.Equal(t, 2, acked)
   151  }
   152  
   153  func TestNewMetricsCollector(t *testing.T) {
   154  	require.NotNil(t, NewMetricsCollector(model.DefaultChangeFeedID("1"), util.RoleOwner, nil))
   155  }
   156  
   157  func TestCompleteSASLConfig(t *testing.T) {
   158  	m, err := completeSASLConfig(&pkafka.Options{
   159  		SASL: nil,
   160  	})
   161  	require.Nil(t, m)
   162  	require.Nil(t, err)
   163  	m, err = completeSASLConfig(&pkafka.Options{
   164  		SASL: &security.SASL{
   165  			SASLUser:      "user",
   166  			SASLPassword:  "pass",
   167  			SASLMechanism: pkafka.SASLTypeSCRAMSHA256,
   168  		},
   169  	})
   170  	require.Nil(t, err)
   171  	require.Equal(t, pkafka.SASLTypeSCRAMSHA256, m.Name())
   172  	m, err = completeSASLConfig(&pkafka.Options{
   173  		SASL: &security.SASL{
   174  			SASLUser:      "user",
   175  			SASLPassword:  "pass",
   176  			SASLMechanism: pkafka.SASLTypeSCRAMSHA512,
   177  		},
   178  	})
   179  	require.NotNil(t, m)
   180  	require.Equal(t, pkafka.SASLTypeSCRAMSHA512, m.Name())
   181  	require.Nil(t, err)
   182  	require.Equal(t, pkafka.SASLTypeSCRAMSHA512, m.Name())
   183  	m, err = completeSASLConfig(&pkafka.Options{
   184  		SASL: &security.SASL{
   185  			SASLUser:      "user",
   186  			SASLPassword:  "pass",
   187  			SASLMechanism: pkafka.SASLTypePlaintext,
   188  		},
   189  	})
   190  	pm, ok := m.(plain.Mechanism)
   191  	require.True(t, ok)
   192  	require.Nil(t, err)
   193  	require.Equal(t, pkafka.SASLTypePlaintext, m.Name())
   194  	require.Equal(t, "user", pm.Username)
   195  	require.Equal(t, "pass", pm.Password)
   196  
   197  	// Unsupported OAUTHBEARER mechanism
   198  	m, err = completeSASLConfig(&pkafka.Options{
   199  		SASL: &security.SASL{
   200  			SASLMechanism: security.OAuthMechanism,
   201  		},
   202  	})
   203  	require.NotNil(t, err)
   204  	require.Contains(t, err.Error(), "OAuth is not yet supported in Kafka sink v2")
   205  }
   206  
   207  func TestSyncWriterSendMessage(t *testing.T) {
   208  	mw := v2mock.NewMockWriter(gomock.NewController(t))
   209  	w := syncWriter{w: mw}
   210  	mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()).
   211  		DoAndReturn(func(ctx context.Context, msgs ...kafka.Message) error {
   212  			require.Equal(t, 1, len(msgs))
   213  			require.Equal(t, 3, msgs[0].Partition)
   214  			return errors.New("fake")
   215  		})
   216  	require.NotNil(t, w.SendMessage(context.Background(), "topic", 3, []byte{'1'}, []byte{}))
   217  }
   218  
   219  func TestSyncWriterSendMessages(t *testing.T) {
   220  	mw := v2mock.NewMockWriter(gomock.NewController(t))
   221  	w := syncWriter{w: mw}
   222  	mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()).
   223  		DoAndReturn(func(ctx context.Context, msgs ...kafka.Message) error {
   224  			require.Equal(t, 3, len(msgs))
   225  			return errors.New("fake")
   226  		})
   227  	require.NotNil(t, w.SendMessages(context.Background(), "topic", 3, []byte{'1'}, []byte{}))
   228  }
   229  
   230  func TestSyncWriterClose(t *testing.T) {
   231  	mw := v2mock.NewMockWriter(gomock.NewController(t))
   232  	w := syncWriter{w: mw}
   233  	// close failed,no panic
   234  	mw.EXPECT().Close().Return(errors.New("fake"))
   235  	w.Close()
   236  	// close success
   237  	mw.EXPECT().Close().Return(nil)
   238  	w.Close()
   239  }
   240  
   241  func TestAsyncWriterAsyncSend(t *testing.T) {
   242  	mw := v2mock.NewMockWriter(gomock.NewController(t))
   243  	w := asyncWriter{w: mw}
   244  
   245  	ctx, cancel := context.WithCancel(context.Background())
   246  
   247  	callback := func() {}
   248  	mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()).Return(nil)
   249  	err := w.AsyncSend(ctx, "topic", 1, []byte{'1'}, []byte{}, callback)
   250  	require.NoError(t, err)
   251  
   252  	cancel()
   253  
   254  	err = w.AsyncSend(ctx, "topic", 1, []byte{'1'}, []byte{}, callback)
   255  	require.ErrorIs(t, err, context.Canceled)
   256  }
   257  
   258  func TestAsyncProducerErrorChan(t *testing.T) {
   259  	t.Parallel()
   260  
   261  	o := newOptions4Test()
   262  	factory := newFactory4Test(o, t)
   263  
   264  	ctx := context.Background()
   265  	asyncProducer, err := factory.AsyncProducer(ctx, make(chan error, 1))
   266  	require.NoError(t, err)
   267  
   268  	mockErr := cerror.New("errors chan error")
   269  	go func() {
   270  		err = asyncProducer.AsyncRunCallback(ctx)
   271  		require.Equal(t, err.Error(), cerror.WrapError(cerror.ErrKafkaAsyncSendMessage, mockErr).Error())
   272  	}()
   273  
   274  	asyncProducer.(*asyncWriter).errorsChan <- mockErr
   275  }