github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/sink/kafka/v2/factory_test.go (about) 1 // Copyright 2023 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package v2 15 16 import ( 17 "context" 18 "crypto/tls" 19 "testing" 20 21 "github.com/golang/mock/gomock" 22 "github.com/pingcap/errors" 23 "github.com/pingcap/tiflow/cdc/model" 24 cerror "github.com/pingcap/tiflow/pkg/errors" 25 "github.com/pingcap/tiflow/pkg/security" 26 pkafka "github.com/pingcap/tiflow/pkg/sink/kafka" 27 v2mock "github.com/pingcap/tiflow/pkg/sink/kafka/v2/mock" 28 "github.com/pingcap/tiflow/pkg/util" 29 "github.com/segmentio/kafka-go" 30 "github.com/segmentio/kafka-go/sasl/plain" 31 "github.com/stretchr/testify/require" 32 ) 33 34 func newOptions4Test() *pkafka.Options { 35 o := pkafka.NewOptions() 36 o.BrokerEndpoints = []string{"127.0.0.1:9092"} 37 o.ClientID = "kafka-go-test" 38 o.EnableTLS = true 39 o.Credential = &security.Credential{ 40 CAPath: "", 41 CertPath: "", 42 KeyPath: "", 43 CertAllowedCN: []string{""}, 44 } 45 return o 46 } 47 48 func newFactory4Test(o *pkafka.Options, t *testing.T) *factory { 49 f, err := NewFactory(o, model.DefaultChangeFeedID("kafka-go-sink")) 50 require.NoError(t, err) 51 52 return f.(*factory) 53 } 54 55 func TestSyncProducer(t *testing.T) { 56 t.Parallel() 57 58 o := newOptions4Test() 59 factory := newFactory4Test(o, t) 60 61 sync, err := factory.SyncProducer(context.Background()) 62 require.NoError(t, err) 63 64 p, ok := sync.(*syncWriter) 65 require.True(t, ok) 66 require.False(t, p.w.(*kafka.Writer).Async) 67 } 68 69 func TestCompression(t *testing.T) { 70 t.Parallel() 71 72 o := newOptions4Test() 73 factory := newFactory4Test(o, t) 74 factory.newWriter(false) 75 cases := []struct { 76 compression string 77 expected kafka.Compression 78 }{ 79 {"none", 0}, 80 {"gzip", kafka.Gzip}, 81 {"snappy", kafka.Snappy}, 82 {"lz4", kafka.Lz4}, 83 {"zstd", kafka.Zstd}, 84 {"xxxx", 0}, 85 } 86 for _, cs := range cases { 87 o.Compression = cs.compression 88 w := factory.newWriter(false) 89 require.Equal(t, cs.expected, w.Compression) 90 } 91 } 92 93 func TestAsyncProducer(t *testing.T) { 94 t.Parallel() 95 96 o := newOptions4Test() 97 factory := newFactory4Test(o, t) 98 require.Equal( 99 t, factory.transport.TLS, 100 &tls.Config{ 101 MinVersion: tls.VersionTLS12, 102 NextProtos: []string{"h2", "http/1.1"}, 103 }, 104 ) 105 106 ctx := context.Background() 107 async, err := factory.AsyncProducer(ctx, make(chan error, 1)) 108 require.NoError(t, err) 109 110 asyncP, ok := async.(*asyncWriter) 111 w := asyncP.w.(*kafka.Writer) 112 require.True(t, ok) 113 require.True(t, w.Async) 114 115 require.Equal(t, w.ReadTimeout, o.ReadTimeout) 116 require.Equal(t, w.WriteTimeout, o.WriteTimeout) 117 require.Equal(t, w.RequiredAcks, kafka.RequiredAcks(o.RequiredAcks)) 118 require.Equal(t, w.BatchBytes, int64(o.MaxMessageBytes)) 119 } 120 121 func TestAsyncCompletion(t *testing.T) { 122 o := newOptions4Test() 123 factory := newFactory4Test(o, t) 124 ctx := context.Background() 125 async, err := factory.AsyncProducer(ctx, make(chan error, 1)) 126 require.NoError(t, err) 127 asyncP, ok := async.(*asyncWriter) 128 require.True(t, ok) 129 w := asyncP.w.(*kafka.Writer) 130 acked := 0 131 callback := func() { 132 acked++ 133 } 134 msgs := []kafka.Message{ 135 { 136 WriterData: callback, 137 }, 138 { 139 WriterData: callback, 140 }, 141 } 142 w.Completion(msgs, nil) 143 require.Equal(t, 2, acked) 144 asyncP.errorsChan = make(chan error, 2) 145 w.Completion(msgs, errors.New("fake")) 146 require.Equal(t, 1, len(asyncP.errorsChan)) 147 asyncP.errorsChan <- errors.New("fake 2") 148 w.Completion(msgs, errors.New("fake")) 149 require.Equal(t, 2, len(asyncP.errorsChan)) 150 require.Equal(t, 2, acked) 151 } 152 153 func TestNewMetricsCollector(t *testing.T) { 154 require.NotNil(t, NewMetricsCollector(model.DefaultChangeFeedID("1"), util.RoleOwner, nil)) 155 } 156 157 func TestCompleteSASLConfig(t *testing.T) { 158 m, err := completeSASLConfig(&pkafka.Options{ 159 SASL: nil, 160 }) 161 require.Nil(t, m) 162 require.Nil(t, err) 163 m, err = completeSASLConfig(&pkafka.Options{ 164 SASL: &security.SASL{ 165 SASLUser: "user", 166 SASLPassword: "pass", 167 SASLMechanism: pkafka.SASLTypeSCRAMSHA256, 168 }, 169 }) 170 require.Nil(t, err) 171 require.Equal(t, pkafka.SASLTypeSCRAMSHA256, m.Name()) 172 m, err = completeSASLConfig(&pkafka.Options{ 173 SASL: &security.SASL{ 174 SASLUser: "user", 175 SASLPassword: "pass", 176 SASLMechanism: pkafka.SASLTypeSCRAMSHA512, 177 }, 178 }) 179 require.NotNil(t, m) 180 require.Equal(t, pkafka.SASLTypeSCRAMSHA512, m.Name()) 181 require.Nil(t, err) 182 require.Equal(t, pkafka.SASLTypeSCRAMSHA512, m.Name()) 183 m, err = completeSASLConfig(&pkafka.Options{ 184 SASL: &security.SASL{ 185 SASLUser: "user", 186 SASLPassword: "pass", 187 SASLMechanism: pkafka.SASLTypePlaintext, 188 }, 189 }) 190 pm, ok := m.(plain.Mechanism) 191 require.True(t, ok) 192 require.Nil(t, err) 193 require.Equal(t, pkafka.SASLTypePlaintext, m.Name()) 194 require.Equal(t, "user", pm.Username) 195 require.Equal(t, "pass", pm.Password) 196 197 // Unsupported OAUTHBEARER mechanism 198 m, err = completeSASLConfig(&pkafka.Options{ 199 SASL: &security.SASL{ 200 SASLMechanism: security.OAuthMechanism, 201 }, 202 }) 203 require.NotNil(t, err) 204 require.Contains(t, err.Error(), "OAuth is not yet supported in Kafka sink v2") 205 } 206 207 func TestSyncWriterSendMessage(t *testing.T) { 208 mw := v2mock.NewMockWriter(gomock.NewController(t)) 209 w := syncWriter{w: mw} 210 mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()). 211 DoAndReturn(func(ctx context.Context, msgs ...kafka.Message) error { 212 require.Equal(t, 1, len(msgs)) 213 require.Equal(t, 3, msgs[0].Partition) 214 return errors.New("fake") 215 }) 216 require.NotNil(t, w.SendMessage(context.Background(), "topic", 3, []byte{'1'}, []byte{})) 217 } 218 219 func TestSyncWriterSendMessages(t *testing.T) { 220 mw := v2mock.NewMockWriter(gomock.NewController(t)) 221 w := syncWriter{w: mw} 222 mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()). 223 DoAndReturn(func(ctx context.Context, msgs ...kafka.Message) error { 224 require.Equal(t, 3, len(msgs)) 225 return errors.New("fake") 226 }) 227 require.NotNil(t, w.SendMessages(context.Background(), "topic", 3, []byte{'1'}, []byte{})) 228 } 229 230 func TestSyncWriterClose(t *testing.T) { 231 mw := v2mock.NewMockWriter(gomock.NewController(t)) 232 w := syncWriter{w: mw} 233 // close failed,no panic 234 mw.EXPECT().Close().Return(errors.New("fake")) 235 w.Close() 236 // close success 237 mw.EXPECT().Close().Return(nil) 238 w.Close() 239 } 240 241 func TestAsyncWriterAsyncSend(t *testing.T) { 242 mw := v2mock.NewMockWriter(gomock.NewController(t)) 243 w := asyncWriter{w: mw} 244 245 ctx, cancel := context.WithCancel(context.Background()) 246 247 callback := func() {} 248 mw.EXPECT().WriteMessages(gomock.Any(), gomock.Any()).Return(nil) 249 err := w.AsyncSend(ctx, "topic", 1, []byte{'1'}, []byte{}, callback) 250 require.NoError(t, err) 251 252 cancel() 253 254 err = w.AsyncSend(ctx, "topic", 1, []byte{'1'}, []byte{}, callback) 255 require.ErrorIs(t, err, context.Canceled) 256 } 257 258 func TestAsyncProducerErrorChan(t *testing.T) { 259 t.Parallel() 260 261 o := newOptions4Test() 262 factory := newFactory4Test(o, t) 263 264 ctx := context.Background() 265 asyncProducer, err := factory.AsyncProducer(ctx, make(chan error, 1)) 266 require.NoError(t, err) 267 268 mockErr := cerror.New("errors chan error") 269 go func() { 270 err = asyncProducer.AsyncRunCallback(ctx) 271 require.Equal(t, err.Error(), cerror.WrapError(cerror.ErrKafkaAsyncSendMessage, mockErr).Error()) 272 }() 273 274 asyncProducer.(*asyncWriter).errorsChan <- mockErr 275 }