github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/valuer/primitive_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package valuer 16 17 import ( 18 "database/sql" 19 "database/sql/driver" 20 "testing" 21 "time" 22 23 "github.com/DATA-DOG/go-sqlmock" 24 "github.com/ecodeclub/eorm/internal/errs" 25 "github.com/ecodeclub/eorm/internal/model" 26 "github.com/ecodeclub/eorm/internal/test" 27 "github.com/stretchr/testify/assert" 28 ) 29 30 func Test_primitiveValue_Field(t *testing.T) { 31 testPrimitiveValueField(t, PrimitiveCreator{Creator: NewUnsafeValue}) 32 testPrimitiveValueField(t, PrimitiveCreator{Creator: NewReflectValue}) 33 } 34 35 func testPrimitiveValueField(t *testing.T, creator PrimitiveCreator) { 36 meta, err := model.NewMetaRegistry().Get(&test.SimpleStruct{}) 37 if err != nil { 38 t.Fatal(err) 39 } 40 t.Run("zero value", func(t *testing.T) { 41 entity := &test.SimpleStruct{} 42 testCases := newValueFieldTestCases(entity) 43 val := creator.NewPrimitiveValue(entity, meta) 44 for _, tc := range testCases { 45 t.Run(tc.name, func(t *testing.T) { 46 v, err := val.Field(tc.field) 47 assert.Equal(t, tc.wantError, err) 48 if err != nil { 49 return 50 } 51 assert.Equal(t, tc.wantVal, v.Interface()) 52 }) 53 } 54 }) 55 t.Run("normal value", func(t *testing.T) { 56 entity := test.NewSimpleStruct(1) 57 testCases := newValueFieldTestCases(entity) 58 val := creator.NewPrimitiveValue(entity, meta) 59 for _, tc := range testCases { 60 t.Run(tc.name, func(t *testing.T) { 61 v, err := val.Field(tc.field) 62 assert.Equal(t, tc.wantError, err) 63 if err != nil { 64 return 65 } 66 assert.Equal(t, tc.wantVal, v.Interface()) 67 }) 68 } 69 }) 70 71 type User struct { 72 CreateTime time.Time 73 Name *string 74 } 75 76 invalidCases := []valueFieldTestCase{ 77 { 78 // 不存在的字段 79 name: "invalid field", 80 field: "UpdateTime", 81 wantError: errs.NewInvalidFieldError("UpdateTime"), 82 }, 83 } 84 t.Run("invalid cases", func(t *testing.T) { 85 meta, err := model.NewMetaRegistry().Get(&User{}) 86 if err != nil { 87 t.Fatal(err) 88 } 89 90 val := creator.NewPrimitiveValue(&User{}, meta) 91 for _, tc := range invalidCases { 92 t.Run(tc.name, func(t *testing.T) { 93 v, err := val.Field(tc.field) 94 assert.Equal(t, tc.wantError, err) 95 if err != nil { 96 return 97 } 98 assert.Equal(t, tc.wantVal, v.Interface()) 99 }) 100 } 101 }) 102 103 type BaseEntity struct { 104 Id int64 `eorm:"auto_increment,primary_key"` 105 CreateTime uint64 106 } 107 type CombinedUser struct { 108 BaseEntity 109 FirstName string 110 } 111 112 cUser := &CombinedUser{} 113 testCases := []valueFieldTestCase{ 114 { 115 name: "id", 116 field: "Id", 117 wantVal: cUser.Id, 118 }, 119 { 120 name: "CreateTime", 121 field: "CreateTime", 122 wantVal: cUser.CreateTime, 123 }, 124 { 125 name: "FirstName", 126 field: "FirstName", 127 wantVal: cUser.FirstName, 128 }, 129 } 130 // 测试使用组合的场景 131 t.Run("combination cases", func(t *testing.T) { 132 meta, err = model.NewMetaRegistry().Get(cUser) 133 if err != nil { 134 t.Fatal(err) 135 } 136 137 val := creator.NewPrimitiveValue(cUser, meta) 138 for _, tc := range testCases { 139 t.Run(tc.name, func(t *testing.T) { 140 v, err := val.Field(tc.field) 141 assert.Equal(t, tc.wantError, err) 142 if err != nil { 143 return 144 } 145 assert.Equal(t, tc.wantVal, v.Interface()) 146 }) 147 } 148 }) 149 } 150 151 func Test_primitiveValue_SetColumn(t *testing.T) { 152 testCases := []struct { 153 name string 154 cs map[string][]byte 155 val any 156 valCreator Creator 157 wantVal any 158 wantErr error 159 }{ 160 { 161 name: "int", 162 valCreator: NewUnsafeValue, 163 cs: map[string][]byte{ 164 "id": []byte("10"), 165 }, 166 val: new(int), 167 wantVal: func() *int { 168 val := 10 169 return &val 170 }(), 171 }, 172 { 173 name: "string", 174 valCreator: NewUnsafeValue, 175 cs: map[string][]byte{ 176 "string": []byte("word"), 177 }, 178 val: new(string), 179 wantVal: func() *string { 180 val := "word" 181 return &val 182 }(), 183 }, 184 { 185 name: "bytes", 186 valCreator: NewUnsafeValue, 187 cs: map[string][]byte{ 188 "byte": []byte("word"), 189 }, 190 val: new([]byte), 191 wantVal: func() *[]byte { 192 val := []byte("word") 193 return &val 194 }(), 195 }, 196 { 197 name: "bool", 198 valCreator: NewUnsafeValue, 199 cs: map[string][]byte{ 200 "bool": []byte("true"), 201 }, 202 val: new(bool), 203 wantVal: func() *bool { 204 val := true 205 return &val 206 }(), 207 }, 208 { 209 name: "sql.NullString", 210 valCreator: NewUnsafeValue, 211 cs: map[string][]byte{ 212 "string": []byte("word"), 213 }, 214 val: &sql.NullString{}, 215 wantVal: &sql.NullString{String: "word", Valid: true}, 216 }, 217 { 218 name: "struct ptr value", 219 valCreator: NewUnsafeValue, 220 cs: map[string][]byte{ 221 "id": []byte("1"), 222 "bool": []byte("true"), 223 "bool_ptr": []byte("false"), 224 "int": []byte("12"), 225 "int_ptr": []byte("13"), 226 "int8": []byte("8"), 227 "int8_ptr": []byte("-8"), 228 "int16": []byte("16"), 229 "int16_ptr": []byte("-16"), 230 "int32": []byte("32"), 231 "int32_ptr": []byte("-32"), 232 "int64": []byte("64"), 233 "int64_ptr": []byte("-64"), 234 "uint": []byte("14"), 235 "uint_ptr": []byte("15"), 236 "uint8": []byte("8"), 237 "uint8_ptr": []byte("18"), 238 "uint16": []byte("16"), 239 "uint16_ptr": []byte("116"), 240 "uint32": []byte("32"), 241 "uint32_ptr": []byte("132"), 242 "uint64": []byte("64"), 243 "uint64_ptr": []byte("164"), 244 "float32": []byte("3.2"), 245 "float32_ptr": []byte("-3.2"), 246 "float64": []byte("6.4"), 247 "float64_ptr": []byte("-6.4"), 248 "byte_array": []byte("hello"), 249 "string": []byte("world"), 250 "null_string_ptr": []byte("null string"), 251 "null_int16_ptr": []byte("16"), 252 "null_int32_ptr": []byte("32"), 253 "null_int64_ptr": []byte("64"), 254 "null_bool_ptr": []byte("true"), 255 "null_float64_ptr": []byte("6.4"), 256 "json_column": []byte(`{"name": "Tom"}`), 257 }, 258 val: &test.SimpleStruct{}, 259 wantVal: test.NewSimpleStruct(1), 260 }, 261 { 262 name: "invalid field", 263 cs: map[string][]byte{ 264 "invalid_column": nil, 265 }, 266 valCreator: NewReflectValue, 267 val: &test.SimpleStruct{}, 268 wantErr: errs.NewInvalidColumnError("invalid_column"), 269 }, 270 } 271 272 r := model.NewMetaRegistry() 273 meta, err := r.Get(&test.SimpleStruct{}) 274 if err != nil { 275 t.Fatal(err) 276 } 277 for _, tc := range testCases { 278 t.Run(tc.name, func(t *testing.T) { 279 db, mock, err := sqlmock.New() 280 if err != nil { 281 t.Fatal(err) 282 } 283 defer func() { _ = db.Close() }() 284 basicCreator := PrimitiveCreator{Creator: tc.valCreator} 285 val := basicCreator.NewPrimitiveValue(tc.val, meta) 286 cols := make([]string, 0, len(tc.cs)) 287 colVals := make([]driver.Value, 0, len(tc.cs)) 288 for k, v := range tc.cs { 289 cols = append(cols, k) 290 colVals = append(colVals, v) 291 } 292 mock.ExpectQuery("SELECT *"). 293 WillReturnRows(sqlmock.NewRows(cols). 294 AddRow(colVals...)) 295 rows, _ := db.Query("SELECT *") 296 rows.Next() 297 err = val.SetColumns(rows) 298 if err != nil { 299 assert.Equal(t, tc.wantErr, err) 300 return 301 } 302 if tc.wantErr != nil { 303 t.Fatalf("期望得到错误,但是并没有得到 %v", tc.wantErr) 304 } 305 assert.Equal(t, tc.wantVal, tc.val) 306 }) 307 } 308 309 }