github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/errors/clean_up_test.go (about)

     1  package errors
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strconv"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  type errCallable struct{ error }
    13  
    14  func (e errCallable) Func() error                   { return e.error }
    15  func (e errCallable) FuncCtx(context.Context) error { return e.error }
    16  
    17  func TestCleanUp(t *testing.T) {
    18  	const (
    19  		closeMsg  = "close [seuozr]"
    20  		returnMsg = "return [mntbnb]"
    21  	)
    22  
    23  	for callIdx, call := range []func(errCallable, *error){
    24  		func(e errCallable, err *error) { CleanUp(e.Func, err) },
    25  		func(e errCallable, err *error) { CleanUpCtx(context.Background(), e.FuncCtx, err) },
    26  	} {
    27  		t.Run(strconv.Itoa(callIdx), func(t *testing.T) {
    28  			// No return error, no close error.
    29  			gotErr := func() (err error) {
    30  				e := errCallable{}
    31  				defer call(e, &err)
    32  				return nil
    33  			}()
    34  			assert.NoError(t, gotErr)
    35  
    36  			// No return error, close error.
    37  			gotErr = func() (err error) {
    38  				e := errCallable{errors.New(closeMsg)}
    39  				defer call(e, &err)
    40  				return nil
    41  			}()
    42  			assert.Equal(t, gotErr.Error(), closeMsg)
    43  
    44  			// Return error, no close error.
    45  			gotErr = func() (err error) {
    46  				e := errCallable{}
    47  				defer call(e, &err)
    48  				return errors.New(returnMsg)
    49  			}()
    50  			assert.Equal(t, gotErr.Error(), returnMsg)
    51  
    52  			// Return error, close error.
    53  			gotErr = func() (err error) {
    54  				e := errCallable{errors.New(closeMsg)}
    55  				defer call(e, &err)
    56  				return errors.New(returnMsg)
    57  			}()
    58  			assert.Contains(t, gotErr.Error(), returnMsg)
    59  			assert.Contains(t, gotErr.Error(), closeMsg)
    60  		})
    61  	}
    62  }