code.gitea.io/gitea@v1.22.3/models/unittest/fixtures.go (about)

     1  // Copyright 2021 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  //nolint:forbidigo
     5  package unittest
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"time"
    11  
    12  	"code.gitea.io/gitea/models/db"
    13  	"code.gitea.io/gitea/modules/auth/password/hash"
    14  	"code.gitea.io/gitea/modules/setting"
    15  
    16  	"github.com/go-testfixtures/testfixtures/v3"
    17  	"xorm.io/xorm"
    18  	"xorm.io/xorm/schemas"
    19  )
    20  
    21  var fixturesLoader *testfixtures.Loader
    22  
    23  // GetXORMEngine gets the XORM engine
    24  func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
    25  	if len(engine) == 1 {
    26  		return engine[0]
    27  	}
    28  	return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
    29  }
    30  
    31  // InitFixtures initialize test fixtures for a test database
    32  func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
    33  	e := GetXORMEngine(engine...)
    34  	var fixtureOptionFiles func(*testfixtures.Loader) error
    35  	if opts.Dir != "" {
    36  		fixtureOptionFiles = testfixtures.Directory(opts.Dir)
    37  	} else {
    38  		fixtureOptionFiles = testfixtures.Files(opts.Files...)
    39  	}
    40  	dialect := "unknown"
    41  	switch e.Dialect().URI().DBType {
    42  	case schemas.POSTGRES:
    43  		dialect = "postgres"
    44  	case schemas.MYSQL:
    45  		dialect = "mysql"
    46  	case schemas.MSSQL:
    47  		dialect = "mssql"
    48  	case schemas.SQLITE:
    49  		dialect = "sqlite3"
    50  	default:
    51  		fmt.Println("Unsupported RDBMS for integration tests")
    52  		os.Exit(1)
    53  	}
    54  	loaderOptions := []func(loader *testfixtures.Loader) error{
    55  		testfixtures.Database(e.DB().DB),
    56  		testfixtures.Dialect(dialect),
    57  		testfixtures.DangerousSkipTestDatabaseCheck(),
    58  		fixtureOptionFiles,
    59  	}
    60  
    61  	if e.Dialect().URI().DBType == schemas.POSTGRES {
    62  		loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
    63  	}
    64  
    65  	fixturesLoader, err = testfixtures.New(loaderOptions...)
    66  	if err != nil {
    67  		return err
    68  	}
    69  
    70  	// register the dummy hash algorithm function used in the test fixtures
    71  	_ = hash.Register("dummy", hash.NewDummyHasher)
    72  
    73  	setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
    74  
    75  	return err
    76  }
    77  
    78  // LoadFixtures load fixtures for a test database
    79  func LoadFixtures(engine ...*xorm.Engine) error {
    80  	e := GetXORMEngine(engine...)
    81  	var err error
    82  	// (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times.
    83  	for i := 0; i < 5; i++ {
    84  		if err = fixturesLoader.Load(); err == nil {
    85  			break
    86  		}
    87  		time.Sleep(200 * time.Millisecond)
    88  	}
    89  	if err != nil {
    90  		fmt.Printf("LoadFixtures failed after retries: %v\n", err)
    91  	}
    92  	// Now if we're running postgres we need to tell it to update the sequences
    93  	if e.Dialect().URI().DBType == schemas.POSTGRES {
    94  		results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
    95  		quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
    96  		', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
    97  		quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
    98  	 FROM pg_class AS S,
    99  	      pg_depend AS D,
   100  	      pg_class AS T,
   101  	      pg_attribute AS C,
   102  	      pg_tables AS PGT
   103  	 WHERE S.relkind = 'S'
   104  	     AND S.oid = D.objid
   105  	     AND D.refobjid = T.oid
   106  	     AND D.refobjid = C.attrelid
   107  	     AND D.refobjsubid = C.attnum
   108  	     AND T.relname = PGT.tablename
   109  	 ORDER BY S.relname;`)
   110  		if err != nil {
   111  			fmt.Printf("Failed to generate sequence update: %v\n", err)
   112  			return err
   113  		}
   114  		for _, r := range results {
   115  			for _, value := range r {
   116  				_, err = e.Exec(value)
   117  				if err != nil {
   118  					fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err)
   119  					return err
   120  				}
   121  			}
   122  		}
   123  	}
   124  	_ = hash.Register("dummy", hash.NewDummyHasher)
   125  	setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
   126  
   127  	return err
   128  }