github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/middleware_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "errors" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/stretchr/testify/assert" 25 ) 26 27 func Test_Middleware(t *testing.T) { 28 testCases := []struct { 29 name string 30 wantErr error 31 mdls []Middleware 32 }{ 33 { 34 name: "one middleware", 35 mdls: func() []Middleware { 36 var mdl Middleware = func(next HandleFunc) HandleFunc { 37 return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 38 return &QueryResult{} 39 } 40 } 41 return []Middleware{mdl} 42 }(), 43 }, 44 { 45 name: "many middleware", 46 mdls: func() []Middleware { 47 mdl1 := func(next HandleFunc) HandleFunc { 48 return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 49 return &QueryResult{Result: "mdl1"} 50 } 51 } 52 mdl2 := func(next HandleFunc) HandleFunc { 53 return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 54 return &QueryResult{Result: "mdl2"} 55 } 56 } 57 return []Middleware{mdl1, mdl2} 58 }(), 59 }, 60 } 61 for _, tc := range testCases { 62 t.Run(tc.name, func(t *testing.T) { 63 db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory", 64 DBWithMiddlewares(tc.mdls...)) 65 if err != nil { 66 t.Error(err) 67 } 68 defer func() { 69 _ = db.Close() 70 }() 71 assert.EqualValues(t, tc.mdls, db.ms) 72 }) 73 } 74 } 75 76 func Test_Middleware_order(t *testing.T) { 77 var res []byte 78 var mdl1 Middleware = func(next HandleFunc) HandleFunc { 79 return func(ctx context.Context, qc *QueryContext) *QueryResult { 80 res = append(res, '1') 81 return next(ctx, qc) 82 } 83 } 84 var mdl2 Middleware = func(next HandleFunc) HandleFunc { 85 return func(ctx context.Context, qc *QueryContext) *QueryResult { 86 res = append(res, '2') 87 return next(ctx, qc) 88 } 89 } 90 91 var mdl3 Middleware = func(next HandleFunc) HandleFunc { 92 return func(ctx context.Context, qc *QueryContext) *QueryResult { 93 res = append(res, '3') 94 return next(ctx, qc) 95 } 96 } 97 var last Middleware = func(next HandleFunc) HandleFunc { 98 return func(ctx context.Context, qc *QueryContext) *QueryResult { 99 return &QueryResult{ 100 Err: errors.New("mock error"), 101 } 102 } 103 } 104 db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory", 105 DBWithMiddlewares(mdl1, mdl2, mdl3, last)) 106 require.NoError(t, err) 107 108 _, err = NewSelector[TestModel](db).Get(context.Background()) 109 assert.Equal(t, errors.New("mock error"), err) 110 assert.Equal(t, "123", string(res)) 111 112 } 113 114 func TestQueryContext(t *testing.T) { 115 testCases := []struct { 116 name string 117 wantErr error 118 q Query 119 qc *QueryContext 120 }{ 121 { 122 name: "one middleware", 123 q: Query{ 124 SQL: `SELECT * FROM user_tab WHERE id = ?;`, 125 Args: []any{1}, 126 }, 127 qc: &QueryContext{ 128 q: Query{ 129 SQL: `SELECT * FROM user_tab WHERE id = ?;`, 130 Args: []any{1}, 131 }, 132 }, 133 }, 134 } 135 for _, tc := range testCases { 136 t.Run(tc.name, func(t *testing.T) { 137 assert.EqualValues(t, tc.q, tc.qc.GetQuery()) 138 }) 139 } 140 }