github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/rows/convert_assign_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rows 16 17 import ( 18 "database/sql" 19 "testing" 20 "time" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 ) 25 26 func TestConvertNullable(t *testing.T) { 27 testcases := []struct { 28 name string 29 src any 30 dest any 31 wantVal any 32 hasErr bool 33 }{ 34 { 35 name: "sql.NUllbool", 36 src: sql.NullBool{Valid: true, Bool: true}, 37 dest: &sql.NullBool{Valid: false, Bool: false}, 38 wantVal: &sql.NullBool{Valid: true, Bool: true}, 39 }, 40 { 41 name: "sql.NUllbool的valid为false", 42 src: sql.NullBool{Valid: false, Bool: true}, 43 dest: &sql.NullBool{Valid: false, Bool: false}, 44 wantVal: &sql.NullBool{Valid: false, Bool: false}, 45 }, 46 { 47 name: "sql.NUllString", 48 src: sql.NullString{Valid: true, String: "xx"}, 49 dest: &sql.NullString{Valid: false, String: ""}, 50 wantVal: &sql.NullString{Valid: true, String: "xx"}, 51 }, 52 { 53 name: "sql.NUllString的valid为false", 54 src: sql.NullString{Valid: false, String: "xx"}, 55 dest: &sql.NullString{Valid: false, String: ""}, 56 wantVal: &sql.NullString{Valid: false, String: ""}, 57 }, 58 { 59 name: "sql.NUllByte", 60 src: sql.NullByte{Valid: true, Byte: 'a'}, 61 dest: &sql.NullByte{Valid: false, Byte: ' '}, 62 wantVal: &sql.NullByte{Valid: true, Byte: 'a'}, 63 }, 64 { 65 name: "sql.NUllByte的valid的false", 66 src: sql.NullByte{Valid: false, Byte: 'a'}, 67 dest: &sql.NullByte{Valid: false, Byte: 0}, 68 wantVal: &sql.NullByte{Valid: false, Byte: 0}, 69 }, 70 { 71 name: "sql.NUllInt32", 72 src: sql.NullInt32{Valid: true, Int32: 5}, 73 dest: &sql.NullInt32{Valid: false, Int32: 0}, 74 wantVal: &sql.NullInt32{Valid: true, Int32: 5}, 75 }, 76 { 77 name: "sql.NUllInt32的valid的false", 78 src: sql.NullInt32{Valid: false, Int32: 0}, 79 dest: &sql.NullInt32{Valid: false, Int32: 0}, 80 wantVal: &sql.NullInt32{Valid: false, Int32: 0}, 81 }, 82 { 83 name: "sql.NUllInt64", 84 src: sql.NullInt64{Valid: true, Int64: 5}, 85 dest: &sql.NullInt64{Valid: false, Int64: 0}, 86 wantVal: &sql.NullInt64{Valid: true, Int64: 5}, 87 }, 88 { 89 name: "sql.NUllInt64的valid的false", 90 src: sql.NullInt64{Valid: false, Int64: 0}, 91 dest: &sql.NullInt64{Valid: false, Int64: 0}, 92 wantVal: &sql.NullInt64{Valid: false, Int64: 0}, 93 }, 94 { 95 name: "sql.NUllInt16", 96 src: sql.NullInt16{Valid: true, Int16: 5}, 97 dest: &sql.NullInt16{Valid: false, Int16: 0}, 98 wantVal: &sql.NullInt16{Valid: true, Int16: 5}, 99 }, 100 { 101 name: "sql.NUllInt16的valid的false", 102 src: sql.NullInt16{Valid: false, Int16: 0}, 103 dest: &sql.NullInt16{Valid: false, Int16: 0}, 104 wantVal: &sql.NullInt16{Valid: false, Int16: 0}, 105 }, 106 { 107 name: "sql.NUllFloat64", 108 src: sql.NullFloat64{Valid: true, Float64: 5}, 109 dest: &sql.NullFloat64{Valid: false, Float64: 0}, 110 wantVal: &sql.NullFloat64{Valid: true, Float64: 5}, 111 }, 112 { 113 name: "sql.NUllfloat64的valid的false", 114 src: sql.NullFloat64{Valid: false, Float64: 0}, 115 dest: &sql.NullFloat64{Valid: false, Float64: 0}, 116 wantVal: &sql.NullFloat64{Valid: false, Float64: 0}, 117 }, 118 { 119 name: "sql.NUllTime", 120 src: sql.NullTime{Valid: true, Time: func() time.Time { 121 val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) 122 require.NoError(t, err) 123 return val 124 }()}, 125 dest: &sql.NullTime{Valid: false, Time: time.Time{}}, 126 wantVal: &sql.NullTime{Valid: true, Time: func() time.Time { 127 val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) 128 require.NoError(t, err) 129 return val 130 }()}, 131 }, 132 { 133 name: "sql.NUllTime的valid的false", 134 src: sql.NullTime{Valid: false, Time: time.Time{}}, 135 dest: &sql.NullTime{Valid: false, Time: time.Time{}}, 136 wantVal: &sql.NullTime{Valid: false, Time: time.Time{}}, 137 }, 138 { 139 name: "使用sql.NullInt32接收sql.NullInt64", 140 src: sql.NullInt64{Valid: true, Int64: 5}, 141 dest: &sql.NullInt32{Valid: false, Int32: 0}, 142 wantVal: &sql.NullInt32{Valid: true, Int32: 5}, 143 }, 144 { 145 name: "使用sql.NullInt16接收sql.NullInt64", 146 src: sql.NullInt64{Valid: true, Int64: 5}, 147 dest: &sql.NullInt16{Valid: false, Int16: 0}, 148 wantVal: &sql.NullInt16{Valid: true, Int16: 5}, 149 }, 150 { 151 name: "使用sql.NullInt32接收sql.NullInt64,Valid为false", 152 src: sql.NullInt64{Valid: false, Int64: 0}, 153 dest: &sql.NullInt32{Valid: false, Int32: 0}, 154 wantVal: &sql.NullInt32{Valid: false, Int32: 0}, 155 }, 156 { 157 name: "使用sql.NullInt16接收sql.NullInt64,Valid为false", 158 src: sql.NullInt64{Valid: false, Int64: 0}, 159 dest: &sql.NullInt16{Valid: false, Int16: 0}, 160 wantVal: &sql.NullInt16{Valid: false, Int16: 0}, 161 }, 162 { 163 name: "使用int32接收sql.NullInt64", 164 src: sql.NullInt64{Valid: true, Int64: 5}, 165 dest: func() *int32 { 166 var val int32 167 return &val 168 }(), 169 wantVal: func() *int32 { 170 val := int32(5) 171 return &val 172 }(), 173 }, 174 { 175 name: "使用int16接收sql.NullInt64", 176 src: sql.NullInt64{Valid: true, Int64: 5}, 177 dest: func() *int16 { 178 var val int16 179 return &val 180 }(), 181 wantVal: func() *int16 { 182 val := int16(5) 183 return &val 184 }(), 185 }, 186 { 187 name: "使用float32接收sql.Nullfloat64", 188 src: sql.NullFloat64{Valid: true, Float64: 5}, 189 dest: func() *float32 { 190 var val float32 191 return &val 192 }(), 193 wantVal: func() *float32 { 194 val := float32(5) 195 return &val 196 }(), 197 }, 198 { 199 name: "使用int32接收sql.NullInt64,Valid为false", 200 src: sql.NullInt64{Valid: false, Int64: 0}, 201 dest: func() *int32 { 202 var val int32 203 return &val 204 }(), 205 hasErr: true, 206 }, 207 { 208 name: "使用int16接收sql.NullInt64,valid为false", 209 src: sql.NullInt64{Valid: false, Int64: 0}, 210 dest: func() *int16 { 211 var val int16 212 return &val 213 }(), 214 hasErr: true, 215 }, 216 { 217 name: "使用float32接收sql.Nullfloat64", 218 src: sql.NullFloat64{Valid: false, Float64: 0}, 219 dest: func() *float32 { 220 var val float32 221 return &val 222 }(), 223 hasErr: true, 224 }, 225 } 226 for _, tc := range testcases { 227 t.Run(tc.name, func(t *testing.T) { 228 err := ConvertAssign(tc.dest, tc.src) 229 if tc.hasErr { 230 require.Error(t, err) 231 return 232 } 233 require.NoError(t, err) 234 assert.Equal(t, tc.dest, tc.wantVal) 235 }) 236 } 237 }