github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dao_test.go (about) 1 package sqx_test 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "strconv" 8 "testing" 9 "time" 10 11 "github.com/bingoohuang/gg/pkg/sqx" 12 "github.com/stretchr/testify/assert" 13 ) 14 15 // person 结构体,对应到person表字段. 16 type person struct { 17 ID string 18 Age int 19 } 20 21 // personDao 定义对person表操作的所有方法. 22 type personDao struct { 23 CreateTable func() `sql:"create table person(id varchar(100), age int)"` 24 Add func(person) `sql:"insert into person(id, age) values(:id, :age)"` 25 AddAll func(...person) `sql:"insert into person(id, age) values(:id, :age)"` 26 AddAll2 func(...person) uint64 `sql:"insert into person(id, age) values(:id, :age)"` 27 AddAll3 func(...person) int `sql:"insert into person(id, age) values(:id, :age)"` 28 AddAll4 func(...person) (int, error) `sql:"insert into person(id, age) values(:id, :age)"` 29 Find func(id string) person `sql:"select id, age from person where id=:1"` 30 Find2 func(id string) *person `sql:"select id, age from person where id=:1"` 31 ListAll func() []person `sql:"select id, age from person"` 32 ListByID func(string) []person `sql:"select id, age from person where id=:1"` 33 Delete func(string) int `sql:"delete from person where id=:1"` 34 GetAge func(string) struct{ Age int } `sql:"select age from person where id=:1"` 35 } 36 37 func TestDao(t *testing.T) { 38 that := assert.New(t) 39 40 // 生成DAO,自动创建dao结构体中的函数字段 41 dao := &personDao{} 42 rawDB, _ := openDB(t) 43 sqx.DB = rawDB 44 that.Nil(sqx.CreateDao(dao)) 45 46 // 建表 47 dao.CreateTable() 48 // 插入 49 dao.Add(person{"100", 100}) 50 // 查找 51 that.Equal(person{"100", 100}, dao.Find("100")) 52 that.Equal(person{"100", 100}, *dao.Find2("100")) 53 // 刪除 54 that.Equal(1, dao.Delete("100")) 55 // 再找,找不到,返回零值 56 that.Zero(dao.Find("100")) 57 // 再找,找不到,返回零值 58 v := dao.Find2("100") 59 that.Nil(v) 60 // 插入 61 dao.Add(person{"200", 200}) 62 // 多值插入 63 dao.AddAll(person{"300", 300}, person{"400", 400}) 64 // 列表 65 that.Equal([]person{{"200", 200}, {"300", 300}, {"400", 400}}, dao.ListAll()) 66 // 条件列表 67 that.Equal([]person{{"200", 200}}, dao.ListByID("200")) 68 // 匿名结构 69 that.Equal(struct{ Age int }{Age: 200}, dao.GetAge("200")) 70 71 // 多值插入 72 lastInsertID2 := dao.AddAll2(person{"500", 500}, person{"600", 600}) 73 assert.True(t, lastInsertID2 > 0) 74 75 lastInsertID3 := dao.AddAll3(person{"700", 500}, person{"701", 600}) 76 assert.True(t, lastInsertID3 > 0) 77 78 lastInsertID4, err := dao.AddAll4(person{"900", 500}, person{"702", 600}) 79 assert.True(t, lastInsertID4 > 0) 80 assert.Nil(t, err) 81 } 82 83 func openDB(t *testing.T) (*sql.DB, *sqx.Sqx) { 84 // 创建数据库连接池 85 rawDb, db, err := sqx.Open("sqlite3", ":memory:") 86 assert.Nil(t, err) 87 88 return rawDb, db 89 } 90 91 // personDao2 定义对person表操作的所有方法. 92 type personDao2 struct { 93 CreateTable func() `sql:"create table person(id varchar(100), age int)"` 94 AddAll func(...person) `sql:"insert into person(id, age) values(:id, :age)"` 95 96 GetAgeE func(string) (struct{ Age int }, error) `sql:"select age from person where xid=:1"` 97 GetAgeX func(string) person `sql:"select age from person where xid=:1"` 98 99 Err error // 添加这个字段,可以用来单独接收error信息 100 101 ListAll func() []person `sql:"select id, age from person order by id"` 102 } 103 104 func TestDaoWithError(t *testing.T) { 105 that := assert.New(t) 106 107 // 生成DAO,自动创建dao结构体中的函数字段 108 dao := &personDao2{} 109 110 var err error 111 112 sqx.DB, _ = openDB(t) 113 that.Nil(sqx.CreateDao(dao, sqx.WithError(&err))) 114 115 dao.CreateTable() 116 117 that.Nil(dao.Err) 118 ageX, err := dao.GetAgeE("200") 119 that.Error(err) 120 that.Zero(ageX) 121 that.Error(dao.Err) 122 123 // 条件列表 124 dao.AddAll(person{"200", 200}) 125 that.Nil(dao.Err) // 验证Err字段是否重置 126 127 that.Zero(dao.GetAgeX("100")) 128 that.Error(err) 129 } 130 131 func TestDaoWithContext(t *testing.T) { 132 that := assert.New(t) 133 134 // 生成DAO,自动创建dao结构体中的函数字段 135 dao := &personDao2{} 136 // Pass a context with a timeout to tell a blocking function that it 137 // should abandon its work after the timeout elapses. 138 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 139 defer cancel() 140 141 sqx.DB, _ = openDB(t) 142 that.Nil(sqx.CreateDao(dao, sqx.WithCtx(ctx), sqx.WithLimit(1))) 143 144 dao.CreateTable() 145 146 // 多值插入 147 dao.AddAll(person{"300", 300}, person{"400", 400}) 148 149 peoples := dao.ListAll() 150 that.Len(peoples, 1) 151 } 152 153 func TestDaoWithRowScanInterceptor(t *testing.T) { 154 that := assert.New(t) 155 156 // 生成DAO,自动创建dao结构体中的函数字段 157 dao := &personDao2{} 158 p := person{} 159 f := sqx.RowScanInterceptorFn(func(rowIndex int, v ...interface{}) (bool, error) { 160 p = v[0].(person) 161 return false, nil 162 }) 163 sqx.DB, _ = openDB(t) 164 that.Nil(sqx.CreateDao(dao, sqx.WithRowScanInterceptor(f))) 165 166 dao.CreateTable() 167 168 // 多值插入 169 dao.AddAll(person{"300", 300}, person{"400", 400}) 170 171 peoples := dao.ListAll() 172 that.Len(peoples, 0) 173 that.Equal(person{"300", 300}, p) 174 } 175 176 // personDao3 定义对person表操作的所有方法. 177 type personDao3 struct { 178 CreateTable func() 179 AddAll func(...person) 180 ListAll func() []person 181 Logger sqx.DaoLogger 182 } 183 184 const dotSQL = ` 185 -- name: CreateTable 186 create table person(id varchar(100), age int); 187 188 -- name: AddAll 189 insert into person(id, age) values(:id, :age); 190 191 -- name: ListAll delimiter: / 192 /select id, age from person order by id// 193 ` 194 195 func TestDaoWithDotSQLString(t *testing.T) { 196 that := assert.New(t) 197 198 // 生成DAO,自动创建dao结构体中的函数字段 199 dao := &personDao3{Logger: &sqx.DaoLog{}} 200 sqx.DB, _ = openDB(t) 201 that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQL))) 202 203 dao.CreateTable() 204 205 // 多值插入 206 dao.AddAll(person{"300", 300}, person{"400", 400}) 207 208 // 列表 209 that.Equal([]person{{"300", 300}, {"400", 400}}, dao.ListAll()) 210 } 211 212 func TestDaoWithDotSQLFile(t *testing.T) { 213 that := assert.New(t) 214 215 // 生成DAO,自动创建dao结构体中的函数字段 216 dao := &personDao3{} 217 sqx.DB, _ = openDB(t) 218 that.Nil(sqx.CreateDao(dao, sqx.WithSQLFile(`testdata/d3.sql`))) 219 220 dao.CreateTable() 221 222 // 多值插入 223 dao.AddAll(person{"300", 300}, person{"400", 400}) 224 225 // 列表 226 that.Equal([]person{{"300", 300}, {"400", 400}}, dao.ListAll()) 227 } 228 229 // person 结构体,对应到person表字段. 230 type person4 struct { 231 ID sql.NullString 232 Age int 233 } 234 235 // person 结构体,对应到person表字段. 236 type person5 struct { 237 ID string 238 Age int 239 Other string 240 Addr sql.NullString 241 } 242 243 type Addr string 244 245 // person 结构体,对应到person表字段. 246 type person6 struct { 247 Addr Addr 248 } 249 250 // personDao4 定义对person表操作的所有方法. 251 type personDao4 struct { 252 CreateTable func() `sql:"create table person(id varchar(100), age int, addr varchar(10) default 'bj')"` 253 AddID func(int) `sql:"insert into person(age, addr) values(:1, 'zags')"` 254 Add func(M) `sql:"insert into person(id, age, addr) values(:id, :age, :addr)"` 255 FindByAge1 func(int) person4 `sql:"select id, age, addr from person where age = :1"` 256 FindByAge2 func(int) person5 `sqlName:"FindByAge1"` 257 FindByAge3 func(int) person6 `sqlName:"FindByAge1"` 258 } 259 260 func TestNullString(t *testing.T) { 261 that := assert.New(t) 262 263 // 生成DAO,自动创建dao结构体中的函数字段 264 dao := &personDao4{} 265 sqx.DB, _ = openDB(t) 266 that.Nil(sqx.CreateDao(dao)) 267 268 dao.CreateTable() 269 dao.AddID(100) 270 p1 := dao.FindByAge1(100) 271 that.Equal(person4{ID: sql.NullString{Valid: false}, Age: 100}, p1) 272 273 p2 := dao.FindByAge2(100) 274 that.Equal(person5{ID: "", Age: 100, Addr: sql.NullString{String: "zags", Valid: true}}, p2) 275 276 p3 := dao.FindByAge3(100) 277 that.Equal(person6{Addr: "zags"}, p3) 278 } 279 280 // personMap 结构体,对应到person表字段. 281 type personMap struct { 282 ID string 283 Age int 284 Addr string 285 } 286 287 const dotSQLMap = ` 288 -- name: CreateTable 289 create table person(id varchar(100), age int, addr varchar(10)); 290 291 -- name: Find 292 select id, age, addr from person where id = :1; 293 294 -- name: Add 295 insert into person(id, age, addr) values(:id, :age, :addr); 296 297 -- name: Count 298 select count(*) from person 299 300 -- name: Get2 301 select id, addr from person where id=:1 302 ` 303 304 // personDaoMap 定义对person表操作的所有方法. 305 type personDaoMap struct { 306 CreateTable func() 307 Add func(m M) 308 Find func(id string) personMap 309 FindAsMap1 func(id string) map[string]string `sqlName:"Find"` 310 FindAsMap2 func(id string) M `sqlName:"Find"` 311 312 Count func() (count int) 313 GetAddr func(id string) (addr string) `sql:"select addr from person where id=:1"` 314 GetAddr2 func(id string) (addr Addr) `sqlName:"GetAddr"` 315 Get2 func(id string) (pid, addr string) 316 Get3 func(id string) (pid string) `sqlName:"Get2"` 317 Get4 func(id string) (pid, addr, nos string, noi int) `sqlName:"Get2"` 318 319 Logger sqx.DaoLogger 320 Error error 321 } 322 323 type M map[string]interface{} 324 325 func TestMapArg(t *testing.T) { 326 that := assert.New(t) 327 328 // 生成DAO,自动创建dao结构体中的函数字段 329 dao := &personDaoMap{Logger: &sqx.DaoLog{}} 330 sqx.DB, _ = openDB(t) 331 that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQLMap))) 332 333 dao.CreateTable() 334 dao.Add(M{"id": "40685", "age": 500, "addr": "bjca"}) 335 dao.Add(M{"id": "40686", "age": 600, "addr": nil}) 336 337 that.Equal(personMap{ID: "40685", Age: 500, Addr: "bjca"}, dao.Find("40685")) 338 that.Equal(personMap{ID: "40686", Age: 600, Addr: ""}, dao.Find("40686")) 339 340 that.Equal(map[string]string{"id": "40685", "age": "500", "addr": "bjca"}, dao.FindAsMap1("40685")) 341 that.Equal(M{"id": "40685", "age": int64(500), "addr": "bjca"}, dao.FindAsMap2("40685")) 342 343 that.Equal(2, dao.Count()) 344 that.Equal("bjca", dao.GetAddr("40685")) 345 that.Equal(Addr("bjca"), dao.GetAddr2("40685")) 346 347 id, addr := dao.Get2("40685") 348 that.Equal("40685", id) 349 that.Equal("bjca", addr) 350 that.Equal("40685", dao.Get3("40685")) 351 352 { 353 id, addr, other, otherInt := dao.Get4("40685") 354 that.Equal("40685", id) 355 that.Equal("bjca", addr) 356 that.Empty(other) 357 that.Zero(otherInt) 358 } 359 360 { 361 id, addr, other, otherInt := dao.Get4("50685") 362 that.Empty(id) 363 that.Empty(addr) 364 that.Empty(other) 365 that.Zero(otherInt) 366 that.Equal(sql.ErrNoRows, dao.Error) 367 } 368 } 369 370 const dotSQLDynamic = ` 371 -- name: CreateTable 372 create table person(id varchar(100), age int, addr varchar(10)); 373 374 -- name: Add 375 insert into person(id, age, addr) values(:id, :age, :addr); 376 377 -- name: GetAddr 378 select addr from person where id = :1 /* if _2 > 0 */ and age = :2 /* end */; 379 380 -- name: GetAddr2 381 select addr from person where id = :1 382 -- if _2 > 0 383 and age = :2 384 -- end 385 ; 386 387 -- name: GetAddrStruct 388 select addr from person where id = :id /* if age > 0 */ and age = :age /* end */; 389 390 -- name: GetAddrStruct2 391 select addr from person where id = :id 392 -- if age > 0 393 and age = :age 394 -- end 395 ; 396 ` 397 398 // personDaoDynamic 定义对person表操作的所有方法. 399 type personDaoDynamic struct { 400 CreateTable func() 401 Add func(m M) 402 GetAddr func(id string, age int) (addr string) 403 GetAddr2 func(id string, age int) (addr string) 404 405 GetAddrStruct func(p personMap) (addr string) 406 GetAddrStruct2 func(p personMap) (addr string) 407 GetAddrStruct3 func(p personMap) (addresses []string) `sqlName:"GetAddrStruct2"` 408 409 Logger sqx.DaoLogger 410 Error error 411 } 412 413 func TestMapDynamic(t *testing.T) { 414 that := assert.New(t) 415 416 // 生成DAO,自动创建dao结构体中的函数字段 417 dao := &personDaoDynamic{Logger: &sqx.DaoLog{}} 418 sqx.DB, _ = openDB(t) 419 that.Nil(sqx.CreateDao(dao, sqx.WithSQLStr(dotSQLDynamic))) 420 421 dao.CreateTable() 422 dao.Add(M{"id": "40685", "age": 500, "addr": "bjca"}) 423 dao.Add(M{"id": "40685", "age": 600, "addr": "acjb"}) 424 425 that.Equal("bjca", dao.GetAddr("40685", 0)) 426 that.Equal("acjb", dao.GetAddr("40685", 600)) 427 428 that.Equal("bjca", dao.GetAddrStruct(personMap{ID: "40685"})) 429 that.Equal("acjb", dao.GetAddrStruct(personMap{ID: "40685", Age: 600})) 430 431 that.Equal("bjca", dao.GetAddr2("40685", 0)) 432 that.Equal("acjb", dao.GetAddr2("40685", 600)) 433 434 that.Equal("bjca", dao.GetAddrStruct2(personMap{ID: "40685"})) 435 that.Equal("acjb", dao.GetAddrStruct2(personMap{ID: "40685", Age: 600})) 436 437 that.Equal([]string{"bjca", "acjb"}, dao.GetAddrStruct3(personMap{ID: "40685"})) 438 that.Equal([]string{"acjb"}, dao.GetAddrStruct3(personMap{ID: "40685", Age: 600})) 439 } 440 441 const dotSQLLastInsertID = ` 442 -- name: CreateTable 443 create table person(id INTEGER PRIMARY KEY AUTOINCREMENT, age int, addr varchar(10)); 444 445 -- name: Add 446 insert into person( age, addr) values( :age, :addr); 447 448 -- name: Update 449 update person set age = :age, addr = :addr where id = :id; 450 ; 451 ` 452 453 type personLastInsertID struct { 454 CreateTable func() 455 Add func(m M) (lastInsertID int) 456 Update func(m M) (effectedRows int) 457 Update2 func(m M) (effectedRows, lastInsertID int) `sqlName:"Update"` 458 Logger sqx.DaoLogger 459 Error error 460 } 461 462 func TestLastInsertID(t *testing.T) { 463 that := assert.New(t) 464 465 // 生成DAO,自动创建dao结构体中的函数字段 466 dao := &personLastInsertID{Logger: &sqx.DaoLog{}} 467 db, _ := openDB(t) 468 that.Nil(sqx.CreateDao(dao, sqx.WithDB(db), sqx.WithSQLStr(dotSQLLastInsertID))) 469 470 dao.CreateTable() 471 that.True(dao.Add(M{"age": 500, "addr": "bjca"}) > 0) 472 lastInsertID := dao.Add(M{"age": 600, "addr": "acjb"}) 473 that.True(lastInsertID > 0) 474 that.Equal(1, dao.Update(M{"id": lastInsertID, "age": 601, "addr": "xx"})) 475 effectedRows, insertID := dao.Update2(M{"id": lastInsertID, "age": 602, "addr": "yy"}) 476 that.Equal(1, effectedRows) 477 that.Equal(lastInsertID, insertID) 478 } 479 480 type MyTime time.Time 481 482 // Marshaler is the interface implemented by types that 483 // can marshal themselves into valid JSON. 484 func (m MyTime) MarshalJSON() ([]byte, error) { 485 s := (time.Time(m)).Format("2006-01-02 15:04:05") 486 return []byte(strconv.Quote(s)), nil 487 } 488 489 // Value - Implementation of valuer for database/sql. 490 func (m MyTime) Value() (driver.Value, error) { 491 return driver.Value(time.Time(m)), nil 492 } 493 494 const dotSQLTime = ` 495 -- name: CreateTable 496 create table person(id INTEGER PRIMARY KEY AUTOINCREMENT, age int, addr varchar(10), birthday timestamp); 497 498 -- name: Add 499 insert into person( age, addr, birthday ) values( :age, :addr, :birthday); 500 501 -- name: Find 502 select id, age, addr, birthday from person where id = :1; 503 504 -- name: Delete 505 delete from person where id = ':';; 506 ` 507 508 type personTime struct { 509 ID uint64 510 Age int 511 Addr string 512 Birthday MyTime 513 } 514 515 type queryCond struct { 516 Addr string `sql:"addr like ?"` 517 } 518 type queryCond2 struct { 519 Addr string `sql:"addr like ?"` 520 Limit sqx.Limit `sql:"limit ?,?"` 521 } 522 523 type personTimeDao struct { 524 CreateTable func() 525 Add func(m personTime) (lastInsertID uint64) 526 Find func(uint64) (personTime, error) 527 Delete func(uint64) (effectedRows int) 528 529 // 根据查询条件,查询 530 Query func(queryCond) ([]personTime, error) `sql:"select id, age, addr, birthday from person"` 531 Query2 func(queryCond2) ([]personTime, sqx.Count, error) `sql:"select id, age, addr, birthday from person"` 532 533 Logger sqx.DaoLogger 534 Error error 535 } 536 537 func TestTime(t *testing.T) { 538 that := assert.New(t) 539 540 // 生成DAO,自动创建dao结构体中的函数字段 541 dao := &personTimeDao{Logger: &sqx.DaoLog{}} 542 db, _ := openDB(t) 543 that.Nil(sqx.CreateDao(dao, sqx.WithDB(db), sqx.WithSQLStr(dotSQLTime))) 544 545 now, _ := time.Parse("2006-01-02 15:04:05", "2020-03-31 17:17:40") 546 dao.CreateTable() 547 548 p := personTime{ 549 Age: 10, 550 Addr: "aa", 551 Birthday: MyTime(now), 552 } 553 554 lastInsertID := dao.Add(p) 555 p.ID = lastInsertID 556 557 p2, err := dao.Find(lastInsertID) 558 that.Equal(p, p2) 559 that.Nil(err) 560 561 persons, err := dao.Query(queryCond{}) 562 that.Nil(err) 563 that.Equal([]personTime{p}, persons) 564 565 persons, err = dao.Query(queryCond{Addr: "%a%"}) 566 that.Nil(err) 567 that.Equal([]personTime{p}, persons) 568 569 var count sqx.Count 570 persons, count, err = dao.Query2(queryCond2{Addr: "%a%", Limit: sqx.Limit{Length: 10}}) 571 that.Nil(err) 572 that.Equal([]personTime{p}, persons) 573 that.Equal(sqx.Count(1), count) 574 575 effectedRows := dao.Delete(lastInsertID) 576 that.Equal(1, effectedRows) 577 }