github.com/blend/go-sdk@v1.20220411.3/db/statement_interceptor_chain_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"testing"
    14  
    15  	"github.com/blend/go-sdk/assert"
    16  )
    17  
    18  func Test_StatementInterceptorChain(t *testing.T) {
    19  	its := assert.New(t)
    20  
    21  	var calls []string
    22  	a := func(_ context.Context, label, statement string) (string, error) {
    23  		calls = append(calls, "a")
    24  		return statement + "a", nil
    25  	}
    26  
    27  	b := func(_ context.Context, label, statement string) (string, error) {
    28  		calls = append(calls, "b")
    29  		return statement + "b", nil
    30  	}
    31  
    32  	c := func(_ context.Context, label, statement string) (string, error) {
    33  		calls = append(calls, "c")
    34  		return statement + "c", nil
    35  	}
    36  
    37  	chain := StatementInterceptorChain(a, b, c)
    38  	statement, err := chain(context.TODO(), "foo", "bar")
    39  	its.Nil(err)
    40  	its.Equal("barabc", statement)
    41  	its.Equal([]string{"a", "b", "c"}, calls)
    42  }
    43  
    44  func Test_StatementInterceptorChain_Errors(t *testing.T) {
    45  	its := assert.New(t)
    46  
    47  	var calls []string
    48  	a := func(_ context.Context, label, statement string) (string, error) {
    49  		calls = append(calls, "a")
    50  		return statement + "a", nil
    51  	}
    52  
    53  	b := func(_ context.Context, label, statement string) (string, error) {
    54  		calls = append(calls, "b")
    55  		return statement + "b", fmt.Errorf("this is just a test")
    56  	}
    57  
    58  	c := func(_ context.Context, label, statement string) (string, error) {
    59  		calls = append(calls, "c")
    60  		return statement + "c", nil
    61  	}
    62  
    63  	chain := StatementInterceptorChain(a, b, c)
    64  	statement, err := chain(context.TODO(), "foo", "bar")
    65  	its.NotNil(err)
    66  	its.Equal("barab", statement)
    67  	its.Equal([]string{"a", "b"}, calls)
    68  }
    69  
    70  func Test_StatementInterceptorChain_Empty(t *testing.T) {
    71  	its := assert.New(t)
    72  
    73  	chain := StatementInterceptorChain()
    74  	statement, err := chain(context.TODO(), "foo", "bar")
    75  	its.Nil(err)
    76  	its.Equal("bar", statement)
    77  }
    78  
    79  func Test_StatementInterceptorChain_Single(t *testing.T) {
    80  	its := assert.New(t)
    81  
    82  	var calls []string
    83  	a := func(_ context.Context, label, statement string) (string, error) {
    84  		calls = append(calls, "a")
    85  		return statement + "a", nil
    86  	}
    87  
    88  	chain := StatementInterceptorChain(a)
    89  	statement, err := chain(context.TODO(), "foo", "bar")
    90  	its.Nil(err)
    91  	its.Equal("bara", statement)
    92  	its.Equal([]string{"a"}, calls)
    93  }