go.temporal.io/server@v1.23.0/common/persistence/persistence-tests/persistence_test_base.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package persistencetests 26 27 import ( 28 "context" 29 "fmt" 30 "math/rand" 31 "strings" 32 "sync/atomic" 33 "time" 34 35 "github.com/stretchr/testify/suite" 36 persistencespb "go.temporal.io/server/api/persistence/v1" 37 replicationspb "go.temporal.io/server/api/replication/v1" 38 "go.temporal.io/server/common" 39 "go.temporal.io/server/common/backoff" 40 "go.temporal.io/server/common/clock" 41 "go.temporal.io/server/common/cluster" 42 "go.temporal.io/server/common/config" 43 "go.temporal.io/server/common/dynamicconfig" 44 "go.temporal.io/server/common/log" 45 "go.temporal.io/server/common/log/tag" 46 "go.temporal.io/server/common/metrics" 47 "go.temporal.io/server/common/persistence" 48 "go.temporal.io/server/common/persistence/cassandra" 49 "go.temporal.io/server/common/persistence/client" 50 "go.temporal.io/server/common/persistence/serialization" 51 "go.temporal.io/server/common/persistence/sql" 52 "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql" 53 "go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql" 54 "go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite" 55 "go.temporal.io/server/common/persistence/visibility" 56 "go.temporal.io/server/common/quotas" 57 "go.temporal.io/server/common/resolver" 58 "go.temporal.io/server/common/searchattribute" 59 "go.temporal.io/server/environment" 60 ) 61 62 // TimePrecision is needed to account for database timestamp precision. 63 // Cassandra only provides milliseconds timestamp precision, so we need to use tolerance when doing comparison 64 const TimePrecision = 2 * time.Millisecond 65 66 type ( 67 // TransferTaskIDGenerator generates IDs for transfer tasks written by helper methods 68 TransferTaskIDGenerator interface { 69 GenerateTransferTaskID() (int64, error) 70 } 71 72 // TestBaseOptions options to configure workflow test base. 73 TestBaseOptions struct { 74 SQLDBPluginName string 75 DBName string 76 DBUsername string 77 DBPassword string 78 DBHost string 79 DBPort int `yaml:"-"` 80 ConnectAttributes map[string]string 81 StoreType string `yaml:"-"` 82 SchemaDir string `yaml:"-"` 83 FaultInjection *config.FaultInjection `yaml:"faultinjection"` 84 Logger log.Logger `yaml:"-"` 85 } 86 87 // TestBase wraps the base setup needed to create workflows over persistence layer. 88 TestBase struct { 89 suite.Suite 90 ShardMgr persistence.ShardManager 91 AbstractDataStoreFactory client.AbstractDataStoreFactory 92 VisibilityStoreFactory visibility.VisibilityStoreFactory 93 FaultInjection *client.FaultInjectionDataStoreFactory 94 Factory client.Factory 95 ExecutionManager persistence.ExecutionManager 96 TaskMgr persistence.TaskManager 97 ClusterMetadataManager persistence.ClusterMetadataManager 98 MetadataManager persistence.MetadataManager 99 NamespaceReplicationQueue persistence.NamespaceReplicationQueue 100 ShardInfo *persistencespb.ShardInfo 101 TaskIDGenerator TransferTaskIDGenerator 102 ClusterMetadata cluster.Metadata 103 SearchAttributesManager searchattribute.Manager 104 PersistenceRateLimiter quotas.RequestRateLimiter 105 PersistenceHealthSignals persistence.HealthSignalAggregator 106 ReadLevel int64 107 ReplicationReadLevel int64 108 DefaultTestCluster PersistenceTestCluster 109 Logger log.Logger 110 } 111 112 // PersistenceTestCluster exposes management operations on a database 113 PersistenceTestCluster interface { 114 SetupTestDatabase() 115 TearDownTestDatabase() 116 Config() config.Persistence 117 } 118 119 // TestTransferTaskIDGenerator helper 120 TestTransferTaskIDGenerator struct { 121 seqNum int64 122 } 123 ) 124 125 // NewTestBaseWithCassandra returns a persistence test base backed by cassandra datastore 126 func NewTestBaseWithCassandra(options *TestBaseOptions) *TestBase { 127 logger := log.NewTestLogger() 128 testCluster := NewTestClusterForCassandra(options, logger) 129 return NewTestBaseForCluster(testCluster, logger) 130 } 131 132 func NewTestClusterForCassandra(options *TestBaseOptions, logger log.Logger) *cassandra.TestCluster { 133 if options.DBName == "" { 134 options.DBName = "test_" + GenerateRandomDBName(3) 135 } 136 testCluster := cassandra.NewTestCluster(options.DBName, options.DBUsername, options.DBPassword, options.DBHost, options.DBPort, options.SchemaDir, options.FaultInjection, logger) 137 return testCluster 138 } 139 140 // NewTestBaseWithSQL returns a new persistence test base backed by SQL 141 func NewTestBaseWithSQL(options *TestBaseOptions) *TestBase { 142 if options.DBName == "" { 143 options.DBName = "test_" + GenerateRandomDBName(3) 144 } 145 logger := options.Logger 146 if logger == nil { 147 logger = log.NewTestLogger() 148 } 149 150 if options.DBPort == 0 { 151 switch options.SQLDBPluginName { 152 case mysql.PluginName, mysql.PluginNameV8: 153 options.DBPort = environment.GetMySQLPort() 154 case postgresql.PluginName, postgresql.PluginNamePGX, postgresql.PluginNameV12, postgresql.PluginNameV12PGX: 155 options.DBPort = environment.GetPostgreSQLPort() 156 case sqlite.PluginName: 157 options.DBPort = 0 158 default: 159 panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName)) 160 } 161 } 162 if options.DBHost == "" { 163 switch options.SQLDBPluginName { 164 case mysql.PluginName, mysql.PluginNameV8: 165 options.DBHost = environment.GetMySQLAddress() 166 case postgresql.PluginName, postgresql.PluginNamePGX: 167 options.DBHost = environment.GetPostgreSQLAddress() 168 case sqlite.PluginName: 169 options.DBHost = environment.GetLocalhostIP() 170 default: 171 panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName)) 172 } 173 } 174 testCluster := sql.NewTestCluster(options.SQLDBPluginName, options.DBName, options.DBUsername, options.DBPassword, options.DBHost, options.DBPort, options.ConnectAttributes, options.SchemaDir, options.FaultInjection, logger) 175 return NewTestBaseForCluster(testCluster, logger) 176 } 177 178 // NewTestBase returns a persistence test base backed by either cassandra or sql 179 func NewTestBase(options *TestBaseOptions) *TestBase { 180 switch options.StoreType { 181 case config.StoreTypeSQL: 182 return NewTestBaseWithSQL(options) 183 case config.StoreTypeNoSQL: 184 return NewTestBaseWithCassandra(options) 185 default: 186 panic("invalid storeType " + options.StoreType) 187 } 188 } 189 190 func NewTestBaseForCluster(testCluster PersistenceTestCluster, logger log.Logger) *TestBase { 191 return &TestBase{ 192 DefaultTestCluster: testCluster, 193 Logger: logger, 194 } 195 } 196 197 // Setup sets up the test base, must be called as part of SetupSuite 198 func (s *TestBase) Setup(clusterMetadataConfig *cluster.Config) { 199 var err error 200 shardID := int32(10) 201 if clusterMetadataConfig == nil { 202 clusterMetadataConfig = cluster.NewTestClusterMetadataConfig(false, false) 203 } 204 if s.PersistenceHealthSignals == nil { 205 s.PersistenceHealthSignals = persistence.NoopHealthSignalAggregator 206 } 207 208 clusterName := clusterMetadataConfig.CurrentClusterName 209 210 s.DefaultTestCluster.SetupTestDatabase() 211 212 cfg := s.DefaultTestCluster.Config() 213 dataStoreFactory, faultInjection := client.DataStoreFactoryProvider( 214 client.ClusterName(clusterName), 215 resolver.NewNoopResolver(), 216 &cfg, 217 s.AbstractDataStoreFactory, 218 s.Logger, 219 metrics.NoopMetricsHandler, 220 ) 221 factory := client.NewFactory(dataStoreFactory, &cfg, s.PersistenceRateLimiter, serialization.NewSerializer(), nil, clusterName, metrics.NoopMetricsHandler, s.Logger, s.PersistenceHealthSignals) 222 223 s.TaskMgr, err = factory.NewTaskManager() 224 s.fatalOnError("NewTaskManager", err) 225 226 s.ClusterMetadataManager, err = factory.NewClusterMetadataManager() 227 s.fatalOnError("NewClusterMetadataManager", err) 228 229 s.ClusterMetadata = cluster.NewMetadataFromConfig(clusterMetadataConfig, s.ClusterMetadataManager, dynamicconfig.NewNoopCollection(), s.Logger) 230 s.SearchAttributesManager = searchattribute.NewManager(clock.NewRealTimeSource(), s.ClusterMetadataManager, dynamicconfig.GetBoolPropertyFn(true)) 231 232 s.MetadataManager, err = factory.NewMetadataManager() 233 s.fatalOnError("NewMetadataManager", err) 234 235 s.ShardMgr, err = factory.NewShardManager() 236 s.fatalOnError("NewShardManager", err) 237 238 s.ExecutionManager, err = factory.NewExecutionManager() 239 s.fatalOnError("NewExecutionManager", err) 240 241 s.Factory = factory 242 s.FaultInjection = faultInjection 243 244 s.ReadLevel = 0 245 s.ReplicationReadLevel = 0 246 s.ShardInfo = &persistencespb.ShardInfo{ 247 ShardId: shardID, 248 RangeId: 0, 249 } 250 251 s.TaskIDGenerator = &TestTransferTaskIDGenerator{} 252 _, err = s.ShardMgr.GetOrCreateShard(context.Background(), &persistence.GetOrCreateShardRequest{ 253 ShardID: shardID, 254 InitialShardInfo: s.ShardInfo, 255 }) 256 s.fatalOnError("CreateShard", err) 257 258 queue, err := factory.NewNamespaceReplicationQueue() 259 s.fatalOnError("Create NamespaceReplicationQueue", err) 260 s.NamespaceReplicationQueue = queue 261 } 262 263 func (s *TestBase) fatalOnError(msg string, err error) { 264 if err != nil { 265 s.Logger.Fatal(msg, tag.Error(err)) 266 } 267 } 268 269 // TearDownWorkflowStore to cleanup 270 func (s *TestBase) TearDownWorkflowStore() { 271 s.TaskMgr.Close() 272 s.ClusterMetadataManager.Close() 273 s.MetadataManager.Close() 274 s.ExecutionManager.Close() 275 s.ShardMgr.Close() 276 s.ExecutionManager.Close() 277 s.NamespaceReplicationQueue.Stop() 278 s.Factory.Close() 279 s.DefaultTestCluster.TearDownTestDatabase() 280 } 281 282 // EqualTimesWithPrecision assertion that two times are equal within precision 283 func (s *TestBase) EqualTimesWithPrecision(t1, t2 time.Time, precision time.Duration) { 284 s.True(timeComparator(t1, t2, precision), 285 "Not equal: \n"+ 286 "expected: %s\n"+ 287 "actual : %s%s", t1, t2, 288 ) 289 } 290 291 // EqualTimes assertion that two times are equal within two millisecond precision 292 func (s *TestBase) EqualTimes(t1, t2 time.Time) { 293 s.EqualTimesWithPrecision(t1, t2, TimePrecision) 294 } 295 296 // GenerateTransferTaskID helper 297 func (g *TestTransferTaskIDGenerator) GenerateTransferTaskID() (int64, error) { 298 return atomic.AddInt64(&g.seqNum, 1), nil 299 } 300 301 // Publish is a utility method to add messages to the queue 302 func (s *TestBase) Publish(ctx context.Context, task *replicationspb.ReplicationTask) error { 303 retryPolicy := backoff.NewExponentialRetryPolicy(100 * time.Millisecond). 304 WithBackoffCoefficient(1.5). 305 WithMaximumAttempts(5) 306 307 return backoff.ThrottleRetry( 308 func() error { 309 return s.NamespaceReplicationQueue.Publish(ctx, task) 310 }, 311 retryPolicy, 312 func(e error) bool { 313 return common.IsPersistenceTransientError(e) || isMessageIDConflictError(e) 314 }) 315 } 316 317 func isMessageIDConflictError(err error) bool { 318 _, ok := err.(*persistence.ConditionFailedError) 319 return ok 320 } 321 322 // GetReplicationMessages is a utility method to get messages from the queue 323 func (s *TestBase) GetReplicationMessages( 324 ctx context.Context, 325 lastMessageID int64, 326 pageSize int, 327 ) ([]*replicationspb.ReplicationTask, int64, error) { 328 return s.NamespaceReplicationQueue.GetReplicationMessages(ctx, lastMessageID, pageSize) 329 } 330 331 // UpdateAckLevel updates replication queue ack level 332 func (s *TestBase) UpdateAckLevel( 333 ctx context.Context, 334 lastProcessedMessageID int64, 335 clusterName string, 336 ) error { 337 return s.NamespaceReplicationQueue.UpdateAckLevel(ctx, lastProcessedMessageID, clusterName) 338 } 339 340 // GetAckLevels returns replication queue ack levels 341 func (s *TestBase) GetAckLevels( 342 ctx context.Context, 343 ) (map[string]int64, error) { 344 return s.NamespaceReplicationQueue.GetAckLevels(ctx) 345 } 346 347 // PublishToNamespaceDLQ is a utility method to add messages to the namespace DLQ 348 func (s *TestBase) PublishToNamespaceDLQ(ctx context.Context, task *replicationspb.ReplicationTask) error { 349 retryPolicy := backoff.NewExponentialRetryPolicy(100 * time.Millisecond). 350 WithBackoffCoefficient(1.5). 351 WithMaximumAttempts(5) 352 353 return backoff.ThrottleRetryContext( 354 ctx, 355 func(ctx context.Context) error { 356 return s.NamespaceReplicationQueue.PublishToDLQ(ctx, task) 357 }, 358 retryPolicy, 359 func(e error) bool { 360 return common.IsPersistenceTransientError(e) || isMessageIDConflictError(e) 361 }) 362 } 363 364 // GetMessagesFromNamespaceDLQ is a utility method to get messages from the namespace DLQ 365 func (s *TestBase) GetMessagesFromNamespaceDLQ( 366 ctx context.Context, 367 firstMessageID int64, 368 lastMessageID int64, 369 pageSize int, 370 pageToken []byte, 371 ) ([]*replicationspb.ReplicationTask, []byte, error) { 372 return s.NamespaceReplicationQueue.GetMessagesFromDLQ( 373 ctx, 374 firstMessageID, 375 lastMessageID, 376 pageSize, 377 pageToken, 378 ) 379 } 380 381 // UpdateNamespaceDLQAckLevel updates namespace dlq ack level 382 func (s *TestBase) UpdateNamespaceDLQAckLevel( 383 ctx context.Context, 384 lastProcessedMessageID int64, 385 ) error { 386 return s.NamespaceReplicationQueue.UpdateDLQAckLevel(ctx, lastProcessedMessageID) 387 } 388 389 // GetNamespaceDLQAckLevel returns namespace dlq ack level 390 func (s *TestBase) GetNamespaceDLQAckLevel( 391 ctx context.Context, 392 ) (int64, error) { 393 return s.NamespaceReplicationQueue.GetDLQAckLevel(ctx) 394 } 395 396 // DeleteMessageFromNamespaceDLQ deletes one message from namespace DLQ 397 func (s *TestBase) DeleteMessageFromNamespaceDLQ( 398 ctx context.Context, 399 messageID int64, 400 ) error { 401 return s.NamespaceReplicationQueue.DeleteMessageFromDLQ(ctx, messageID) 402 } 403 404 // RangeDeleteMessagesFromNamespaceDLQ deletes messages from namespace DLQ 405 func (s *TestBase) RangeDeleteMessagesFromNamespaceDLQ( 406 ctx context.Context, 407 firstMessageID int64, 408 lastMessageID int64, 409 ) error { 410 return s.NamespaceReplicationQueue.RangeDeleteMessagesFromDLQ(ctx, firstMessageID, lastMessageID) 411 } 412 413 func randString(length int) string { 414 const lowercaseSet = "abcdefghijklmnopqrstuvwxyz" 415 b := make([]byte, length) 416 for i := range b { 417 b[i] = lowercaseSet[rand.Int63()%int64(len(lowercaseSet))] 418 } 419 return string(b) 420 } 421 422 // GenerateRandomDBName helper 423 // Format: MMDDHHMMSS_abc 424 func GenerateRandomDBName(n int) string { 425 var prefix strings.Builder 426 prefix.WriteString(time.Now().UTC().Format("0102150405")) 427 prefix.WriteRune('_') 428 prefix.WriteString(randString(n)) 429 return prefix.String() 430 } 431 432 func timeComparator(t1, t2 time.Time, timeTolerance time.Duration) bool { 433 diff := t2.Sub(t1) 434 return diff.Nanoseconds() <= timeTolerance.Nanoseconds() 435 }