go.temporal.io/server@v1.23.0/common/persistence/persistence-tests/persistence_test_base.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package persistencetests
    26  
    27  import (
    28  	"context"
    29  	"fmt"
    30  	"math/rand"
    31  	"strings"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"github.com/stretchr/testify/suite"
    36  	persistencespb "go.temporal.io/server/api/persistence/v1"
    37  	replicationspb "go.temporal.io/server/api/replication/v1"
    38  	"go.temporal.io/server/common"
    39  	"go.temporal.io/server/common/backoff"
    40  	"go.temporal.io/server/common/clock"
    41  	"go.temporal.io/server/common/cluster"
    42  	"go.temporal.io/server/common/config"
    43  	"go.temporal.io/server/common/dynamicconfig"
    44  	"go.temporal.io/server/common/log"
    45  	"go.temporal.io/server/common/log/tag"
    46  	"go.temporal.io/server/common/metrics"
    47  	"go.temporal.io/server/common/persistence"
    48  	"go.temporal.io/server/common/persistence/cassandra"
    49  	"go.temporal.io/server/common/persistence/client"
    50  	"go.temporal.io/server/common/persistence/serialization"
    51  	"go.temporal.io/server/common/persistence/sql"
    52  	"go.temporal.io/server/common/persistence/sql/sqlplugin/mysql"
    53  	"go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql"
    54  	"go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite"
    55  	"go.temporal.io/server/common/persistence/visibility"
    56  	"go.temporal.io/server/common/quotas"
    57  	"go.temporal.io/server/common/resolver"
    58  	"go.temporal.io/server/common/searchattribute"
    59  	"go.temporal.io/server/environment"
    60  )
    61  
    62  // TimePrecision is needed to account for database timestamp precision.
    63  // Cassandra only provides milliseconds timestamp precision, so we need to use tolerance when doing comparison
    64  const TimePrecision = 2 * time.Millisecond
    65  
    66  type (
    67  	// TransferTaskIDGenerator generates IDs for transfer tasks written by helper methods
    68  	TransferTaskIDGenerator interface {
    69  		GenerateTransferTaskID() (int64, error)
    70  	}
    71  
    72  	// TestBaseOptions options to configure workflow test base.
    73  	TestBaseOptions struct {
    74  		SQLDBPluginName   string
    75  		DBName            string
    76  		DBUsername        string
    77  		DBPassword        string
    78  		DBHost            string
    79  		DBPort            int `yaml:"-"`
    80  		ConnectAttributes map[string]string
    81  		StoreType         string                 `yaml:"-"`
    82  		SchemaDir         string                 `yaml:"-"`
    83  		FaultInjection    *config.FaultInjection `yaml:"faultinjection"`
    84  		Logger            log.Logger             `yaml:"-"`
    85  	}
    86  
    87  	// TestBase wraps the base setup needed to create workflows over persistence layer.
    88  	TestBase struct {
    89  		suite.Suite
    90  		ShardMgr                  persistence.ShardManager
    91  		AbstractDataStoreFactory  client.AbstractDataStoreFactory
    92  		VisibilityStoreFactory    visibility.VisibilityStoreFactory
    93  		FaultInjection            *client.FaultInjectionDataStoreFactory
    94  		Factory                   client.Factory
    95  		ExecutionManager          persistence.ExecutionManager
    96  		TaskMgr                   persistence.TaskManager
    97  		ClusterMetadataManager    persistence.ClusterMetadataManager
    98  		MetadataManager           persistence.MetadataManager
    99  		NamespaceReplicationQueue persistence.NamespaceReplicationQueue
   100  		ShardInfo                 *persistencespb.ShardInfo
   101  		TaskIDGenerator           TransferTaskIDGenerator
   102  		ClusterMetadata           cluster.Metadata
   103  		SearchAttributesManager   searchattribute.Manager
   104  		PersistenceRateLimiter    quotas.RequestRateLimiter
   105  		PersistenceHealthSignals  persistence.HealthSignalAggregator
   106  		ReadLevel                 int64
   107  		ReplicationReadLevel      int64
   108  		DefaultTestCluster        PersistenceTestCluster
   109  		Logger                    log.Logger
   110  	}
   111  
   112  	// PersistenceTestCluster exposes management operations on a database
   113  	PersistenceTestCluster interface {
   114  		SetupTestDatabase()
   115  		TearDownTestDatabase()
   116  		Config() config.Persistence
   117  	}
   118  
   119  	// TestTransferTaskIDGenerator helper
   120  	TestTransferTaskIDGenerator struct {
   121  		seqNum int64
   122  	}
   123  )
   124  
   125  // NewTestBaseWithCassandra returns a persistence test base backed by cassandra datastore
   126  func NewTestBaseWithCassandra(options *TestBaseOptions) *TestBase {
   127  	logger := log.NewTestLogger()
   128  	testCluster := NewTestClusterForCassandra(options, logger)
   129  	return NewTestBaseForCluster(testCluster, logger)
   130  }
   131  
   132  func NewTestClusterForCassandra(options *TestBaseOptions, logger log.Logger) *cassandra.TestCluster {
   133  	if options.DBName == "" {
   134  		options.DBName = "test_" + GenerateRandomDBName(3)
   135  	}
   136  	testCluster := cassandra.NewTestCluster(options.DBName, options.DBUsername, options.DBPassword, options.DBHost, options.DBPort, options.SchemaDir, options.FaultInjection, logger)
   137  	return testCluster
   138  }
   139  
   140  // NewTestBaseWithSQL returns a new persistence test base backed by SQL
   141  func NewTestBaseWithSQL(options *TestBaseOptions) *TestBase {
   142  	if options.DBName == "" {
   143  		options.DBName = "test_" + GenerateRandomDBName(3)
   144  	}
   145  	logger := options.Logger
   146  	if logger == nil {
   147  		logger = log.NewTestLogger()
   148  	}
   149  
   150  	if options.DBPort == 0 {
   151  		switch options.SQLDBPluginName {
   152  		case mysql.PluginName, mysql.PluginNameV8:
   153  			options.DBPort = environment.GetMySQLPort()
   154  		case postgresql.PluginName, postgresql.PluginNamePGX, postgresql.PluginNameV12, postgresql.PluginNameV12PGX:
   155  			options.DBPort = environment.GetPostgreSQLPort()
   156  		case sqlite.PluginName:
   157  			options.DBPort = 0
   158  		default:
   159  			panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName))
   160  		}
   161  	}
   162  	if options.DBHost == "" {
   163  		switch options.SQLDBPluginName {
   164  		case mysql.PluginName, mysql.PluginNameV8:
   165  			options.DBHost = environment.GetMySQLAddress()
   166  		case postgresql.PluginName, postgresql.PluginNamePGX:
   167  			options.DBHost = environment.GetPostgreSQLAddress()
   168  		case sqlite.PluginName:
   169  			options.DBHost = environment.GetLocalhostIP()
   170  		default:
   171  			panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName))
   172  		}
   173  	}
   174  	testCluster := sql.NewTestCluster(options.SQLDBPluginName, options.DBName, options.DBUsername, options.DBPassword, options.DBHost, options.DBPort, options.ConnectAttributes, options.SchemaDir, options.FaultInjection, logger)
   175  	return NewTestBaseForCluster(testCluster, logger)
   176  }
   177  
   178  // NewTestBase returns a persistence test base backed by either cassandra or sql
   179  func NewTestBase(options *TestBaseOptions) *TestBase {
   180  	switch options.StoreType {
   181  	case config.StoreTypeSQL:
   182  		return NewTestBaseWithSQL(options)
   183  	case config.StoreTypeNoSQL:
   184  		return NewTestBaseWithCassandra(options)
   185  	default:
   186  		panic("invalid storeType " + options.StoreType)
   187  	}
   188  }
   189  
   190  func NewTestBaseForCluster(testCluster PersistenceTestCluster, logger log.Logger) *TestBase {
   191  	return &TestBase{
   192  		DefaultTestCluster: testCluster,
   193  		Logger:             logger,
   194  	}
   195  }
   196  
   197  // Setup sets up the test base, must be called as part of SetupSuite
   198  func (s *TestBase) Setup(clusterMetadataConfig *cluster.Config) {
   199  	var err error
   200  	shardID := int32(10)
   201  	if clusterMetadataConfig == nil {
   202  		clusterMetadataConfig = cluster.NewTestClusterMetadataConfig(false, false)
   203  	}
   204  	if s.PersistenceHealthSignals == nil {
   205  		s.PersistenceHealthSignals = persistence.NoopHealthSignalAggregator
   206  	}
   207  
   208  	clusterName := clusterMetadataConfig.CurrentClusterName
   209  
   210  	s.DefaultTestCluster.SetupTestDatabase()
   211  
   212  	cfg := s.DefaultTestCluster.Config()
   213  	dataStoreFactory, faultInjection := client.DataStoreFactoryProvider(
   214  		client.ClusterName(clusterName),
   215  		resolver.NewNoopResolver(),
   216  		&cfg,
   217  		s.AbstractDataStoreFactory,
   218  		s.Logger,
   219  		metrics.NoopMetricsHandler,
   220  	)
   221  	factory := client.NewFactory(dataStoreFactory, &cfg, s.PersistenceRateLimiter, serialization.NewSerializer(), nil, clusterName, metrics.NoopMetricsHandler, s.Logger, s.PersistenceHealthSignals)
   222  
   223  	s.TaskMgr, err = factory.NewTaskManager()
   224  	s.fatalOnError("NewTaskManager", err)
   225  
   226  	s.ClusterMetadataManager, err = factory.NewClusterMetadataManager()
   227  	s.fatalOnError("NewClusterMetadataManager", err)
   228  
   229  	s.ClusterMetadata = cluster.NewMetadataFromConfig(clusterMetadataConfig, s.ClusterMetadataManager, dynamicconfig.NewNoopCollection(), s.Logger)
   230  	s.SearchAttributesManager = searchattribute.NewManager(clock.NewRealTimeSource(), s.ClusterMetadataManager, dynamicconfig.GetBoolPropertyFn(true))
   231  
   232  	s.MetadataManager, err = factory.NewMetadataManager()
   233  	s.fatalOnError("NewMetadataManager", err)
   234  
   235  	s.ShardMgr, err = factory.NewShardManager()
   236  	s.fatalOnError("NewShardManager", err)
   237  
   238  	s.ExecutionManager, err = factory.NewExecutionManager()
   239  	s.fatalOnError("NewExecutionManager", err)
   240  
   241  	s.Factory = factory
   242  	s.FaultInjection = faultInjection
   243  
   244  	s.ReadLevel = 0
   245  	s.ReplicationReadLevel = 0
   246  	s.ShardInfo = &persistencespb.ShardInfo{
   247  		ShardId: shardID,
   248  		RangeId: 0,
   249  	}
   250  
   251  	s.TaskIDGenerator = &TestTransferTaskIDGenerator{}
   252  	_, err = s.ShardMgr.GetOrCreateShard(context.Background(), &persistence.GetOrCreateShardRequest{
   253  		ShardID:          shardID,
   254  		InitialShardInfo: s.ShardInfo,
   255  	})
   256  	s.fatalOnError("CreateShard", err)
   257  
   258  	queue, err := factory.NewNamespaceReplicationQueue()
   259  	s.fatalOnError("Create NamespaceReplicationQueue", err)
   260  	s.NamespaceReplicationQueue = queue
   261  }
   262  
   263  func (s *TestBase) fatalOnError(msg string, err error) {
   264  	if err != nil {
   265  		s.Logger.Fatal(msg, tag.Error(err))
   266  	}
   267  }
   268  
   269  // TearDownWorkflowStore to cleanup
   270  func (s *TestBase) TearDownWorkflowStore() {
   271  	s.TaskMgr.Close()
   272  	s.ClusterMetadataManager.Close()
   273  	s.MetadataManager.Close()
   274  	s.ExecutionManager.Close()
   275  	s.ShardMgr.Close()
   276  	s.ExecutionManager.Close()
   277  	s.NamespaceReplicationQueue.Stop()
   278  	s.Factory.Close()
   279  	s.DefaultTestCluster.TearDownTestDatabase()
   280  }
   281  
   282  // EqualTimesWithPrecision assertion that two times are equal within precision
   283  func (s *TestBase) EqualTimesWithPrecision(t1, t2 time.Time, precision time.Duration) {
   284  	s.True(timeComparator(t1, t2, precision),
   285  		"Not equal: \n"+
   286  			"expected: %s\n"+
   287  			"actual  : %s%s", t1, t2,
   288  	)
   289  }
   290  
   291  // EqualTimes assertion that two times are equal within two millisecond precision
   292  func (s *TestBase) EqualTimes(t1, t2 time.Time) {
   293  	s.EqualTimesWithPrecision(t1, t2, TimePrecision)
   294  }
   295  
   296  // GenerateTransferTaskID helper
   297  func (g *TestTransferTaskIDGenerator) GenerateTransferTaskID() (int64, error) {
   298  	return atomic.AddInt64(&g.seqNum, 1), nil
   299  }
   300  
   301  // Publish is a utility method to add messages to the queue
   302  func (s *TestBase) Publish(ctx context.Context, task *replicationspb.ReplicationTask) error {
   303  	retryPolicy := backoff.NewExponentialRetryPolicy(100 * time.Millisecond).
   304  		WithBackoffCoefficient(1.5).
   305  		WithMaximumAttempts(5)
   306  
   307  	return backoff.ThrottleRetry(
   308  		func() error {
   309  			return s.NamespaceReplicationQueue.Publish(ctx, task)
   310  		},
   311  		retryPolicy,
   312  		func(e error) bool {
   313  			return common.IsPersistenceTransientError(e) || isMessageIDConflictError(e)
   314  		})
   315  }
   316  
   317  func isMessageIDConflictError(err error) bool {
   318  	_, ok := err.(*persistence.ConditionFailedError)
   319  	return ok
   320  }
   321  
   322  // GetReplicationMessages is a utility method to get messages from the queue
   323  func (s *TestBase) GetReplicationMessages(
   324  	ctx context.Context,
   325  	lastMessageID int64,
   326  	pageSize int,
   327  ) ([]*replicationspb.ReplicationTask, int64, error) {
   328  	return s.NamespaceReplicationQueue.GetReplicationMessages(ctx, lastMessageID, pageSize)
   329  }
   330  
   331  // UpdateAckLevel updates replication queue ack level
   332  func (s *TestBase) UpdateAckLevel(
   333  	ctx context.Context,
   334  	lastProcessedMessageID int64,
   335  	clusterName string,
   336  ) error {
   337  	return s.NamespaceReplicationQueue.UpdateAckLevel(ctx, lastProcessedMessageID, clusterName)
   338  }
   339  
   340  // GetAckLevels returns replication queue ack levels
   341  func (s *TestBase) GetAckLevels(
   342  	ctx context.Context,
   343  ) (map[string]int64, error) {
   344  	return s.NamespaceReplicationQueue.GetAckLevels(ctx)
   345  }
   346  
   347  // PublishToNamespaceDLQ is a utility method to add messages to the namespace DLQ
   348  func (s *TestBase) PublishToNamespaceDLQ(ctx context.Context, task *replicationspb.ReplicationTask) error {
   349  	retryPolicy := backoff.NewExponentialRetryPolicy(100 * time.Millisecond).
   350  		WithBackoffCoefficient(1.5).
   351  		WithMaximumAttempts(5)
   352  
   353  	return backoff.ThrottleRetryContext(
   354  		ctx,
   355  		func(ctx context.Context) error {
   356  			return s.NamespaceReplicationQueue.PublishToDLQ(ctx, task)
   357  		},
   358  		retryPolicy,
   359  		func(e error) bool {
   360  			return common.IsPersistenceTransientError(e) || isMessageIDConflictError(e)
   361  		})
   362  }
   363  
   364  // GetMessagesFromNamespaceDLQ is a utility method to get messages from the namespace DLQ
   365  func (s *TestBase) GetMessagesFromNamespaceDLQ(
   366  	ctx context.Context,
   367  	firstMessageID int64,
   368  	lastMessageID int64,
   369  	pageSize int,
   370  	pageToken []byte,
   371  ) ([]*replicationspb.ReplicationTask, []byte, error) {
   372  	return s.NamespaceReplicationQueue.GetMessagesFromDLQ(
   373  		ctx,
   374  		firstMessageID,
   375  		lastMessageID,
   376  		pageSize,
   377  		pageToken,
   378  	)
   379  }
   380  
   381  // UpdateNamespaceDLQAckLevel updates namespace dlq ack level
   382  func (s *TestBase) UpdateNamespaceDLQAckLevel(
   383  	ctx context.Context,
   384  	lastProcessedMessageID int64,
   385  ) error {
   386  	return s.NamespaceReplicationQueue.UpdateDLQAckLevel(ctx, lastProcessedMessageID)
   387  }
   388  
   389  // GetNamespaceDLQAckLevel returns namespace dlq ack level
   390  func (s *TestBase) GetNamespaceDLQAckLevel(
   391  	ctx context.Context,
   392  ) (int64, error) {
   393  	return s.NamespaceReplicationQueue.GetDLQAckLevel(ctx)
   394  }
   395  
   396  // DeleteMessageFromNamespaceDLQ deletes one message from namespace DLQ
   397  func (s *TestBase) DeleteMessageFromNamespaceDLQ(
   398  	ctx context.Context,
   399  	messageID int64,
   400  ) error {
   401  	return s.NamespaceReplicationQueue.DeleteMessageFromDLQ(ctx, messageID)
   402  }
   403  
   404  // RangeDeleteMessagesFromNamespaceDLQ deletes messages from namespace DLQ
   405  func (s *TestBase) RangeDeleteMessagesFromNamespaceDLQ(
   406  	ctx context.Context,
   407  	firstMessageID int64,
   408  	lastMessageID int64,
   409  ) error {
   410  	return s.NamespaceReplicationQueue.RangeDeleteMessagesFromDLQ(ctx, firstMessageID, lastMessageID)
   411  }
   412  
   413  func randString(length int) string {
   414  	const lowercaseSet = "abcdefghijklmnopqrstuvwxyz"
   415  	b := make([]byte, length)
   416  	for i := range b {
   417  		b[i] = lowercaseSet[rand.Int63()%int64(len(lowercaseSet))]
   418  	}
   419  	return string(b)
   420  }
   421  
   422  // GenerateRandomDBName helper
   423  // Format: MMDDHHMMSS_abc
   424  func GenerateRandomDBName(n int) string {
   425  	var prefix strings.Builder
   426  	prefix.WriteString(time.Now().UTC().Format("0102150405"))
   427  	prefix.WriteRune('_')
   428  	prefix.WriteString(randString(n))
   429  	return prefix.String()
   430  }
   431  
   432  func timeComparator(t1, t2 time.Time, timeTolerance time.Duration) bool {
   433  	diff := t2.Sub(t1)
   434  	return diff.Nanoseconds() <= timeTolerance.Nanoseconds()
   435  }