github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/valuer/unsafe_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package valuer
    16  
    17  import (
    18  	"database/sql"
    19  	"database/sql/driver"
    20  	"errors"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/DATA-DOG/go-sqlmock"
    25  	"github.com/ecodeclub/eorm/internal/errs"
    26  	"github.com/ecodeclub/eorm/internal/model"
    27  	"github.com/ecodeclub/eorm/internal/test"
    28  	"github.com/stretchr/testify/assert"
    29  )
    30  
    31  func Test_unsafeValue_Field(t *testing.T) {
    32  	testValueField(t, NewUnsafeValue)
    33  }
    34  
    35  func Test_unsafeValue_SetColumn(t *testing.T) {
    36  	testSetColumn(t, NewUnsafeValue)
    37  }
    38  
    39  func testSetColumn(t *testing.T, creator Creator) {
    40  	r := model.NewMetaRegistry()
    41  	t.Run("types", func(t *testing.T) {
    42  		testCases := []struct {
    43  			name    string
    44  			cs      map[string][]byte
    45  			val     *test.SimpleStruct
    46  			wantVal *test.SimpleStruct
    47  			wantErr error
    48  		}{
    49  			{
    50  				name: "normal value",
    51  				cs: map[string][]byte{
    52  					"id":               []byte("1"),
    53  					"bool":             []byte("true"),
    54  					"bool_ptr":         []byte("false"),
    55  					"int":              []byte("12"),
    56  					"int_ptr":          []byte("13"),
    57  					"int8":             []byte("8"),
    58  					"int8_ptr":         []byte("-8"),
    59  					"int16":            []byte("16"),
    60  					"int16_ptr":        []byte("-16"),
    61  					"int32":            []byte("32"),
    62  					"int32_ptr":        []byte("-32"),
    63  					"int64":            []byte("64"),
    64  					"int64_ptr":        []byte("-64"),
    65  					"uint":             []byte("14"),
    66  					"uint_ptr":         []byte("15"),
    67  					"uint8":            []byte("8"),
    68  					"uint8_ptr":        []byte("18"),
    69  					"uint16":           []byte("16"),
    70  					"uint16_ptr":       []byte("116"),
    71  					"uint32":           []byte("32"),
    72  					"uint32_ptr":       []byte("132"),
    73  					"uint64":           []byte("64"),
    74  					"uint64_ptr":       []byte("164"),
    75  					"float32":          []byte("3.2"),
    76  					"float32_ptr":      []byte("-3.2"),
    77  					"float64":          []byte("6.4"),
    78  					"float64_ptr":      []byte("-6.4"),
    79  					"byte_array":       []byte("hello"),
    80  					"string":           []byte("world"),
    81  					"null_string_ptr":  []byte("null string"),
    82  					"null_int16_ptr":   []byte("16"),
    83  					"null_int32_ptr":   []byte("32"),
    84  					"null_int64_ptr":   []byte("64"),
    85  					"null_bool_ptr":    []byte("true"),
    86  					"null_float64_ptr": []byte("6.4"),
    87  					"json_column":      []byte(`{"name": "Tom"}`),
    88  				},
    89  				val:     &test.SimpleStruct{},
    90  				wantVal: test.NewSimpleStruct(1),
    91  			},
    92  			{
    93  				name: "invalid field",
    94  				cs: map[string][]byte{
    95  					"invalid_column": nil,
    96  				},
    97  				wantErr: errs.NewInvalidColumnError("invalid_column"),
    98  			},
    99  		}
   100  
   101  		meta, err := r.Get(&test.SimpleStruct{})
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  
   106  		for _, tc := range testCases {
   107  			t.Run(tc.name, func(t *testing.T) {
   108  				db, mock, err := sqlmock.New()
   109  				if err != nil {
   110  					t.Fatal(err)
   111  				}
   112  				defer func() { _ = db.Close() }()
   113  				val := creator(tc.val, meta)
   114  				cols := make([]string, 0, len(tc.cs))
   115  				colVals := make([]driver.Value, 0, len(tc.cs))
   116  				for k, v := range tc.cs {
   117  					cols = append(cols, k)
   118  					colVals = append(colVals, v)
   119  				}
   120  				mock.ExpectQuery("SELECT *").
   121  					WillReturnRows(sqlmock.NewRows(cols).
   122  						AddRow(colVals...))
   123  				rows, _ := db.Query("SELECT *")
   124  				rows.Next()
   125  				err = val.SetColumns(rows)
   126  				if err != nil {
   127  					assert.Equal(t, tc.wantErr, err)
   128  					return
   129  				}
   130  				if tc.wantErr != nil {
   131  					t.Fatalf("期望得到错误,但是并没有得到 %v", tc.wantErr)
   132  				}
   133  				assert.Equal(t, tc.wantVal, tc.val)
   134  			})
   135  		}
   136  	})
   137  
   138  	type User struct {
   139  		Name string
   140  	}
   141  
   142  	// 测试行数匹配不上,或者无法读取数据的场景
   143  	t.Run("invalid rows", func(t *testing.T) {
   144  		db, mock, err := sqlmock.New()
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		defer func() { _ = db.Close() }()
   149  
   150  		u := &User{}
   151  		meta, err := r.Get(u)
   152  		if err != nil {
   153  			t.Fatal(err)
   154  		}
   155  		val := creator(u, meta)
   156  		// 多了一个列
   157  		mock.ExpectQuery("SELECT *").
   158  			WillReturnRows(sqlmock.NewRows([]string{"ID", "Name"}).
   159  				AddRow(123, "Tom"))
   160  		rows, _ := db.Query("SELECT *")
   161  		rows.Next()
   162  		err = val.SetColumns(rows)
   163  		assert.Equal(t, errs.ErrTooManyColumns, err)
   164  
   165  		// 读取列错误
   166  		mock.ExpectQuery("SELECT *").
   167  			WillReturnRows(sqlmock.NewRows([]string{"Name"}))
   168  		rows, _ = db.Query("SELECT *")
   169  		rows.Next()
   170  		err = val.SetColumns(rows)
   171  		assert.Equal(t, errors.New("sql: Rows are closed"), err)
   172  	})
   173  
   174  	type BaseEntity struct {
   175  		Id         int64 `eorm:"primary_key"`
   176  		CreateTime uint64
   177  	}
   178  	type CombinedUser struct {
   179  		BaseEntity
   180  		FirstName string
   181  	}
   182  
   183  	// 测试使用组合的场景
   184  	t.Run("combination", func(t *testing.T) {
   185  		db, mock, err := sqlmock.New()
   186  		if err != nil {
   187  			t.Fatal(err)
   188  		}
   189  		defer func() { _ = db.Close() }()
   190  
   191  		u := &CombinedUser{}
   192  		meta, err := r.Get(u)
   193  		if err != nil {
   194  			t.Fatal(err)
   195  		}
   196  		val := creator(u, meta)
   197  		// 多了一个列
   198  		mock.ExpectQuery("SELECT *").
   199  			WillReturnRows(sqlmock.NewRows([]string{"id", "create_time", "first_name"}).
   200  				AddRow(123, 100000, "Tom"))
   201  		rows, _ := db.Query("SELECT *")
   202  		rows.Next()
   203  		err = val.SetColumns(rows)
   204  		assert.NoError(t, err)
   205  		wantUser := &CombinedUser{
   206  			BaseEntity: BaseEntity{
   207  				Id:         123,
   208  				CreateTime: 100000,
   209  			},
   210  			FirstName: "Tom",
   211  		}
   212  		assert.Equal(t, wantUser, u)
   213  	})
   214  }
   215  
   216  func testValueField(t *testing.T, creator Creator) {
   217  	meta, err := model.NewMetaRegistry().Get(&test.SimpleStruct{})
   218  	if err != nil {
   219  		t.Fatal(err)
   220  	}
   221  	t.Run("zero value", func(t *testing.T) {
   222  		entity := &test.SimpleStruct{}
   223  		testCases := newValueFieldTestCases(entity)
   224  		val := creator(entity, meta)
   225  		for _, tc := range testCases {
   226  			t.Run(tc.name, func(t *testing.T) {
   227  				v, err := val.Field(tc.field)
   228  				assert.Equal(t, tc.wantError, err)
   229  				if err != nil {
   230  					return
   231  				}
   232  				assert.Equal(t, tc.wantVal, v.Interface())
   233  			})
   234  		}
   235  	})
   236  	t.Run("normal value", func(t *testing.T) {
   237  		entity := test.NewSimpleStruct(1)
   238  		testCases := newValueFieldTestCases(entity)
   239  		val := NewUnsafeValue(entity, meta)
   240  		for _, tc := range testCases {
   241  			t.Run(tc.name, func(t *testing.T) {
   242  				v, err := val.Field(tc.field)
   243  				assert.Equal(t, tc.wantError, err)
   244  				if err != nil {
   245  					return
   246  				}
   247  				assert.Equal(t, tc.wantVal, v.Interface())
   248  			})
   249  		}
   250  	})
   251  
   252  	type User struct {
   253  		CreateTime time.Time
   254  		Name       *string
   255  	}
   256  
   257  	invalidCases := []valueFieldTestCase{
   258  		{
   259  			// 不存在的字段
   260  			name:      "invalid field",
   261  			field:     "UpdateTime",
   262  			wantError: errs.NewInvalidFieldError("UpdateTime"),
   263  		},
   264  	}
   265  	t.Run("invalid cases", func(t *testing.T) {
   266  		meta, err := model.NewMetaRegistry().Get(&User{})
   267  		if err != nil {
   268  			t.Fatal(err)
   269  		}
   270  
   271  		val := NewUnsafeValue(&User{}, meta)
   272  		for _, tc := range invalidCases {
   273  			t.Run(tc.name, func(t *testing.T) {
   274  				v, err := val.Field(tc.field)
   275  				assert.Equal(t, tc.wantError, err)
   276  				if err != nil {
   277  					return
   278  				}
   279  				assert.Equal(t, tc.wantVal, v.Interface())
   280  			})
   281  		}
   282  	})
   283  
   284  	type BaseEntity struct {
   285  		Id         int64 `eorm:"primary_key"`
   286  		CreateTime uint64
   287  	}
   288  	type CombinedUser struct {
   289  		BaseEntity
   290  		FirstName string
   291  	}
   292  
   293  	cUser := &CombinedUser{}
   294  	testCases := []valueFieldTestCase{
   295  		{
   296  			name:    "id",
   297  			field:   "Id",
   298  			wantVal: cUser.Id,
   299  		},
   300  		{
   301  			name:    "CreateTime",
   302  			field:   "CreateTime",
   303  			wantVal: cUser.CreateTime,
   304  		},
   305  		{
   306  			name:    "FirstName",
   307  			field:   "FirstName",
   308  			wantVal: cUser.FirstName,
   309  		},
   310  	}
   311  	// 测试使用组合的场景
   312  	t.Run("combination cases", func(t *testing.T) {
   313  		meta, err = model.NewMetaRegistry().Get(cUser)
   314  		if err != nil {
   315  			t.Fatal(err)
   316  		}
   317  
   318  		val := NewUnsafeValue(cUser, meta)
   319  		for _, tc := range testCases {
   320  			t.Run(tc.name, func(t *testing.T) {
   321  				v, err := val.Field(tc.field)
   322  				assert.Equal(t, tc.wantError, err)
   323  				if err != nil {
   324  					return
   325  				}
   326  				assert.Equal(t, tc.wantVal, v.Interface())
   327  			})
   328  		}
   329  	})
   330  }
   331  
   332  func newValueFieldTestCases(entity *test.SimpleStruct) []valueFieldTestCase {
   333  	return []valueFieldTestCase{
   334  		{
   335  			name:    "bool",
   336  			field:   "Bool",
   337  			wantVal: entity.Bool,
   338  		},
   339  		{
   340  			// bool 指针类型
   341  			name:    "bool pointer",
   342  			field:   "BoolPtr",
   343  			wantVal: entity.BoolPtr,
   344  		},
   345  		{
   346  			name:    "int",
   347  			field:   "Int",
   348  			wantVal: entity.Int,
   349  		},
   350  		{
   351  			// int 指针类型
   352  			name:    "int pointer",
   353  			field:   "IntPtr",
   354  			wantVal: entity.IntPtr,
   355  		},
   356  		{
   357  			name:    "int8",
   358  			field:   "Int8",
   359  			wantVal: entity.Int8,
   360  		},
   361  		{
   362  			name:    "int8 pointer",
   363  			field:   "Int8Ptr",
   364  			wantVal: entity.Int8Ptr,
   365  		},
   366  		{
   367  			name:    "int16",
   368  			field:   "Int16",
   369  			wantVal: entity.Int16,
   370  		},
   371  		{
   372  			name:    "int16 pointer",
   373  			field:   "Int16Ptr",
   374  			wantVal: entity.Int16Ptr,
   375  		},
   376  		{
   377  			name:    "int32",
   378  			field:   "Int32",
   379  			wantVal: entity.Int32,
   380  		},
   381  		{
   382  			name:    "int32 pointer",
   383  			field:   "Int32Ptr",
   384  			wantVal: entity.Int32Ptr,
   385  		},
   386  		{
   387  			name:    "int64",
   388  			field:   "Int64",
   389  			wantVal: entity.Int64,
   390  		},
   391  		{
   392  			name:    "int64 pointer",
   393  			field:   "Int64Ptr",
   394  			wantVal: entity.Int64Ptr,
   395  		},
   396  		{
   397  			name:    "uint",
   398  			field:   "Uint",
   399  			wantVal: entity.Uint,
   400  		},
   401  		{
   402  			name:    "uint pointer",
   403  			field:   "UintPtr",
   404  			wantVal: entity.UintPtr,
   405  		},
   406  		{
   407  			name:    "uint8",
   408  			field:   "Uint8",
   409  			wantVal: entity.Uint8,
   410  		},
   411  		{
   412  			name:    "uint8 pointer",
   413  			field:   "Uint8Ptr",
   414  			wantVal: entity.Uint8Ptr,
   415  		},
   416  		{
   417  			name:    "uint16",
   418  			field:   "Uint16",
   419  			wantVal: entity.Uint16,
   420  		},
   421  		{
   422  			name:    "uint16 pointer",
   423  			field:   "Uint16Ptr",
   424  			wantVal: entity.Uint16Ptr,
   425  		},
   426  		{
   427  			name:    "uint32",
   428  			field:   "Uint32",
   429  			wantVal: entity.Uint32,
   430  		},
   431  		{
   432  			name:    "uint32 pointer",
   433  			field:   "Uint32Ptr",
   434  			wantVal: entity.Uint32Ptr,
   435  		},
   436  		{
   437  			name:    "uint64",
   438  			field:   "Uint64",
   439  			wantVal: entity.Uint64,
   440  		},
   441  		{
   442  			name:    "uint64 pointer",
   443  			field:   "Uint64Ptr",
   444  			wantVal: entity.Uint64Ptr,
   445  		},
   446  		{
   447  			name:    "float32",
   448  			field:   "Float32",
   449  			wantVal: entity.Float32,
   450  		},
   451  		{
   452  			name:    "float32 pointer",
   453  			field:   "Float32Ptr",
   454  			wantVal: entity.Float32Ptr,
   455  		},
   456  		{
   457  			name:    "float64",
   458  			field:   "Float64",
   459  			wantVal: entity.Float64,
   460  		},
   461  		{
   462  			name:    "float64 pointer",
   463  			field:   "Float64Ptr",
   464  			wantVal: entity.Float64Ptr,
   465  		},
   466  		{
   467  			name:    "byte array",
   468  			field:   "ByteArray",
   469  			wantVal: entity.ByteArray,
   470  		},
   471  		{
   472  			name:    "string",
   473  			field:   "String",
   474  			wantVal: entity.String,
   475  		},
   476  		{
   477  			name:    "NullStringPtr",
   478  			field:   "NullStringPtr",
   479  			wantVal: entity.NullStringPtr,
   480  		},
   481  		{
   482  			name:    "NullInt16Ptr",
   483  			field:   "NullInt16Ptr",
   484  			wantVal: entity.NullInt16Ptr,
   485  		},
   486  		{
   487  			name:    "NullInt32Ptr",
   488  			field:   "NullInt32Ptr",
   489  			wantVal: entity.NullInt32Ptr,
   490  		},
   491  		{
   492  			name:    "NullInt64Ptr",
   493  			field:   "NullInt64Ptr",
   494  			wantVal: entity.NullInt64Ptr,
   495  		},
   496  		{
   497  			name:    "NullBoolPtr",
   498  			field:   "NullBoolPtr",
   499  			wantVal: entity.NullBoolPtr,
   500  		},
   501  		{
   502  			name:    "NullFloat64Ptr",
   503  			field:   "NullFloat64Ptr",
   504  			wantVal: entity.NullFloat64Ptr,
   505  		},
   506  		{
   507  			name:    "JsonColumn",
   508  			field:   "JsonColumn",
   509  			wantVal: entity.JsonColumn,
   510  		},
   511  	}
   512  }
   513  
   514  type valueFieldTestCase struct {
   515  	name      string
   516  	field     string
   517  	wantVal   interface{}
   518  	wantError error
   519  }
   520  
   521  func FuzzUnsafeValue_Field(f *testing.F) {
   522  	f.Fuzz(fuzzValueField(NewUnsafeValue))
   523  }
   524  
   525  func BenchmarkUnsafeValue_Field(b *testing.B) {
   526  	meta, _ := model.NewMetaRegistry().Get(&test.SimpleStruct{})
   527  	ins := NewReflectValue(&test.SimpleStruct{Int64: 13, NullStringPtr: &sql.NullString{}}, meta)
   528  	b.Run("int64", func(b *testing.B) {
   529  		for i := 0; i < b.N; i++ {
   530  			val, err := ins.Field("Int64")
   531  			assert.Nil(b, err)
   532  			assert.Equal(b, int64(13), val.Interface())
   533  		}
   534  	})
   535  	b.Run("holder", func(b *testing.B) {
   536  		for i := 0; i < b.N; i++ {
   537  			val, err := ins.Field("NullStringPtr")
   538  			assert.Nil(b, err)
   539  			assert.Equal(b, &sql.NullString{}, val.Interface())
   540  		}
   541  	})
   542  }