github.com/erda-project/erda-infra@v1.0.10-0.20240327085753-f3a249292aeb/providers/mysqlxorm/sqlite3/interface_test.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlite3
    16  
    17  import (
    18  	"errors"
    19  	"os"
    20  	"path/filepath"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"xorm.io/xorm/names"
    26  
    27  	"github.com/erda-project/erda-infra/providers/mysqlxorm"
    28  )
    29  
    30  const dbSourceName = "test1-*.db"
    31  
    32  type Server struct {
    33  	mysql mysqlxorm.Interface
    34  }
    35  
    36  type User struct {
    37  	ID   uint64 `json:"id" xorm:"pk autoincr"`
    38  	Name string `json:"name"`
    39  }
    40  
    41  func (u *User) TableName() string {
    42  	return "user"
    43  }
    44  
    45  func (s *Server) GetUserByID(id uint64, ops ...mysqlxorm.SessionOption) (*User, error) {
    46  	session := s.mysql.NewSession(ops...)
    47  	defer session.Close()
    48  
    49  	var user User
    50  	_, err := session.ID(id).Get(&user)
    51  
    52  	return &user, err
    53  }
    54  
    55  func (s *Server) CreateUser(user *User, ops ...mysqlxorm.SessionOption) (err error) {
    56  	session := s.mysql.NewSession(ops...)
    57  	defer session.Close()
    58  
    59  	_, err = session.Insert(user)
    60  	return err
    61  }
    62  
    63  func (s *Server) TestTx(err error, ops ...mysqlxorm.SessionOption) error {
    64  	session := s.mysql.NewSession(ops...)
    65  	defer session.Close()
    66  	return err
    67  }
    68  
    69  func TestNewSqlite3(t *testing.T) {
    70  	dbname := filepath.Join(os.TempDir(), dbSourceName)
    71  	engine, err := NewSqlite3(dbname)
    72  	if err != nil {
    73  		t.Fatalf("new sqlite3 err : %s", err)
    74  	}
    75  
    76  	defer engine.Close()
    77  
    78  	server := Server{
    79  		mysql: engine,
    80  	}
    81  
    82  	server.mysql.DB().SetMapper(names.GonicMapper{})
    83  	server.mysql.DB().Sync2(&User{})
    84  
    85  	testCase := []struct {
    86  		name       string
    87  		insertUser []User
    88  		txErr      error
    89  		want       []User
    90  	}{
    91  		{
    92  			name:  "test tx",
    93  			txErr: errors.New("tx error"),
    94  			insertUser: []User{
    95  				{ID: 4, Name: "Alice"},
    96  				{ID: 5, Name: "Bob"},
    97  				{ID: 6, Name: "Cat"},
    98  			},
    99  			want: []User{},
   100  		},
   101  		{
   102  			name: "sqlite3 use for xorm",
   103  			insertUser: []User{
   104  				{ID: 1, Name: "Alice"},
   105  				{ID: 2, Name: "Bob"},
   106  				{ID: 3, Name: "Cat"},
   107  			},
   108  			txErr: nil,
   109  			want: []User{
   110  				{ID: 1, Name: "Alice"},
   111  				{ID: 2, Name: "Bob"},
   112  				{ID: 3, Name: "Cat"},
   113  			},
   114  		},
   115  	}
   116  
   117  	for _, test := range testCase {
   118  		t.Run(test.name, func(t *testing.T) {
   119  			tx := server.mysql.NewSession()
   120  			defer tx.Close()
   121  			if err = tx.Begin(); err != nil {
   122  				t.Fatalf("tx begin err : %s", err)
   123  			}
   124  
   125  			ops := mysqlxorm.WithSession(tx)
   126  			// insert sql
   127  			for _, user := range test.insertUser {
   128  				err = server.CreateUser(&user, ops)
   129  				if err != nil {
   130  					tx.Rollback()
   131  					t.Fatalf("create user err : %s", err)
   132  				}
   133  			}
   134  
   135  			err = server.TestTx(test.txErr, ops)
   136  			if err != nil {
   137  				tx.Rollback()
   138  			} else {
   139  				tx.Commit()
   140  			}
   141  
   142  			if len(test.want) <= 0 {
   143  				for _, user := range test.insertUser {
   144  					u, err := server.GetUserByID(user.ID)
   145  					if err != nil {
   146  						t.Fatalf("get user err : %s", err)
   147  					}
   148  					assert.Equal(t, &User{}, u)
   149  				}
   150  				return
   151  			}
   152  
   153  			for _, user := range test.want {
   154  				u, err := server.GetUserByID(user.ID)
   155  				if err != nil {
   156  					t.Fatalf("get user err : %s", err)
   157  				}
   158  				assert.Equal(t, user.Name, u.Name)
   159  			}
   160  		})
   161  	}
   162  }
   163  
   164  func TestJournalMode(t *testing.T) {
   165  	dbname := filepath.Join(os.TempDir(), dbSourceName)
   166  
   167  	want := []JournalMode{
   168  		MEMORY,
   169  		DELETE,
   170  		PERSIST,
   171  		OFF,
   172  		WAL,
   173  		TRUNCATE,
   174  	}
   175  	defer func() {
   176  		os.Remove(dbname)
   177  	}()
   178  
   179  	for _, w := range want {
   180  		engine, err := NewSqlite3(dbname, WithJournalMode(w))
   181  		if err != nil {
   182  			t.Fatalf("new sqlite3 err : %s", err)
   183  		}
   184  
   185  		// get journal in sqlite3
   186  		results, _ := engine.DB().Query("PRAGMA journal_mode;")
   187  		assert.Equal(t, string(w), string(results[0]["journal_mode"]))
   188  		engine.Close()
   189  	}
   190  
   191  }
   192  
   193  func TestRandomName(t *testing.T) {
   194  	path := filepath.Join(os.TempDir(), "sample-*.txt")
   195  	name1, err := randomName(path)
   196  	if err != nil {
   197  		t.Error(err)
   198  	}
   199  	name2, err := randomName(path)
   200  	if err != nil {
   201  		t.Error(err)
   202  	}
   203  
   204  	assert.True(t, strings.HasPrefix(name1, filepath.Join(os.TempDir(), "sample-")), "Random name does not start with original name")
   205  
   206  	assert.Equal(t, filepath.Ext(name1), ".txt", "Random name does not have original extension")
   207  
   208  	assert.NotEqual(t, name1, name2, "Random name generator produced the same result twice")
   209  }
   210  
   211  func TestWithRandomName(t *testing.T) {
   212  	dbname := filepath.Join(os.TempDir(), dbSourceName)
   213  	engine, err := NewSqlite3(dbname, WithRandomName(false))
   214  	defer func() {
   215  		if engine != nil {
   216  			defer engine.Close()
   217  		}
   218  	}()
   219  	if err != nil {
   220  		panic(err)
   221  	}
   222  
   223  	assert.Nil(t, err)
   224  	assert.Equal(t, dbname, engine.DataSourceName())
   225  	engine.Close()
   226  
   227  	engineRandom, err := NewSqlite3(dbname, WithRandomName(true))
   228  	assert.Nil(t, err)
   229  	assert.NotEqual(t, dbname, engineRandom, "Random name is not take effect")
   230  	assert.Equal(t, filepath.Ext(engineRandom.DataSourceName()), filepath.Ext(dbname), "Random names does not have original extension")
   231  	defer engineRandom.Close()
   232  }
   233  
   234  func TestClose(t *testing.T) {
   235  	// close file
   236  	dbname := filepath.Join(os.TempDir(), dbSourceName)
   237  	engine, err := NewSqlite3(dbname, WithRandomName(true))
   238  	if err != nil {
   239  		panic(err)
   240  	}
   241  	err = engine.Close()
   242  	assert.Nil(t, err)
   243  	// check if the file exists
   244  	_, err = os.Stat(engine.DataSourceName())
   245  	assert.True(t, os.IsNotExist(err))
   246  }