github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/middleware_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  )
    26  
    27  func Test_Middleware(t *testing.T) {
    28  	testCases := []struct {
    29  		name    string
    30  		wantErr error
    31  		mdls    []Middleware
    32  	}{
    33  		{
    34  			name: "one middleware",
    35  			mdls: func() []Middleware {
    36  				var mdl Middleware = func(next HandleFunc) HandleFunc {
    37  					return func(ctx context.Context, queryContext *QueryContext) *QueryResult {
    38  						return &QueryResult{}
    39  					}
    40  				}
    41  				return []Middleware{mdl}
    42  			}(),
    43  		},
    44  		{
    45  			name: "many middleware",
    46  			mdls: func() []Middleware {
    47  				mdl1 := func(next HandleFunc) HandleFunc {
    48  					return func(ctx context.Context, queryContext *QueryContext) *QueryResult {
    49  						return &QueryResult{Result: "mdl1"}
    50  					}
    51  				}
    52  				mdl2 := func(next HandleFunc) HandleFunc {
    53  					return func(ctx context.Context, queryContext *QueryContext) *QueryResult {
    54  						return &QueryResult{Result: "mdl2"}
    55  					}
    56  				}
    57  				return []Middleware{mdl1, mdl2}
    58  			}(),
    59  		},
    60  	}
    61  	for _, tc := range testCases {
    62  		t.Run(tc.name, func(t *testing.T) {
    63  			db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory",
    64  				DBWithMiddlewares(tc.mdls...))
    65  			if err != nil {
    66  				t.Error(err)
    67  			}
    68  			defer func() {
    69  				_ = db.Close()
    70  			}()
    71  			assert.EqualValues(t, tc.mdls, db.ms)
    72  		})
    73  	}
    74  }
    75  
    76  func Test_Middleware_order(t *testing.T) {
    77  	var res []byte
    78  	var mdl1 Middleware = func(next HandleFunc) HandleFunc {
    79  		return func(ctx context.Context, qc *QueryContext) *QueryResult {
    80  			res = append(res, '1')
    81  			return next(ctx, qc)
    82  		}
    83  	}
    84  	var mdl2 Middleware = func(next HandleFunc) HandleFunc {
    85  		return func(ctx context.Context, qc *QueryContext) *QueryResult {
    86  			res = append(res, '2')
    87  			return next(ctx, qc)
    88  		}
    89  	}
    90  
    91  	var mdl3 Middleware = func(next HandleFunc) HandleFunc {
    92  		return func(ctx context.Context, qc *QueryContext) *QueryResult {
    93  			res = append(res, '3')
    94  			return next(ctx, qc)
    95  		}
    96  	}
    97  	var last Middleware = func(next HandleFunc) HandleFunc {
    98  		return func(ctx context.Context, qc *QueryContext) *QueryResult {
    99  			return &QueryResult{
   100  				Err: errors.New("mock error"),
   101  			}
   102  		}
   103  	}
   104  	db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory",
   105  		DBWithMiddlewares(mdl1, mdl2, mdl3, last))
   106  	require.NoError(t, err)
   107  
   108  	_, err = NewSelector[TestModel](db).Get(context.Background())
   109  	assert.Equal(t, errors.New("mock error"), err)
   110  	assert.Equal(t, "123", string(res))
   111  
   112  }
   113  
   114  func TestQueryContext(t *testing.T) {
   115  	testCases := []struct {
   116  		name    string
   117  		wantErr error
   118  		q       Query
   119  		qc      *QueryContext
   120  	}{
   121  		{
   122  			name: "one middleware",
   123  			q: Query{
   124  				SQL:  `SELECT * FROM user_tab WHERE id = ?;`,
   125  				Args: []any{1},
   126  			},
   127  			qc: &QueryContext{
   128  				q: Query{
   129  					SQL:  `SELECT * FROM user_tab WHERE id = ?;`,
   130  					Args: []any{1},
   131  				},
   132  			},
   133  		},
   134  	}
   135  	for _, tc := range testCases {
   136  		t.Run(tc.name, func(t *testing.T) {
   137  			assert.EqualValues(t, tc.q, tc.qc.GetQuery())
   138  		})
   139  	}
   140  }