github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/builder_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/ecodeclub/eorm/internal/datasource/single"
    25  
    26  	"github.com/DATA-DOG/go-sqlmock"
    27  	"github.com/ecodeclub/eorm/internal/errs"
    28  	"github.com/ecodeclub/eorm/internal/valuer"
    29  	"github.com/stretchr/testify/assert"
    30  )
    31  
    32  func ExampleRawQuery() {
    33  	db := memoryDB()
    34  	q := RawQuery[any](db, `SELECT * FROM user_tab WHERE id = ?;`, 1)
    35  	fmt.Printf(`
    36  SQL: %s
    37  Args: %v
    38  `, q.qc.q.SQL, q.qc.q.Args)
    39  	// Output:
    40  	// SQL: SELECT * FROM user_tab WHERE id = ?;
    41  	// Args: [1]
    42  }
    43  
    44  func ExampleQuerier_Exec() {
    45  	db := memoryDB()
    46  	// 在 Exec 的时候,泛型参数可以是任意的
    47  	q := RawQuery[any](db, `CREATE TABLE IF NOT EXISTS groups (
    48     group_id INTEGER PRIMARY KEY,
    49     name TEXT NOT NULL
    50  )`)
    51  	res := q.Exec(context.Background())
    52  	if res.Err() == nil {
    53  		fmt.Print("SUCCESS")
    54  	}
    55  	// Output:
    56  	// SUCCESS
    57  }
    58  
    59  func TestQuerier_Get(t *testing.T) {
    60  	t.Run("unsafe", func(t *testing.T) {
    61  		testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue})
    62  	})
    63  
    64  	t.Run("reflect", func(t *testing.T) {
    65  		testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue})
    66  	})
    67  }
    68  
    69  func testQuerierGet(t *testing.T, creator valuer.PrimitiveCreator) {
    70  	db, mock, err := sqlmock.New()
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	defer func() { _ = db.Close() }()
    75  
    76  	orm, err := OpenDS("mysql", single.NewDB(db))
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  	testCases := []struct {
    81  		name     string
    82  		query    string
    83  		mockErr  error
    84  		mockRows *sqlmock.Rows
    85  		wantErr  error
    86  		wantVal  *TestModel
    87  	}{
    88  		{
    89  			// 查询返回错误
    90  			name:    "query error",
    91  			mockErr: errors.New("invalid query"),
    92  			wantErr: errors.New("invalid query"),
    93  			query:   "invalid query",
    94  		},
    95  		{
    96  			name:     "no row",
    97  			wantErr:  ErrNoRows,
    98  			query:    "no row",
    99  			mockRows: sqlmock.NewRows([]string{"id"}),
   100  		},
   101  		{
   102  			name:    "too many column",
   103  			wantErr: errs.ErrTooManyColumns,
   104  			query:   "too many column",
   105  			mockRows: func() *sqlmock.Rows {
   106  				res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name", "extra_column"})
   107  				res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"), []byte("nothing"))
   108  				return res
   109  			}(),
   110  		},
   111  		{
   112  			name:  "get data",
   113  			query: "SELECT xx FROM `test_model`",
   114  			mockRows: func() *sqlmock.Rows {
   115  				res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
   116  				res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"))
   117  				return res
   118  			}(),
   119  			wantVal: &TestModel{
   120  				Id:        1,
   121  				FirstName: "Da",
   122  				Age:       18,
   123  				LastName:  &sql.NullString{String: "Ming", Valid: true},
   124  			},
   125  		},
   126  	}
   127  
   128  	for _, tc := range testCases {
   129  		exp := mock.ExpectQuery(tc.query)
   130  		if tc.mockErr != nil {
   131  			exp.WillReturnError(tc.mockErr)
   132  		} else {
   133  			exp.WillReturnRows(tc.mockRows)
   134  		}
   135  	}
   136  	orm.valCreator = creator
   137  	for _, tc := range testCases {
   138  		t.Run(tc.name, func(t *testing.T) {
   139  			res, err := RawQuery[TestModel](orm, tc.query).Get(context.Background())
   140  			assert.Equal(t, tc.wantErr, err)
   141  			if err != nil {
   142  				return
   143  			}
   144  			assert.Equal(t, tc.wantVal, res)
   145  		})
   146  	}
   147  }
   148  
   149  func TestQuerierGetMulti(t *testing.T) {
   150  	t.Run("unsafe", func(t *testing.T) {
   151  		testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue})
   152  	})
   153  	t.Run("reflect", func(t *testing.T) {
   154  		testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue})
   155  	})
   156  }
   157  
   158  func testQuerier_GetMulti(t *testing.T, creator valuer.PrimitiveCreator) {
   159  	db, mock, err := sqlmock.New()
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	defer func() {
   164  		_ = db.Close()
   165  	}()
   166  	orm, err := OpenDS("mysql", single.NewDB(db))
   167  	if err != nil {
   168  		t.Fatal(err)
   169  	}
   170  	testCases := []struct {
   171  		name     string
   172  		query    string
   173  		mockErr  error
   174  		mockRows *sqlmock.Rows
   175  		wantErr  error
   176  		wantVal  []*TestModel
   177  	}{
   178  		{
   179  			name:    "query error",
   180  			mockErr: errors.New("invalid query"),
   181  			wantErr: errors.New("invalid query"),
   182  			query:   "invalid query",
   183  		},
   184  		{
   185  			name:     "no row",
   186  			query:    "no row",
   187  			mockRows: sqlmock.NewRows([]string{"id"}),
   188  			wantVal:  []*TestModel{},
   189  		},
   190  		{
   191  			name:    "too many column",
   192  			wantErr: errs.ErrTooManyColumns,
   193  			query:   "too many column",
   194  			mockRows: func() *sqlmock.Rows {
   195  				res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name", "extra_column"})
   196  				res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"), []byte("nothing"))
   197  				return res
   198  			}(),
   199  		},
   200  		{
   201  			name:  "get data",
   202  			query: "SELECT xx FROM `test_model`",
   203  			mockRows: func() *sqlmock.Rows {
   204  				res := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
   205  				res.AddRow([]byte("1"), []byte("Da"), []byte("18"), []byte("Ming"))
   206  				res.AddRow([]byte("2"), []byte("Xiao"), []byte("28"), []byte("Hong"))
   207  				return res
   208  			}(),
   209  			wantVal: []*TestModel{&TestModel{
   210  				Id:        1,
   211  				FirstName: "Da",
   212  				Age:       18,
   213  				LastName:  &sql.NullString{String: "Ming", Valid: true},
   214  			},
   215  				{
   216  					Id:        2,
   217  					FirstName: "Xiao",
   218  					Age:       28,
   219  					LastName:  &sql.NullString{String: "Hong", Valid: true},
   220  				},
   221  			},
   222  		},
   223  	}
   224  	for _, tc := range testCases {
   225  		exp := mock.ExpectQuery(tc.query)
   226  		if tc.mockErr != nil {
   227  			exp.WillReturnError(tc.mockErr)
   228  		} else {
   229  			exp.WillReturnRows(tc.mockRows)
   230  		}
   231  	}
   232  	orm.valCreator = creator
   233  	for _, tc := range testCases {
   234  		t.Run(tc.name, func(t *testing.T) {
   235  			res, err := RawQuery[TestModel](orm, tc.query).GetMulti(context.Background())
   236  			assert.Equal(t, tc.wantErr, err)
   237  			if err != nil {
   238  				return
   239  			}
   240  			assert.Equal(t, tc.wantVal, res)
   241  		})
   242  	}
   243  
   244  }