github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/valuer/primitive_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package valuer
    16  
    17  import (
    18  	"database/sql"
    19  	"database/sql/driver"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/DATA-DOG/go-sqlmock"
    24  	"github.com/ecodeclub/eorm/internal/errs"
    25  	"github.com/ecodeclub/eorm/internal/model"
    26  	"github.com/ecodeclub/eorm/internal/test"
    27  	"github.com/stretchr/testify/assert"
    28  )
    29  
    30  func Test_primitiveValue_Field(t *testing.T) {
    31  	testPrimitiveValueField(t, PrimitiveCreator{Creator: NewUnsafeValue})
    32  	testPrimitiveValueField(t, PrimitiveCreator{Creator: NewReflectValue})
    33  }
    34  
    35  func testPrimitiveValueField(t *testing.T, creator PrimitiveCreator) {
    36  	meta, err := model.NewMetaRegistry().Get(&test.SimpleStruct{})
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	t.Run("zero value", func(t *testing.T) {
    41  		entity := &test.SimpleStruct{}
    42  		testCases := newValueFieldTestCases(entity)
    43  		val := creator.NewPrimitiveValue(entity, meta)
    44  		for _, tc := range testCases {
    45  			t.Run(tc.name, func(t *testing.T) {
    46  				v, err := val.Field(tc.field)
    47  				assert.Equal(t, tc.wantError, err)
    48  				if err != nil {
    49  					return
    50  				}
    51  				assert.Equal(t, tc.wantVal, v.Interface())
    52  			})
    53  		}
    54  	})
    55  	t.Run("normal value", func(t *testing.T) {
    56  		entity := test.NewSimpleStruct(1)
    57  		testCases := newValueFieldTestCases(entity)
    58  		val := creator.NewPrimitiveValue(entity, meta)
    59  		for _, tc := range testCases {
    60  			t.Run(tc.name, func(t *testing.T) {
    61  				v, err := val.Field(tc.field)
    62  				assert.Equal(t, tc.wantError, err)
    63  				if err != nil {
    64  					return
    65  				}
    66  				assert.Equal(t, tc.wantVal, v.Interface())
    67  			})
    68  		}
    69  	})
    70  
    71  	type User struct {
    72  		CreateTime time.Time
    73  		Name       *string
    74  	}
    75  
    76  	invalidCases := []valueFieldTestCase{
    77  		{
    78  			// 不存在的字段
    79  			name:      "invalid field",
    80  			field:     "UpdateTime",
    81  			wantError: errs.NewInvalidFieldError("UpdateTime"),
    82  		},
    83  	}
    84  	t.Run("invalid cases", func(t *testing.T) {
    85  		meta, err := model.NewMetaRegistry().Get(&User{})
    86  		if err != nil {
    87  			t.Fatal(err)
    88  		}
    89  
    90  		val := creator.NewPrimitiveValue(&User{}, meta)
    91  		for _, tc := range invalidCases {
    92  			t.Run(tc.name, func(t *testing.T) {
    93  				v, err := val.Field(tc.field)
    94  				assert.Equal(t, tc.wantError, err)
    95  				if err != nil {
    96  					return
    97  				}
    98  				assert.Equal(t, tc.wantVal, v.Interface())
    99  			})
   100  		}
   101  	})
   102  
   103  	type BaseEntity struct {
   104  		Id         int64 `eorm:"auto_increment,primary_key"`
   105  		CreateTime uint64
   106  	}
   107  	type CombinedUser struct {
   108  		BaseEntity
   109  		FirstName string
   110  	}
   111  
   112  	cUser := &CombinedUser{}
   113  	testCases := []valueFieldTestCase{
   114  		{
   115  			name:    "id",
   116  			field:   "Id",
   117  			wantVal: cUser.Id,
   118  		},
   119  		{
   120  			name:    "CreateTime",
   121  			field:   "CreateTime",
   122  			wantVal: cUser.CreateTime,
   123  		},
   124  		{
   125  			name:    "FirstName",
   126  			field:   "FirstName",
   127  			wantVal: cUser.FirstName,
   128  		},
   129  	}
   130  	// 测试使用组合的场景
   131  	t.Run("combination cases", func(t *testing.T) {
   132  		meta, err = model.NewMetaRegistry().Get(cUser)
   133  		if err != nil {
   134  			t.Fatal(err)
   135  		}
   136  
   137  		val := creator.NewPrimitiveValue(cUser, meta)
   138  		for _, tc := range testCases {
   139  			t.Run(tc.name, func(t *testing.T) {
   140  				v, err := val.Field(tc.field)
   141  				assert.Equal(t, tc.wantError, err)
   142  				if err != nil {
   143  					return
   144  				}
   145  				assert.Equal(t, tc.wantVal, v.Interface())
   146  			})
   147  		}
   148  	})
   149  }
   150  
   151  func Test_primitiveValue_SetColumn(t *testing.T) {
   152  	testCases := []struct {
   153  		name       string
   154  		cs         map[string][]byte
   155  		val        any
   156  		valCreator Creator
   157  		wantVal    any
   158  		wantErr    error
   159  	}{
   160  		{
   161  			name:       "int",
   162  			valCreator: NewUnsafeValue,
   163  			cs: map[string][]byte{
   164  				"id": []byte("10"),
   165  			},
   166  			val: new(int),
   167  			wantVal: func() *int {
   168  				val := 10
   169  				return &val
   170  			}(),
   171  		},
   172  		{
   173  			name:       "string",
   174  			valCreator: NewUnsafeValue,
   175  			cs: map[string][]byte{
   176  				"string": []byte("word"),
   177  			},
   178  			val: new(string),
   179  			wantVal: func() *string {
   180  				val := "word"
   181  				return &val
   182  			}(),
   183  		},
   184  		{
   185  			name:       "bytes",
   186  			valCreator: NewUnsafeValue,
   187  			cs: map[string][]byte{
   188  				"byte": []byte("word"),
   189  			},
   190  			val: new([]byte),
   191  			wantVal: func() *[]byte {
   192  				val := []byte("word")
   193  				return &val
   194  			}(),
   195  		},
   196  		{
   197  			name:       "bool",
   198  			valCreator: NewUnsafeValue,
   199  			cs: map[string][]byte{
   200  				"bool": []byte("true"),
   201  			},
   202  			val: new(bool),
   203  			wantVal: func() *bool {
   204  				val := true
   205  				return &val
   206  			}(),
   207  		},
   208  		{
   209  			name:       "sql.NullString",
   210  			valCreator: NewUnsafeValue,
   211  			cs: map[string][]byte{
   212  				"string": []byte("word"),
   213  			},
   214  			val:     &sql.NullString{},
   215  			wantVal: &sql.NullString{String: "word", Valid: true},
   216  		},
   217  		{
   218  			name:       "struct ptr value",
   219  			valCreator: NewUnsafeValue,
   220  			cs: map[string][]byte{
   221  				"id":               []byte("1"),
   222  				"bool":             []byte("true"),
   223  				"bool_ptr":         []byte("false"),
   224  				"int":              []byte("12"),
   225  				"int_ptr":          []byte("13"),
   226  				"int8":             []byte("8"),
   227  				"int8_ptr":         []byte("-8"),
   228  				"int16":            []byte("16"),
   229  				"int16_ptr":        []byte("-16"),
   230  				"int32":            []byte("32"),
   231  				"int32_ptr":        []byte("-32"),
   232  				"int64":            []byte("64"),
   233  				"int64_ptr":        []byte("-64"),
   234  				"uint":             []byte("14"),
   235  				"uint_ptr":         []byte("15"),
   236  				"uint8":            []byte("8"),
   237  				"uint8_ptr":        []byte("18"),
   238  				"uint16":           []byte("16"),
   239  				"uint16_ptr":       []byte("116"),
   240  				"uint32":           []byte("32"),
   241  				"uint32_ptr":       []byte("132"),
   242  				"uint64":           []byte("64"),
   243  				"uint64_ptr":       []byte("164"),
   244  				"float32":          []byte("3.2"),
   245  				"float32_ptr":      []byte("-3.2"),
   246  				"float64":          []byte("6.4"),
   247  				"float64_ptr":      []byte("-6.4"),
   248  				"byte_array":       []byte("hello"),
   249  				"string":           []byte("world"),
   250  				"null_string_ptr":  []byte("null string"),
   251  				"null_int16_ptr":   []byte("16"),
   252  				"null_int32_ptr":   []byte("32"),
   253  				"null_int64_ptr":   []byte("64"),
   254  				"null_bool_ptr":    []byte("true"),
   255  				"null_float64_ptr": []byte("6.4"),
   256  				"json_column":      []byte(`{"name": "Tom"}`),
   257  			},
   258  			val:     &test.SimpleStruct{},
   259  			wantVal: test.NewSimpleStruct(1),
   260  		},
   261  		{
   262  			name: "invalid field",
   263  			cs: map[string][]byte{
   264  				"invalid_column": nil,
   265  			},
   266  			valCreator: NewReflectValue,
   267  			val:        &test.SimpleStruct{},
   268  			wantErr:    errs.NewInvalidColumnError("invalid_column"),
   269  		},
   270  	}
   271  
   272  	r := model.NewMetaRegistry()
   273  	meta, err := r.Get(&test.SimpleStruct{})
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  	for _, tc := range testCases {
   278  		t.Run(tc.name, func(t *testing.T) {
   279  			db, mock, err := sqlmock.New()
   280  			if err != nil {
   281  				t.Fatal(err)
   282  			}
   283  			defer func() { _ = db.Close() }()
   284  			basicCreator := PrimitiveCreator{Creator: tc.valCreator}
   285  			val := basicCreator.NewPrimitiveValue(tc.val, meta)
   286  			cols := make([]string, 0, len(tc.cs))
   287  			colVals := make([]driver.Value, 0, len(tc.cs))
   288  			for k, v := range tc.cs {
   289  				cols = append(cols, k)
   290  				colVals = append(colVals, v)
   291  			}
   292  			mock.ExpectQuery("SELECT *").
   293  				WillReturnRows(sqlmock.NewRows(cols).
   294  					AddRow(colVals...))
   295  			rows, _ := db.Query("SELECT *")
   296  			rows.Next()
   297  			err = val.SetColumns(rows)
   298  			if err != nil {
   299  				assert.Equal(t, tc.wantErr, err)
   300  				return
   301  			}
   302  			if tc.wantErr != nil {
   303  				t.Fatalf("期望得到错误,但是并没有得到 %v", tc.wantErr)
   304  			}
   305  			assert.Equal(t, tc.wantVal, tc.val)
   306  		})
   307  	}
   308  
   309  }