github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/builder_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "testing" 23 24 "github.com/ecodeclub/eorm/internal/datasource/single" 25 26 "github.com/DATA-DOG/go-sqlmock" 27 "github.com/ecodeclub/eorm/internal/errs" 28 "github.com/ecodeclub/eorm/internal/valuer" 29 "github.com/stretchr/testify/assert" 30 ) 31 32 func ExampleRawQuery() { 33 db := memoryDB() 34 q := RawQuery[any](db, `SELECT * FROM user_tab WHERE id = ?;`, 1) 35 fmt.Printf(` 36 SQL: %s 37 Args: %v 38 `, q.qc.q.SQL, q.qc.q.Args) 39 // Output: 40 // SQL: SELECT * FROM user_tab WHERE id = ?; 41 // Args: [1] 42 } 43 44 func ExampleQuerier_Exec() { 45 db := memoryDB() 46 // 在 Exec 的时候,泛型参数可以是任意的 47 q := RawQuery[any](db, `CREATE TABLE IF NOT EXISTS groups ( 48 group_id INTEGER PRIMARY KEY, 49 name TEXT NOT NULL 50 )`) 51 res := q.Exec(context.Background()) 52 if res.Err() == nil { 53 fmt.Print("SUCCESS") 54 } 55 // Output: 56 // SUCCESS 57 } 58 59 func TestQuerier_Get(t *testing.T) { 60 t.Run("unsafe", func(t *testing.T) { 61 testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue}) 62 }) 63 64 t.Run("reflect", func(t *testing.T) { 65 testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue}) 66 }) 67 } 68 69 func testQuerierGet(t *testing.T, creator valuer.PrimitiveCreator) { 70 db, mock, err := sqlmock.New() 71 if err != nil { 72 t.Fatal(err) 73 } 74 defer func() { _ = db.Close() }() 75 76 orm, err := OpenDS("mysql", single.NewDB(db)) 77 if err != nil { 78 t.Fatal(err) 79 } 80 testCases := []struct { 81 name string 82 query string 83 mockErr error 84 mockRows *sqlmock.Rows 85 wantErr error 86 wantVal *TestModel 87 }{ 88 { 89 // 查询返回错误 90 name: "query error", 91 mockErr: errors.New("invalid query"), 92 wantErr: errors.New("invalid query"), 93 query: "invalid query", 94 }, 95 { 96 name: "no row", 97 wantErr: ErrNoRows, 98 query: "no row", 99 mockRows: sqlmock.NewRows([]string{"id"}), 100 }, 101 { 102 name: "too many column", 103 wantErr: errs.ErrTooManyColumns, 104 query: "too many column", 105 mockRows: func() *sqlmock.Rows { 106 res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name", "extra_column"}) 107 res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"), []byte("nothing")) 108 return res 109 }(), 110 }, 111 { 112 name: "get data", 113 query: "SELECT xx FROM `test_model`", 114 mockRows: func() *sqlmock.Rows { 115 res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"}) 116 res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming")) 117 return res 118 }(), 119 wantVal: &TestModel{ 120 Id: 1, 121 FirstName: "Da", 122 Age: 18, 123 LastName: &sql.NullString{String: "Ming", Valid: true}, 124 }, 125 }, 126 } 127 128 for _, tc := range testCases { 129 exp := mock.ExpectQuery(tc.query) 130 if tc.mockErr != nil { 131 exp.WillReturnError(tc.mockErr) 132 } else { 133 exp.WillReturnRows(tc.mockRows) 134 } 135 } 136 orm.valCreator = creator 137 for _, tc := range testCases { 138 t.Run(tc.name, func(t *testing.T) { 139 res, err := RawQuery[TestModel](orm, tc.query).Get(context.Background()) 140 assert.Equal(t, tc.wantErr, err) 141 if err != nil { 142 return 143 } 144 assert.Equal(t, tc.wantVal, res) 145 }) 146 } 147 } 148 149 func TestQuerierGetMulti(t *testing.T) { 150 t.Run("unsafe", func(t *testing.T) { 151 testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue}) 152 }) 153 t.Run("reflect", func(t *testing.T) { 154 testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue}) 155 }) 156 } 157 158 func testQuerier_GetMulti(t *testing.T, creator valuer.PrimitiveCreator) { 159 db, mock, err := sqlmock.New() 160 if err != nil { 161 t.Fatal(err) 162 } 163 defer func() { 164 _ = db.Close() 165 }() 166 orm, err := OpenDS("mysql", single.NewDB(db)) 167 if err != nil { 168 t.Fatal(err) 169 } 170 testCases := []struct { 171 name string 172 query string 173 mockErr error 174 mockRows *sqlmock.Rows 175 wantErr error 176 wantVal []*TestModel 177 }{ 178 { 179 name: "query error", 180 mockErr: errors.New("invalid query"), 181 wantErr: errors.New("invalid query"), 182 query: "invalid query", 183 }, 184 { 185 name: "no row", 186 query: "no row", 187 mockRows: sqlmock.NewRows([]string{"id"}), 188 wantVal: []*TestModel{}, 189 }, 190 { 191 name: "too many column", 192 wantErr: errs.ErrTooManyColumns, 193 query: "too many column", 194 mockRows: func() *sqlmock.Rows { 195 res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name", "extra_column"}) 196 res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"), []byte("nothing")) 197 return res 198 }(), 199 }, 200 { 201 name: "get data", 202 query: "SELECT xx FROM `test_model`", 203 mockRows: func() *sqlmock.Rows { 204 res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"}) 205 res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming")) 206 res.AddRow([]byte("2"), []byte("Xiao"), []byte("28"), []byte("Hong")) 207 return res 208 }(), 209 wantVal: []*TestModel{&TestModel{ 210 Id: 1, 211 FirstName: "Da", 212 Age: 18, 213 LastName: &sql.NullString{String: "Ming", Valid: true}, 214 }, 215 { 216 Id: 2, 217 FirstName: "Xiao", 218 Age: 28, 219 LastName: &sql.NullString{String: "Hong", Valid: true}, 220 }, 221 }, 222 }, 223 } 224 for _, tc := range testCases { 225 exp := mock.ExpectQuery(tc.query) 226 if tc.mockErr != nil { 227 exp.WillReturnError(tc.mockErr) 228 } else { 229 exp.WillReturnRows(tc.mockRows) 230 } 231 } 232 orm.valCreator = creator 233 for _, tc := range testCases { 234 t.Run(tc.name, func(t *testing.T) { 235 res, err := RawQuery[TestModel](orm, tc.query).GetMulti(context.Background()) 236 assert.Equal(t, tc.wantErr, err) 237 if err != nil { 238 return 239 } 240 assert.Equal(t, tc.wantVal, res) 241 }) 242 } 243 244 }