go.temporal.io/server@v1.23.0/common/persistence/sql/test_sql_persistence.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package sql
    26  
    27  import (
    28  	"fmt"
    29  	"os"
    30  	"path"
    31  	"strings"
    32  	"time"
    33  
    34  	"go.temporal.io/server/common"
    35  	"go.temporal.io/server/common/backoff"
    36  	"go.temporal.io/server/common/config"
    37  	"go.temporal.io/server/common/dynamicconfig"
    38  	"go.temporal.io/server/common/log"
    39  	"go.temporal.io/server/common/log/tag"
    40  	p "go.temporal.io/server/common/persistence"
    41  	"go.temporal.io/server/common/persistence/sql/sqlplugin"
    42  	"go.temporal.io/server/common/resolver"
    43  	"go.temporal.io/server/tests/testutils"
    44  )
    45  
    46  // TestCluster allows executing cassandra operations in testing.
    47  type TestCluster struct {
    48  	dbName         string
    49  	schemaDir      string
    50  	cfg            config.SQL
    51  	faultInjection *config.FaultInjection
    52  	logger         log.Logger
    53  }
    54  
    55  // NewTestCluster returns a new SQL test cluster
    56  func NewTestCluster(
    57  	pluginName string,
    58  	dbName string,
    59  	username string,
    60  	password string,
    61  	host string,
    62  	port int,
    63  	connectAttributes map[string]string,
    64  	schemaDir string,
    65  	faultInjection *config.FaultInjection,
    66  	logger log.Logger,
    67  ) *TestCluster {
    68  	var result TestCluster
    69  	result.logger = logger
    70  	result.dbName = dbName
    71  
    72  	result.schemaDir = schemaDir
    73  	result.cfg = config.SQL{
    74  		User:               username,
    75  		Password:           password,
    76  		ConnectAddr:        fmt.Sprintf("%v:%v", host, port),
    77  		ConnectProtocol:    "tcp",
    78  		PluginName:         pluginName,
    79  		DatabaseName:       dbName,
    80  		TaskScanPartitions: 4,
    81  		ConnectAttributes:  connectAttributes,
    82  	}
    83  
    84  	result.faultInjection = faultInjection
    85  	return &result
    86  }
    87  
    88  // DatabaseName from PersistenceTestCluster interface
    89  func (s *TestCluster) DatabaseName() string {
    90  	return s.dbName
    91  }
    92  
    93  // SetupTestDatabase from PersistenceTestCluster interface
    94  func (s *TestCluster) SetupTestDatabase() {
    95  	s.CreateDatabase()
    96  
    97  	if s.schemaDir == "" {
    98  		s.logger.Info("No schema directory provided, skipping schema setup")
    99  		return
   100  	}
   101  
   102  	schemaDir := s.schemaDir + "/"
   103  	if !strings.HasPrefix(schemaDir, "/") && !strings.HasPrefix(schemaDir, "../") {
   104  		temporalPackageDir := testutils.GetRepoRootDirectory()
   105  		schemaDir = path.Join(temporalPackageDir, schemaDir)
   106  	}
   107  	s.LoadSchema(path.Join(schemaDir, "temporal", "schema.sql"))
   108  	s.LoadSchema(path.Join(schemaDir, "visibility", "schema.sql"))
   109  }
   110  
   111  // Config returns the persistence config for connecting to this test cluster
   112  func (s *TestCluster) Config() config.Persistence {
   113  	cfg := s.cfg
   114  	return config.Persistence{
   115  		DefaultStore:    "test",
   116  		VisibilityStore: "test",
   117  		DataStores: map[string]config.DataStore{
   118  			"test": {SQL: &cfg, FaultInjection: s.faultInjection},
   119  		},
   120  		TransactionSizeLimit: dynamicconfig.GetIntPropertyFn(common.DefaultTransactionSizeLimit),
   121  	}
   122  }
   123  
   124  // TearDownTestDatabase from PersistenceTestCluster interface
   125  func (s *TestCluster) TearDownTestDatabase() {
   126  	s.DropDatabase()
   127  }
   128  
   129  // CreateDatabase from PersistenceTestCluster interface
   130  func (s *TestCluster) CreateDatabase() {
   131  	cfg2 := s.cfg
   132  	// NOTE need to connect with empty name to create new database
   133  	if cfg2.PluginName != "sqlite" {
   134  		cfg2.DatabaseName = ""
   135  	}
   136  
   137  	var db sqlplugin.AdminDB
   138  	var err error
   139  	err = backoff.ThrottleRetry(
   140  		func() error {
   141  			db, err = NewSQLAdminDB(sqlplugin.DbKindUnknown, &cfg2, resolver.NewNoopResolver())
   142  			return err
   143  		},
   144  		backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute),
   145  		nil,
   146  	)
   147  	if err != nil {
   148  		panic(err)
   149  	}
   150  	defer func() {
   151  		err := db.Close()
   152  		if err != nil {
   153  			panic(err)
   154  		}
   155  	}()
   156  	err = db.CreateDatabase(s.cfg.DatabaseName)
   157  	if err != nil {
   158  		panic(err)
   159  	}
   160  	s.logger.Info("created database", tag.NewStringTag("database", s.cfg.DatabaseName))
   161  }
   162  
   163  // DropDatabase from PersistenceTestCluster interface
   164  func (s *TestCluster) DropDatabase() {
   165  	cfg2 := s.cfg
   166  
   167  	if cfg2.PluginName == "sqlite" && cfg2.DatabaseName != ":memory:" && cfg2.ConnectAttributes["mode"] != "memory" {
   168  		if len(cfg2.DatabaseName) > 3 { // 3 should mean not ., .., empty, or /
   169  			err := os.Remove(cfg2.DatabaseName)
   170  			if err != nil {
   171  				panic(err)
   172  			}
   173  		}
   174  		return
   175  	}
   176  
   177  	// NOTE need to connect with empty name to drop the database
   178  	cfg2.DatabaseName = ""
   179  	db, err := NewSQLAdminDB(sqlplugin.DbKindUnknown, &cfg2, resolver.NewNoopResolver())
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  	defer func() {
   184  		err := db.Close()
   185  		if err != nil {
   186  			panic(err)
   187  		}
   188  	}()
   189  	err = db.DropDatabase(s.cfg.DatabaseName)
   190  	if err != nil {
   191  		panic(err)
   192  	}
   193  	s.logger.Info("dropped database", tag.NewStringTag("database", s.cfg.DatabaseName))
   194  }
   195  
   196  // LoadSchema from PersistenceTestCluster interface
   197  func (s *TestCluster) LoadSchema(schemaFile string) {
   198  	statements, err := p.LoadAndSplitQuery([]string{schemaFile})
   199  	if err != nil {
   200  		s.logger.Fatal("LoadSchema", tag.Error(err))
   201  	}
   202  
   203  	var db sqlplugin.AdminDB
   204  	err = backoff.ThrottleRetry(
   205  		func() error {
   206  			db, err = NewSQLAdminDB(sqlplugin.DbKindUnknown, &s.cfg, resolver.NewNoopResolver())
   207  			return err
   208  		},
   209  		backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute),
   210  		nil,
   211  	)
   212  	if err != nil {
   213  		panic(err)
   214  	}
   215  	defer func() {
   216  		err := db.Close()
   217  		if err != nil {
   218  			panic(err)
   219  		}
   220  	}()
   221  
   222  	for _, stmt := range statements {
   223  		if err = db.Exec(stmt); err != nil {
   224  			s.logger.Fatal("LoadSchema", tag.Error(err))
   225  		}
   226  	}
   227  	s.logger.Info("loaded schema")
   228  }