github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/valuer/unsafe_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package valuer 16 17 import ( 18 "database/sql" 19 "database/sql/driver" 20 "errors" 21 "testing" 22 "time" 23 24 "github.com/DATA-DOG/go-sqlmock" 25 "github.com/ecodeclub/eorm/internal/errs" 26 "github.com/ecodeclub/eorm/internal/model" 27 "github.com/ecodeclub/eorm/internal/test" 28 "github.com/stretchr/testify/assert" 29 ) 30 31 func Test_unsafeValue_Field(t *testing.T) { 32 testValueField(t, NewUnsafeValue) 33 } 34 35 func Test_unsafeValue_SetColumn(t *testing.T) { 36 testSetColumn(t, NewUnsafeValue) 37 } 38 39 func testSetColumn(t *testing.T, creator Creator) { 40 r := model.NewMetaRegistry() 41 t.Run("types", func(t *testing.T) { 42 testCases := []struct { 43 name string 44 cs map[string][]byte 45 val *test.SimpleStruct 46 wantVal *test.SimpleStruct 47 wantErr error 48 }{ 49 { 50 name: "normal value", 51 cs: map[string][]byte{ 52 "id": []byte("1"), 53 "bool": []byte("true"), 54 "bool_ptr": []byte("false"), 55 "int": []byte("12"), 56 "int_ptr": []byte("13"), 57 "int8": []byte("8"), 58 "int8_ptr": []byte("-8"), 59 "int16": []byte("16"), 60 "int16_ptr": []byte("-16"), 61 "int32": []byte("32"), 62 "int32_ptr": []byte("-32"), 63 "int64": []byte("64"), 64 "int64_ptr": []byte("-64"), 65 "uint": []byte("14"), 66 "uint_ptr": []byte("15"), 67 "uint8": []byte("8"), 68 "uint8_ptr": []byte("18"), 69 "uint16": []byte("16"), 70 "uint16_ptr": []byte("116"), 71 "uint32": []byte("32"), 72 "uint32_ptr": []byte("132"), 73 "uint64": []byte("64"), 74 "uint64_ptr": []byte("164"), 75 "float32": []byte("3.2"), 76 "float32_ptr": []byte("-3.2"), 77 "float64": []byte("6.4"), 78 "float64_ptr": []byte("-6.4"), 79 "byte_array": []byte("hello"), 80 "string": []byte("world"), 81 "null_string_ptr": []byte("null string"), 82 "null_int16_ptr": []byte("16"), 83 "null_int32_ptr": []byte("32"), 84 "null_int64_ptr": []byte("64"), 85 "null_bool_ptr": []byte("true"), 86 "null_float64_ptr": []byte("6.4"), 87 "json_column": []byte(`{"name": "Tom"}`), 88 }, 89 val: &test.SimpleStruct{}, 90 wantVal: test.NewSimpleStruct(1), 91 }, 92 { 93 name: "invalid field", 94 cs: map[string][]byte{ 95 "invalid_column": nil, 96 }, 97 wantErr: errs.NewInvalidColumnError("invalid_column"), 98 }, 99 } 100 101 meta, err := r.Get(&test.SimpleStruct{}) 102 if err != nil { 103 t.Fatal(err) 104 } 105 106 for _, tc := range testCases { 107 t.Run(tc.name, func(t *testing.T) { 108 db, mock, err := sqlmock.New() 109 if err != nil { 110 t.Fatal(err) 111 } 112 defer func() { _ = db.Close() }() 113 val := creator(tc.val, meta) 114 cols := make([]string, 0, len(tc.cs)) 115 colVals := make([]driver.Value, 0, len(tc.cs)) 116 for k, v := range tc.cs { 117 cols = append(cols, k) 118 colVals = append(colVals, v) 119 } 120 mock.ExpectQuery("SELECT *"). 121 WillReturnRows(sqlmock.NewRows(cols). 122 AddRow(colVals...)) 123 rows, _ := db.Query("SELECT *") 124 rows.Next() 125 err = val.SetColumns(rows) 126 if err != nil { 127 assert.Equal(t, tc.wantErr, err) 128 return 129 } 130 if tc.wantErr != nil { 131 t.Fatalf("期望得到错误,但是并没有得到 %v", tc.wantErr) 132 } 133 assert.Equal(t, tc.wantVal, tc.val) 134 }) 135 } 136 }) 137 138 type User struct { 139 Name string 140 } 141 142 // 测试行数匹配不上,或者无法读取数据的场景 143 t.Run("invalid rows", func(t *testing.T) { 144 db, mock, err := sqlmock.New() 145 if err != nil { 146 t.Fatal(err) 147 } 148 defer func() { _ = db.Close() }() 149 150 u := &User{} 151 meta, err := r.Get(u) 152 if err != nil { 153 t.Fatal(err) 154 } 155 val := creator(u, meta) 156 // 多了一个列 157 mock.ExpectQuery("SELECT *"). 158 WillReturnRows(sqlmock.NewRows([]string{"ID", "Name"}). 159 AddRow(123, "Tom")) 160 rows, _ := db.Query("SELECT *") 161 rows.Next() 162 err = val.SetColumns(rows) 163 assert.Equal(t, errs.ErrTooManyColumns, err) 164 165 // 读取列错误 166 mock.ExpectQuery("SELECT *"). 167 WillReturnRows(sqlmock.NewRows([]string{"Name"})) 168 rows, _ = db.Query("SELECT *") 169 rows.Next() 170 err = val.SetColumns(rows) 171 assert.Equal(t, errors.New("sql: Rows are closed"), err) 172 }) 173 174 type BaseEntity struct { 175 Id int64 `eorm:"primary_key"` 176 CreateTime uint64 177 } 178 type CombinedUser struct { 179 BaseEntity 180 FirstName string 181 } 182 183 // 测试使用组合的场景 184 t.Run("combination", func(t *testing.T) { 185 db, mock, err := sqlmock.New() 186 if err != nil { 187 t.Fatal(err) 188 } 189 defer func() { _ = db.Close() }() 190 191 u := &CombinedUser{} 192 meta, err := r.Get(u) 193 if err != nil { 194 t.Fatal(err) 195 } 196 val := creator(u, meta) 197 // 多了一个列 198 mock.ExpectQuery("SELECT *"). 199 WillReturnRows(sqlmock.NewRows([]string{"id", "create_time", "first_name"}). 200 AddRow(123, 100000, "Tom")) 201 rows, _ := db.Query("SELECT *") 202 rows.Next() 203 err = val.SetColumns(rows) 204 assert.NoError(t, err) 205 wantUser := &CombinedUser{ 206 BaseEntity: BaseEntity{ 207 Id: 123, 208 CreateTime: 100000, 209 }, 210 FirstName: "Tom", 211 } 212 assert.Equal(t, wantUser, u) 213 }) 214 } 215 216 func testValueField(t *testing.T, creator Creator) { 217 meta, err := model.NewMetaRegistry().Get(&test.SimpleStruct{}) 218 if err != nil { 219 t.Fatal(err) 220 } 221 t.Run("zero value", func(t *testing.T) { 222 entity := &test.SimpleStruct{} 223 testCases := newValueFieldTestCases(entity) 224 val := creator(entity, meta) 225 for _, tc := range testCases { 226 t.Run(tc.name, func(t *testing.T) { 227 v, err := val.Field(tc.field) 228 assert.Equal(t, tc.wantError, err) 229 if err != nil { 230 return 231 } 232 assert.Equal(t, tc.wantVal, v.Interface()) 233 }) 234 } 235 }) 236 t.Run("normal value", func(t *testing.T) { 237 entity := test.NewSimpleStruct(1) 238 testCases := newValueFieldTestCases(entity) 239 val := NewUnsafeValue(entity, meta) 240 for _, tc := range testCases { 241 t.Run(tc.name, func(t *testing.T) { 242 v, err := val.Field(tc.field) 243 assert.Equal(t, tc.wantError, err) 244 if err != nil { 245 return 246 } 247 assert.Equal(t, tc.wantVal, v.Interface()) 248 }) 249 } 250 }) 251 252 type User struct { 253 CreateTime time.Time 254 Name *string 255 } 256 257 invalidCases := []valueFieldTestCase{ 258 { 259 // 不存在的字段 260 name: "invalid field", 261 field: "UpdateTime", 262 wantError: errs.NewInvalidFieldError("UpdateTime"), 263 }, 264 } 265 t.Run("invalid cases", func(t *testing.T) { 266 meta, err := model.NewMetaRegistry().Get(&User{}) 267 if err != nil { 268 t.Fatal(err) 269 } 270 271 val := NewUnsafeValue(&User{}, meta) 272 for _, tc := range invalidCases { 273 t.Run(tc.name, func(t *testing.T) { 274 v, err := val.Field(tc.field) 275 assert.Equal(t, tc.wantError, err) 276 if err != nil { 277 return 278 } 279 assert.Equal(t, tc.wantVal, v.Interface()) 280 }) 281 } 282 }) 283 284 type BaseEntity struct { 285 Id int64 `eorm:"primary_key"` 286 CreateTime uint64 287 } 288 type CombinedUser struct { 289 BaseEntity 290 FirstName string 291 } 292 293 cUser := &CombinedUser{} 294 testCases := []valueFieldTestCase{ 295 { 296 name: "id", 297 field: "Id", 298 wantVal: cUser.Id, 299 }, 300 { 301 name: "CreateTime", 302 field: "CreateTime", 303 wantVal: cUser.CreateTime, 304 }, 305 { 306 name: "FirstName", 307 field: "FirstName", 308 wantVal: cUser.FirstName, 309 }, 310 } 311 // 测试使用组合的场景 312 t.Run("combination cases", func(t *testing.T) { 313 meta, err = model.NewMetaRegistry().Get(cUser) 314 if err != nil { 315 t.Fatal(err) 316 } 317 318 val := NewUnsafeValue(cUser, meta) 319 for _, tc := range testCases { 320 t.Run(tc.name, func(t *testing.T) { 321 v, err := val.Field(tc.field) 322 assert.Equal(t, tc.wantError, err) 323 if err != nil { 324 return 325 } 326 assert.Equal(t, tc.wantVal, v.Interface()) 327 }) 328 } 329 }) 330 } 331 332 func newValueFieldTestCases(entity *test.SimpleStruct) []valueFieldTestCase { 333 return []valueFieldTestCase{ 334 { 335 name: "bool", 336 field: "Bool", 337 wantVal: entity.Bool, 338 }, 339 { 340 // bool 指针类型 341 name: "bool pointer", 342 field: "BoolPtr", 343 wantVal: entity.BoolPtr, 344 }, 345 { 346 name: "int", 347 field: "Int", 348 wantVal: entity.Int, 349 }, 350 { 351 // int 指针类型 352 name: "int pointer", 353 field: "IntPtr", 354 wantVal: entity.IntPtr, 355 }, 356 { 357 name: "int8", 358 field: "Int8", 359 wantVal: entity.Int8, 360 }, 361 { 362 name: "int8 pointer", 363 field: "Int8Ptr", 364 wantVal: entity.Int8Ptr, 365 }, 366 { 367 name: "int16", 368 field: "Int16", 369 wantVal: entity.Int16, 370 }, 371 { 372 name: "int16 pointer", 373 field: "Int16Ptr", 374 wantVal: entity.Int16Ptr, 375 }, 376 { 377 name: "int32", 378 field: "Int32", 379 wantVal: entity.Int32, 380 }, 381 { 382 name: "int32 pointer", 383 field: "Int32Ptr", 384 wantVal: entity.Int32Ptr, 385 }, 386 { 387 name: "int64", 388 field: "Int64", 389 wantVal: entity.Int64, 390 }, 391 { 392 name: "int64 pointer", 393 field: "Int64Ptr", 394 wantVal: entity.Int64Ptr, 395 }, 396 { 397 name: "uint", 398 field: "Uint", 399 wantVal: entity.Uint, 400 }, 401 { 402 name: "uint pointer", 403 field: "UintPtr", 404 wantVal: entity.UintPtr, 405 }, 406 { 407 name: "uint8", 408 field: "Uint8", 409 wantVal: entity.Uint8, 410 }, 411 { 412 name: "uint8 pointer", 413 field: "Uint8Ptr", 414 wantVal: entity.Uint8Ptr, 415 }, 416 { 417 name: "uint16", 418 field: "Uint16", 419 wantVal: entity.Uint16, 420 }, 421 { 422 name: "uint16 pointer", 423 field: "Uint16Ptr", 424 wantVal: entity.Uint16Ptr, 425 }, 426 { 427 name: "uint32", 428 field: "Uint32", 429 wantVal: entity.Uint32, 430 }, 431 { 432 name: "uint32 pointer", 433 field: "Uint32Ptr", 434 wantVal: entity.Uint32Ptr, 435 }, 436 { 437 name: "uint64", 438 field: "Uint64", 439 wantVal: entity.Uint64, 440 }, 441 { 442 name: "uint64 pointer", 443 field: "Uint64Ptr", 444 wantVal: entity.Uint64Ptr, 445 }, 446 { 447 name: "float32", 448 field: "Float32", 449 wantVal: entity.Float32, 450 }, 451 { 452 name: "float32 pointer", 453 field: "Float32Ptr", 454 wantVal: entity.Float32Ptr, 455 }, 456 { 457 name: "float64", 458 field: "Float64", 459 wantVal: entity.Float64, 460 }, 461 { 462 name: "float64 pointer", 463 field: "Float64Ptr", 464 wantVal: entity.Float64Ptr, 465 }, 466 { 467 name: "byte array", 468 field: "ByteArray", 469 wantVal: entity.ByteArray, 470 }, 471 { 472 name: "string", 473 field: "String", 474 wantVal: entity.String, 475 }, 476 { 477 name: "NullStringPtr", 478 field: "NullStringPtr", 479 wantVal: entity.NullStringPtr, 480 }, 481 { 482 name: "NullInt16Ptr", 483 field: "NullInt16Ptr", 484 wantVal: entity.NullInt16Ptr, 485 }, 486 { 487 name: "NullInt32Ptr", 488 field: "NullInt32Ptr", 489 wantVal: entity.NullInt32Ptr, 490 }, 491 { 492 name: "NullInt64Ptr", 493 field: "NullInt64Ptr", 494 wantVal: entity.NullInt64Ptr, 495 }, 496 { 497 name: "NullBoolPtr", 498 field: "NullBoolPtr", 499 wantVal: entity.NullBoolPtr, 500 }, 501 { 502 name: "NullFloat64Ptr", 503 field: "NullFloat64Ptr", 504 wantVal: entity.NullFloat64Ptr, 505 }, 506 { 507 name: "JsonColumn", 508 field: "JsonColumn", 509 wantVal: entity.JsonColumn, 510 }, 511 } 512 } 513 514 type valueFieldTestCase struct { 515 name string 516 field string 517 wantVal interface{} 518 wantError error 519 } 520 521 func FuzzUnsafeValue_Field(f *testing.F) { 522 f.Fuzz(fuzzValueField(NewUnsafeValue)) 523 } 524 525 func BenchmarkUnsafeValue_Field(b *testing.B) { 526 meta, _ := model.NewMetaRegistry().Get(&test.SimpleStruct{}) 527 ins := NewReflectValue(&test.SimpleStruct{Int64: 13, NullStringPtr: &sql.NullString{}}, meta) 528 b.Run("int64", func(b *testing.B) { 529 for i := 0; i < b.N; i++ { 530 val, err := ins.Field("Int64") 531 assert.Nil(b, err) 532 assert.Equal(b, int64(13), val.Interface()) 533 } 534 }) 535 b.Run("holder", func(b *testing.B) { 536 for i := 0; i < b.N; i++ { 537 val, err := ins.Field("NullStringPtr") 538 assert.Nil(b, err) 539 assert.Equal(b, &sql.NullString{}, val.Interface()) 540 } 541 }) 542 }