github.com/mattermosttest/mattermost-server/v5@v5.0.0-20200917143240-9dfa12e121f9/store/sqlstore/store_test.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package sqlstore
     5  
     6  import (
     7  	"os"
     8  	"sync"
     9  	"testing"
    10  
    11  	"github.com/mattermost/mattermost-server/v5/model"
    12  	"github.com/mattermost/mattermost-server/v5/store"
    13  	"github.com/mattermost/mattermost-server/v5/store/searchtest"
    14  	"github.com/mattermost/mattermost-server/v5/store/storetest"
    15  )
    16  
    17  type storeType struct {
    18  	Name        string
    19  	SqlSettings *model.SqlSettings
    20  	SqlSupplier *SqlSupplier
    21  	Store       store.Store
    22  }
    23  
    24  var storeTypes []*storeType
    25  
    26  func StoreTest(t *testing.T, f func(*testing.T, store.Store)) {
    27  	defer func() {
    28  		if err := recover(); err != nil {
    29  			tearDownStores()
    30  			panic(err)
    31  		}
    32  	}()
    33  	for _, st := range storeTypes {
    34  		st := st
    35  		t.Run(st.Name, func(t *testing.T) {
    36  			if testing.Short() {
    37  				t.SkipNow()
    38  			}
    39  			f(t, st.Store)
    40  		})
    41  	}
    42  }
    43  
    44  func StoreTestWithSearchTestEngine(t *testing.T, f func(*testing.T, store.Store, *searchtest.SearchTestEngine)) {
    45  	defer func() {
    46  		if err := recover(); err != nil {
    47  			tearDownStores()
    48  			panic(err)
    49  		}
    50  	}()
    51  
    52  	for _, st := range storeTypes {
    53  		st := st
    54  		searchTestEngine := &searchtest.SearchTestEngine{
    55  			Driver: *st.SqlSettings.DriverName,
    56  		}
    57  
    58  		t.Run(st.Name, func(t *testing.T) { f(t, st.Store, searchTestEngine) })
    59  	}
    60  }
    61  
    62  func StoreTestWithSqlSupplier(t *testing.T, f func(*testing.T, store.Store, storetest.SqlSupplier)) {
    63  	defer func() {
    64  		if err := recover(); err != nil {
    65  			tearDownStores()
    66  			panic(err)
    67  		}
    68  	}()
    69  	for _, st := range storeTypes {
    70  		st := st
    71  		t.Run(st.Name, func(t *testing.T) {
    72  			if testing.Short() {
    73  				t.SkipNow()
    74  			}
    75  			f(t, st.Store, st.SqlSupplier)
    76  		})
    77  	}
    78  }
    79  
    80  func initStores() {
    81  	if testing.Short() {
    82  		return
    83  	}
    84  	// In CI, we already run the entire test suite for both mysql and postgres in parallel.
    85  	// So we just run the tests for the current database set.
    86  	if os.Getenv("IS_CI") == "true" {
    87  		switch os.Getenv("MM_SQLSETTINGS_DRIVERNAME") {
    88  		case "mysql":
    89  			storeTypes = append(storeTypes, &storeType{
    90  				Name:        "MySQL",
    91  				SqlSettings: storetest.MakeSqlSettings(model.DATABASE_DRIVER_MYSQL),
    92  			})
    93  		case "postgres":
    94  			storeTypes = append(storeTypes, &storeType{
    95  				Name:        "PostgreSQL",
    96  				SqlSettings: storetest.MakeSqlSettings(model.DATABASE_DRIVER_POSTGRES),
    97  			})
    98  		}
    99  	} else {
   100  		storeTypes = append(storeTypes, &storeType{
   101  			Name:        "MySQL",
   102  			SqlSettings: storetest.MakeSqlSettings(model.DATABASE_DRIVER_MYSQL),
   103  		})
   104  		storeTypes = append(storeTypes, &storeType{
   105  			Name:        "PostgreSQL",
   106  			SqlSettings: storetest.MakeSqlSettings(model.DATABASE_DRIVER_POSTGRES),
   107  		})
   108  	}
   109  
   110  	defer func() {
   111  		if err := recover(); err != nil {
   112  			tearDownStores()
   113  			panic(err)
   114  		}
   115  	}()
   116  	var wg sync.WaitGroup
   117  	for _, st := range storeTypes {
   118  		st := st
   119  		wg.Add(1)
   120  		go func() {
   121  			defer wg.Done()
   122  			st.SqlSupplier = NewSqlSupplier(*st.SqlSettings, nil)
   123  			st.Store = st.SqlSupplier
   124  			st.Store.DropAllTables()
   125  			st.Store.MarkSystemRanUnitTests()
   126  		}()
   127  	}
   128  	wg.Wait()
   129  }
   130  
   131  var tearDownStoresOnce sync.Once
   132  
   133  func tearDownStores() {
   134  	if testing.Short() {
   135  		return
   136  	}
   137  	tearDownStoresOnce.Do(func() {
   138  		var wg sync.WaitGroup
   139  		wg.Add(len(storeTypes))
   140  		for _, st := range storeTypes {
   141  			st := st
   142  			go func() {
   143  				if st.Store != nil {
   144  					st.Store.Close()
   145  				}
   146  				if st.SqlSettings != nil {
   147  					storetest.CleanupSqlSettings(st.SqlSettings)
   148  				}
   149  				wg.Done()
   150  			}()
   151  		}
   152  		wg.Wait()
   153  	})
   154  }