github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dao_test.go (about)

     1  package sqx_test
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"strconv"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/bingoohuang/gg/pkg/sqx"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  // person 结构体,对应到person表字段.
    16  type person struct {
    17  	ID  string
    18  	Age int
    19  }
    20  
    21  // personDao 定义对person表操作的所有方法.
    22  type personDao struct {
    23  	CreateTable func()                         `sql:"create table person(id varchar(100), age int)"`
    24  	Add         func(person)                   `sql:"insert into person(id, age) values(:id, :age)"`
    25  	AddAll      func(...person)                `sql:"insert into person(id, age) values(:id, :age)"`
    26  	AddAll2     func(...person) uint64         `sql:"insert into person(id, age) values(:id, :age)"`
    27  	AddAll3     func(...person) int            `sql:"insert into person(id, age) values(:id, :age)"`
    28  	AddAll4     func(...person) (int, error)   `sql:"insert into person(id, age) values(:id, :age)"`
    29  	Find        func(id string) person         `sql:"select id, age from person where id=:1"`
    30  	Find2       func(id string) *person        `sql:"select id, age from person where id=:1"`
    31  	ListAll     func() []person                `sql:"select id, age from person"`
    32  	ListByID    func(string) []person          `sql:"select id, age from person where id=:1"`
    33  	Delete      func(string) int               `sql:"delete from person where id=:1"`
    34  	GetAge      func(string) struct{ Age int } `sql:"select age from person where id=:1"`
    35  }
    36  
    37  func TestDao(t *testing.T) {
    38  	that := assert.New(t)
    39  
    40  	// 生成DAO,自动创建dao结构体中的函数字段
    41  	dao := &personDao{}
    42  	rawDB, _ := openDB(t)
    43  	sqx.DB = rawDB
    44  	that.Nil(sqx.CreateDao(dao))
    45  
    46  	// 建表
    47  	dao.CreateTable()
    48  	// 插入
    49  	dao.Add(person{"100", 100})
    50  	// 查找
    51  	that.Equal(person{"100", 100}, dao.Find("100"))
    52  	that.Equal(person{"100", 100}, *dao.Find2("100"))
    53  	// 刪除
    54  	that.Equal(1, dao.Delete("100"))
    55  	// 再找,找不到,返回零值
    56  	that.Zero(dao.Find("100"))
    57  	// 再找,找不到,返回零值
    58  	v := dao.Find2("100")
    59  	that.Nil(v)
    60  	// 插入
    61  	dao.Add(person{"200", 200})
    62  	// 多值插入
    63  	dao.AddAll(person{"300", 300}, person{"400", 400})
    64  	// 列表
    65  	that.Equal([]person{{"200", 200}, {"300", 300}, {"400", 400}}, dao.ListAll())
    66  	// 条件列表
    67  	that.Equal([]person{{"200", 200}}, dao.ListByID("200"))
    68  	// 匿名结构
    69  	that.Equal(struct{ Age int }{Age: 200}, dao.GetAge("200"))
    70  
    71  	// 多值插入
    72  	lastInsertID2 := dao.AddAll2(person{"500", 500}, person{"600", 600})
    73  	assert.True(t, lastInsertID2 > 0)
    74  
    75  	lastInsertID3 := dao.AddAll3(person{"700", 500}, person{"701", 600})
    76  	assert.True(t, lastInsertID3 > 0)
    77  
    78  	lastInsertID4, err := dao.AddAll4(person{"900", 500}, person{"702", 600})
    79  	assert.True(t, lastInsertID4 > 0)
    80  	assert.Nil(t, err)
    81  }
    82  
    83  func openDB(t *testing.T) (*sql.DB, *sqx.Sqx) {
    84  	// 创建数据库连接池
    85  	rawDb, db, err := sqx.Open("sqlite3", ":memory:")
    86  	assert.Nil(t, err)
    87  
    88  	return rawDb, db
    89  }
    90  
    91  // personDao2 定义对person表操作的所有方法.
    92  type personDao2 struct {
    93  	CreateTable func()          `sql:"create table person(id varchar(100), age int)"`
    94  	AddAll      func(...person) `sql:"insert into person(id, age) values(:id, :age)"`
    95  
    96  	GetAgeE func(string) (struct{ Age int }, error) `sql:"select age from person where xid=:1"`
    97  	GetAgeX func(string) person                     `sql:"select age from person where xid=:1"`
    98  
    99  	Err error // 添加这个字段,可以用来单独接收error信息
   100  
   101  	ListAll func() []person `sql:"select id, age from person order by id"`
   102  }
   103  
   104  func TestDaoWithError(t *testing.T) {
   105  	that := assert.New(t)
   106  
   107  	// 生成DAO,自动创建dao结构体中的函数字段
   108  	dao := &personDao2{}
   109  
   110  	var err error
   111  
   112  	sqx.DB, _ = openDB(t)
   113  	that.Nil(sqx.CreateDao(dao, sqx.WithError(&err)))
   114  
   115  	dao.CreateTable()
   116  
   117  	that.Nil(dao.Err)
   118  	ageX, err := dao.GetAgeE("200")
   119  	that.Error(err)
   120  	that.Zero(ageX)
   121  	that.Error(dao.Err)
   122  
   123  	// 条件列表
   124  	dao.AddAll(person{"200", 200})
   125  	that.Nil(dao.Err) // 验证Err字段是否重置
   126  
   127  	that.Zero(dao.GetAgeX("100"))
   128  	that.Error(err)
   129  }
   130  
   131  func TestDaoWithContext(t *testing.T) {
   132  	that := assert.New(t)
   133  
   134  	// 生成DAO,自动创建dao结构体中的函数字段
   135  	dao := &personDao2{}
   136  	// Pass a context with a timeout to tell a blocking function that it
   137  	// should abandon its work after the timeout elapses.
   138  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
   139  	defer cancel()
   140  
   141  	sqx.DB, _ = openDB(t)
   142  	that.Nil(sqx.CreateDao(dao, sqx.WithCtx(ctx), sqx.WithLimit(1)))
   143  
   144  	dao.CreateTable()
   145  
   146  	// 多值插入
   147  	dao.AddAll(person{"300", 300}, person{"400", 400})
   148  
   149  	peoples := dao.ListAll()
   150  	that.Len(peoples, 1)
   151  }
   152  
   153  func TestDaoWithRowScanInterceptor(t *testing.T) {
   154  	that := assert.New(t)
   155  
   156  	// 生成DAO,自动创建dao结构体中的函数字段
   157  	dao := &personDao2{}
   158  	p := person{}
   159  	f := sqx.RowScanInterceptorFn(func(rowIndex int, v ...interface{}) (bool, error) {
   160  		p = v[0].(person)
   161  		return false, nil
   162  	})
   163  	sqx.DB, _ = openDB(t)
   164  	that.Nil(sqx.CreateDao(dao, sqx.WithRowScanInterceptor(f)))
   165  
   166  	dao.CreateTable()
   167  
   168  	// 多值插入
   169  	dao.AddAll(person{"300", 300}, person{"400", 400})
   170  
   171  	peoples := dao.ListAll()
   172  	that.Len(peoples, 0)
   173  	that.Equal(person{"300", 300}, p)
   174  }
   175  
   176  // personDao3 定义对person表操作的所有方法.
   177  type personDao3 struct {
   178  	CreateTable func()
   179  	AddAll      func(...person)
   180  	ListAll     func() []person
   181  	Logger      sqx.DaoLogger
   182  }
   183  
   184  const dotSQL = `
   185  -- name: CreateTable
   186  create table person(id varchar(100), age int);
   187  
   188  -- name: AddAll
   189  insert into person(id, age) values(:id, :age);
   190  
   191  -- name: ListAll delimiter: /
   192  /select id, age from person order by id//
   193  `
   194  
   195  func TestDaoWithDotSQLString(t *testing.T) {
   196  	that := assert.New(t)
   197  
   198  	// 生成DAO,自动创建dao结构体中的函数字段
   199  	dao := &personDao3{Logger: &sqx.DaoLog{}}
   200  	sqx.DB, _ = openDB(t)
   201  	that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQL)))
   202  
   203  	dao.CreateTable()
   204  
   205  	// 多值插入
   206  	dao.AddAll(person{"300", 300}, person{"400", 400})
   207  
   208  	// 列表
   209  	that.Equal([]person{{"300", 300}, {"400", 400}}, dao.ListAll())
   210  }
   211  
   212  func TestDaoWithDotSQLFile(t *testing.T) {
   213  	that := assert.New(t)
   214  
   215  	// 生成DAO,自动创建dao结构体中的函数字段
   216  	dao := &personDao3{}
   217  	sqx.DB, _ = openDB(t)
   218  	that.Nil(sqx.CreateDao(dao, sqx.WithSQLFile(`testdata/d3.sql`)))
   219  
   220  	dao.CreateTable()
   221  
   222  	// 多值插入
   223  	dao.AddAll(person{"300", 300}, person{"400", 400})
   224  
   225  	// 列表
   226  	that.Equal([]person{{"300", 300}, {"400", 400}}, dao.ListAll())
   227  }
   228  
   229  // person 结构体,对应到person表字段.
   230  type person4 struct {
   231  	ID  sql.NullString
   232  	Age int
   233  }
   234  
   235  // person 结构体,对应到person表字段.
   236  type person5 struct {
   237  	ID    string
   238  	Age   int
   239  	Other string
   240  	Addr  sql.NullString
   241  }
   242  
   243  type Addr string
   244  
   245  // person 结构体,对应到person表字段.
   246  type person6 struct {
   247  	Addr Addr
   248  }
   249  
   250  // personDao4 定义对person表操作的所有方法.
   251  type personDao4 struct {
   252  	CreateTable func()            `sql:"create table person(id varchar(100), age int, addr varchar(10) default 'bj')"`
   253  	AddID       func(int)         `sql:"insert into person(age, addr) values(:1, 'zags')"`
   254  	Add         func(M)           `sql:"insert into person(id, age, addr) values(:id, :age, :addr)"`
   255  	FindByAge1  func(int) person4 `sql:"select id, age, addr from person where age = :1"`
   256  	FindByAge2  func(int) person5 `sqlName:"FindByAge1"`
   257  	FindByAge3  func(int) person6 `sqlName:"FindByAge1"`
   258  }
   259  
   260  func TestNullString(t *testing.T) {
   261  	that := assert.New(t)
   262  
   263  	// 生成DAO,自动创建dao结构体中的函数字段
   264  	dao := &personDao4{}
   265  	sqx.DB, _ = openDB(t)
   266  	that.Nil(sqx.CreateDao(dao))
   267  
   268  	dao.CreateTable()
   269  	dao.AddID(100)
   270  	p1 := dao.FindByAge1(100)
   271  	that.Equal(person4{ID: sql.NullString{Valid: false}, Age: 100}, p1)
   272  
   273  	p2 := dao.FindByAge2(100)
   274  	that.Equal(person5{ID: "", Age: 100, Addr: sql.NullString{String: "zags", Valid: true}}, p2)
   275  
   276  	p3 := dao.FindByAge3(100)
   277  	that.Equal(person6{Addr: "zags"}, p3)
   278  }
   279  
   280  // personMap 结构体,对应到person表字段.
   281  type personMap struct {
   282  	ID   string
   283  	Age  int
   284  	Addr string
   285  }
   286  
   287  const dotSQLMap = `
   288  -- name: CreateTable
   289  create table person(id varchar(100), age int, addr varchar(10));
   290  
   291  -- name: Find
   292  select id, age, addr from person where id = :1;
   293  
   294  -- name: Add
   295  insert into person(id, age, addr) values(:id, :age, :addr);
   296  
   297  -- name: Count
   298  select count(*) from person
   299  
   300  -- name: Get2
   301  select id, addr from person where id=:1
   302  `
   303  
   304  // personDaoMap 定义对person表操作的所有方法.
   305  type personDaoMap struct {
   306  	CreateTable func()
   307  	Add         func(m M)
   308  	Find        func(id string) personMap
   309  	FindAsMap1  func(id string) map[string]string `sqlName:"Find"`
   310  	FindAsMap2  func(id string) M                 `sqlName:"Find"`
   311  
   312  	Count    func() (count int)
   313  	GetAddr  func(id string) (addr string) `sql:"select addr from person where id=:1"`
   314  	GetAddr2 func(id string) (addr Addr)   `sqlName:"GetAddr"`
   315  	Get2     func(id string) (pid, addr string)
   316  	Get3     func(id string) (pid string)                     `sqlName:"Get2"`
   317  	Get4     func(id string) (pid, addr, nos string, noi int) `sqlName:"Get2"`
   318  
   319  	Logger sqx.DaoLogger
   320  	Error  error
   321  }
   322  
   323  type M map[string]interface{}
   324  
   325  func TestMapArg(t *testing.T) {
   326  	that := assert.New(t)
   327  
   328  	// 生成DAO,自动创建dao结构体中的函数字段
   329  	dao := &personDaoMap{Logger: &sqx.DaoLog{}}
   330  	sqx.DB, _ = openDB(t)
   331  	that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQLMap)))
   332  
   333  	dao.CreateTable()
   334  	dao.Add(M{"id": "40685", "age": 500, "addr": "bjca"})
   335  	dao.Add(M{"id": "40686", "age": 600, "addr": nil})
   336  
   337  	that.Equal(personMap{ID: "40685", Age: 500, Addr: "bjca"}, dao.Find("40685"))
   338  	that.Equal(personMap{ID: "40686", Age: 600, Addr: ""}, dao.Find("40686"))
   339  
   340  	that.Equal(map[string]string{"id": "40685", "age": "500", "addr": "bjca"}, dao.FindAsMap1("40685"))
   341  	that.Equal(M{"id": "40685", "age": int64(500), "addr": "bjca"}, dao.FindAsMap2("40685"))
   342  
   343  	that.Equal(2, dao.Count())
   344  	that.Equal("bjca", dao.GetAddr("40685"))
   345  	that.Equal(Addr("bjca"), dao.GetAddr2("40685"))
   346  
   347  	id, addr := dao.Get2("40685")
   348  	that.Equal("40685", id)
   349  	that.Equal("bjca", addr)
   350  	that.Equal("40685", dao.Get3("40685"))
   351  
   352  	{
   353  		id, addr, other, otherInt := dao.Get4("40685")
   354  		that.Equal("40685", id)
   355  		that.Equal("bjca", addr)
   356  		that.Empty(other)
   357  		that.Zero(otherInt)
   358  	}
   359  
   360  	{
   361  		id, addr, other, otherInt := dao.Get4("50685")
   362  		that.Empty(id)
   363  		that.Empty(addr)
   364  		that.Empty(other)
   365  		that.Zero(otherInt)
   366  		that.Equal(sql.ErrNoRows, dao.Error)
   367  	}
   368  }
   369  
   370  const dotSQLDynamic = `
   371  -- name: CreateTable
   372  create table person(id varchar(100), age int, addr varchar(10));
   373  
   374  -- name: Add
   375  insert into person(id, age, addr) values(:id, :age, :addr);
   376  
   377  -- name: GetAddr
   378  select addr from person where id = :1 /* if _2 > 0 */ and age = :2 /* end */;
   379  
   380  -- name: GetAddr2
   381  select addr from person where id = :1
   382  -- if _2 > 0 
   383  and age = :2 
   384  -- end
   385  ;
   386  
   387  -- name: GetAddrStruct
   388  select addr from person where id = :id /* if age > 0 */ and age = :age /* end */;
   389  
   390  -- name: GetAddrStruct2
   391  select addr from person where id = :id
   392  -- if age > 0 
   393  and age = :age
   394  -- end
   395  ;
   396  `
   397  
   398  // personDaoDynamic 定义对person表操作的所有方法.
   399  type personDaoDynamic struct {
   400  	CreateTable func()
   401  	Add         func(m M)
   402  	GetAddr     func(id string, age int) (addr string)
   403  	GetAddr2    func(id string, age int) (addr string)
   404  
   405  	GetAddrStruct  func(p personMap) (addr string)
   406  	GetAddrStruct2 func(p personMap) (addr string)
   407  	GetAddrStruct3 func(p personMap) (addresses []string) `sqlName:"GetAddrStruct2"`
   408  
   409  	Logger sqx.DaoLogger
   410  	Error  error
   411  }
   412  
   413  func TestMapDynamic(t *testing.T) {
   414  	that := assert.New(t)
   415  
   416  	// 生成DAO,自动创建dao结构体中的函数字段
   417  	dao := &personDaoDynamic{Logger: &sqx.DaoLog{}}
   418  	sqx.DB, _ = openDB(t)
   419  	that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQLDynamic)))
   420  
   421  	dao.CreateTable()
   422  	dao.Add(M{"id": "40685", "age": 500, "addr": "bjca"})
   423  	dao.Add(M{"id": "40685", "age": 600, "addr": "acjb"})
   424  
   425  	that.Equal("bjca", dao.GetAddr("40685", 0))
   426  	that.Equal("acjb", dao.GetAddr("40685", 600))
   427  
   428  	that.Equal("bjca", dao.GetAddrStruct(personMap{ID: "40685"}))
   429  	that.Equal("acjb", dao.GetAddrStruct(personMap{ID: "40685", Age: 600}))
   430  
   431  	that.Equal("bjca", dao.GetAddr2("40685", 0))
   432  	that.Equal("acjb", dao.GetAddr2("40685", 600))
   433  
   434  	that.Equal("bjca", dao.GetAddrStruct2(personMap{ID: "40685"}))
   435  	that.Equal("acjb", dao.GetAddrStruct2(personMap{ID: "40685", Age: 600}))
   436  
   437  	that.Equal([]string{"bjca", "acjb"}, dao.GetAddrStruct3(personMap{ID: "40685"}))
   438  	that.Equal([]string{"acjb"}, dao.GetAddrStruct3(personMap{ID: "40685", Age: 600}))
   439  }
   440  
   441  const dotSQLLastInsertID = `
   442  -- name: CreateTable
   443  create table person(id INTEGER PRIMARY KEY AUTOINCREMENT, age int, addr varchar(10));
   444  
   445  -- name: Add
   446  insert into person( age, addr) values( :age, :addr);
   447  
   448  -- name: Update
   449  update person set age = :age, addr = :addr where id = :id;
   450  ;
   451  `
   452  
   453  type personLastInsertID struct {
   454  	CreateTable func()
   455  	Add         func(m M) (lastInsertID int)
   456  	Update      func(m M) (effectedRows int)
   457  	Update2     func(m M) (effectedRows, lastInsertID int) `sqlName:"Update"`
   458  	Logger      sqx.DaoLogger
   459  	Error       error
   460  }
   461  
   462  func TestLastInsertID(t *testing.T) {
   463  	that := assert.New(t)
   464  
   465  	// 生成DAO,自动创建dao结构体中的函数字段
   466  	dao := &personLastInsertID{Logger: &sqx.DaoLog{}}
   467  	db, _ := openDB(t)
   468  	that.Nil(sqx.CreateDao(dao, sqx.WithDB(db), sqx.WithSQLStr(dotSQLLastInsertID)))
   469  
   470  	dao.CreateTable()
   471  	that.True(dao.Add(M{"age": 500, "addr": "bjca"}) > 0)
   472  	lastInsertID := dao.Add(M{"age": 600, "addr": "acjb"})
   473  	that.True(lastInsertID > 0)
   474  	that.Equal(1, dao.Update(M{"id": lastInsertID, "age": 601, "addr": "xx"}))
   475  	effectedRows, insertID := dao.Update2(M{"id": lastInsertID, "age": 602, "addr": "yy"})
   476  	that.Equal(1, effectedRows)
   477  	that.Equal(lastInsertID, insertID)
   478  }
   479  
   480  type MyTime time.Time
   481  
   482  // Marshaler is the interface implemented by types that
   483  // can marshal themselves into valid JSON.
   484  func (m MyTime) MarshalJSON() ([]byte, error) {
   485  	s := (time.Time(m)).Format("2006-01-02 15:04:05")
   486  	return []byte(strconv.Quote(s)), nil
   487  }
   488  
   489  // Value - Implementation of valuer for database/sql.
   490  func (m MyTime) Value() (driver.Value, error) {
   491  	return driver.Value(time.Time(m)), nil
   492  }
   493  
   494  const dotSQLTime = `
   495  -- name: CreateTable
   496  create table person(id INTEGER PRIMARY KEY AUTOINCREMENT, age int, addr varchar(10), birthday timestamp);
   497  
   498  -- name: Add
   499  insert into person( age, addr, birthday ) values( :age, :addr, :birthday);
   500  
   501  -- name: Find
   502  select id, age, addr, birthday from person where id = :1;
   503  
   504  -- name: Delete
   505  delete from person  where id = ':';;
   506  `
   507  
   508  type personTime struct {
   509  	ID       uint64
   510  	Age      int
   511  	Addr     string
   512  	Birthday MyTime
   513  }
   514  
   515  type queryCond struct {
   516  	Addr string `sql:"addr like ?"`
   517  }
   518  type queryCond2 struct {
   519  	Addr  string    `sql:"addr like ?"`
   520  	Limit sqx.Limit `sql:"limit ?,?"`
   521  }
   522  
   523  type personTimeDao struct {
   524  	CreateTable func()
   525  	Add         func(m personTime) (lastInsertID uint64)
   526  	Find        func(uint64) (personTime, error)
   527  	Delete      func(uint64) (effectedRows int)
   528  
   529  	// 根据查询条件,查询
   530  	Query  func(queryCond) ([]personTime, error)             `sql:"select id, age, addr, birthday from person"`
   531  	Query2 func(queryCond2) ([]personTime, sqx.Count, error) `sql:"select id, age, addr, birthday from person"`
   532  
   533  	Logger sqx.DaoLogger
   534  	Error  error
   535  }
   536  
   537  func TestTime(t *testing.T) {
   538  	that := assert.New(t)
   539  
   540  	// 生成DAO,自动创建dao结构体中的函数字段
   541  	dao := &personTimeDao{Logger: &sqx.DaoLog{}}
   542  	db, _ := openDB(t)
   543  	that.Nil(sqx.CreateDao(dao, sqx.WithDB(db), sqx.WithSQLStr(dotSQLTime)))
   544  
   545  	now, _ := time.Parse("2006-01-02 15:04:05", "2020-03-31 17:17:40")
   546  	dao.CreateTable()
   547  
   548  	p := personTime{
   549  		Age:      10,
   550  		Addr:     "aa",
   551  		Birthday: MyTime(now),
   552  	}
   553  
   554  	lastInsertID := dao.Add(p)
   555  	p.ID = lastInsertID
   556  
   557  	p2, err := dao.Find(lastInsertID)
   558  	that.Equal(p, p2)
   559  	that.Nil(err)
   560  
   561  	persons, err := dao.Query(queryCond{})
   562  	that.Nil(err)
   563  	that.Equal([]personTime{p}, persons)
   564  
   565  	persons, err = dao.Query(queryCond{Addr: "%a%"})
   566  	that.Nil(err)
   567  	that.Equal([]personTime{p}, persons)
   568  
   569  	var count sqx.Count
   570  	persons, count, err = dao.Query2(queryCond2{Addr: "%a%", Limit: sqx.Limit{Length: 10}})
   571  	that.Nil(err)
   572  	that.Equal([]personTime{p}, persons)
   573  	that.Equal(sqx.Count(1), count)
   574  
   575  	effectedRows := dao.Delete(lastInsertID)
   576  	that.Equal(1, effectedRows)
   577  }