github.com/erda-project/erda-infra@v1.0.9/providers/mysqlxorm/sqlite3/interface_test.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlite3
    16  
    17  import (
    18  	"errors"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/xormplus/core"
    21  	"os"
    22  	"path/filepath"
    23  	"testing"
    24  
    25  	"github.com/erda-project/erda-infra/providers/mysqlxorm"
    26  )
    27  
    28  const dbSourceName = "test1.sqlite3"
    29  
    30  type Server struct {
    31  	mysql mysqlxorm.Interface
    32  }
    33  
    34  type User struct {
    35  	ID   uint64 `json:"id" xorm:"pk autoincr"`
    36  	Name string `json:"name"`
    37  }
    38  
    39  func (u *User) TableName() string {
    40  	return "user"
    41  }
    42  
    43  func (s *Server) GetUserByID(id uint64, ops ...mysqlxorm.SessionOption) (*User, error) {
    44  	session := s.mysql.NewSession(ops...)
    45  	defer session.Close()
    46  
    47  	var user User
    48  	_, err := session.ID(id).Get(&user)
    49  
    50  	return &user, err
    51  }
    52  
    53  func (s *Server) CreateUser(user *User, ops ...mysqlxorm.SessionOption) (err error) {
    54  	session := s.mysql.NewSession(ops...)
    55  	defer session.Close()
    56  
    57  	_, err = session.Insert(user)
    58  	return err
    59  }
    60  
    61  func (s *Server) TestTx(err error, ops ...mysqlxorm.SessionOption) error {
    62  	session := s.mysql.NewSession(ops...)
    63  	defer session.Close()
    64  	return err
    65  }
    66  
    67  func TestNewSqlite3(t *testing.T) {
    68  	dbname := filepath.Join(os.TempDir(), dbSourceName)
    69  	defer func() {
    70  		os.Remove(dbname)
    71  	}()
    72  	engine, err := NewSqlite3(dbname)
    73  	if err != nil {
    74  		t.Fatalf("new sqlite3 err : %s", err)
    75  	}
    76  
    77  	server := Server{
    78  		mysql: engine,
    79  	}
    80  
    81  	server.mysql.DB().SetMapper(core.GonicMapper{})
    82  	server.mysql.DB().Sync2(&User{})
    83  
    84  	testCase := []struct {
    85  		name       string
    86  		insertUser []User
    87  		txErr      error
    88  		want       []User
    89  	}{
    90  		{
    91  			name:  "test tx",
    92  			txErr: errors.New("tx error"),
    93  			insertUser: []User{
    94  				{ID: 4, Name: "Alice"},
    95  				{ID: 5, Name: "Bob"},
    96  				{ID: 6, Name: "Cat"},
    97  			},
    98  			want: []User{},
    99  		},
   100  		{
   101  			name: "sqlite3 use for xorm",
   102  			insertUser: []User{
   103  				{ID: 1, Name: "Alice"},
   104  				{ID: 2, Name: "Bob"},
   105  				{ID: 3, Name: "Cat"},
   106  			},
   107  			txErr: nil,
   108  			want: []User{
   109  				{ID: 1, Name: "Alice"},
   110  				{ID: 2, Name: "Bob"},
   111  				{ID: 3, Name: "Cat"},
   112  			},
   113  		},
   114  	}
   115  
   116  	for _, test := range testCase {
   117  		t.Run(test.name, func(t *testing.T) {
   118  			tx := server.mysql.NewSession()
   119  			defer tx.Close()
   120  			if err = tx.Begin(); err != nil {
   121  				t.Fatalf("tx begin err : %s", err)
   122  			}
   123  
   124  			ops := mysqlxorm.WithSession(tx)
   125  			// insert sql
   126  			for _, user := range test.insertUser {
   127  				err = server.CreateUser(&user, ops)
   128  				if err != nil {
   129  					tx.Rollback()
   130  					t.Fatalf("create user err : %s", err)
   131  				}
   132  			}
   133  
   134  			err = server.TestTx(test.txErr, ops)
   135  			if err != nil {
   136  				tx.Rollback()
   137  			} else {
   138  				tx.Commit()
   139  			}
   140  
   141  			if len(test.want) <= 0 {
   142  				for _, user := range test.insertUser {
   143  					u, err := server.GetUserByID(user.ID)
   144  					if err != nil {
   145  						t.Fatalf("get user err : %s", err)
   146  					}
   147  					assert.Equal(t, &User{}, u)
   148  				}
   149  				return
   150  			}
   151  
   152  			for _, user := range test.want {
   153  				u, err := server.GetUserByID(user.ID)
   154  				if err != nil {
   155  					t.Fatalf("get user err : %s", err)
   156  				}
   157  				assert.Equal(t, user.Name, u.Name)
   158  			}
   159  		})
   160  	}
   161  }