go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/testing/assertions/error_tests.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package assertions
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"go.chromium.org/luci/common/errors"
    21  
    22  	"github.com/smarty/assertions"
    23  )
    24  
    25  // ShouldContainErr checks if an `errors.MultiError` on the left side contains
    26  // as one of its errors an `error` or `string` on the right side. If nothing is
    27  // provided on the right, checks that the left side contains at least one non-nil
    28  // error. If nil is provided on the right, checks that the left side contains
    29  // at least one nil, even if it contains other errors.
    30  //
    31  // Equivalent to calling ShouldErrLike on each `error` in an `errors.MultiError`
    32  // and succeeding as long as one of the ShouldErrLike calls succeeds.
    33  //
    34  // To avoid confusion, explicitly rejects the special case where the right side is
    35  // an `errors.MultiError`.
    36  func ShouldContainErr(actual any, expected ...any) string {
    37  	if len(expected) > 1 {
    38  		return fmt.Sprintf("ShouldContainErr requires 0 or 1 expected value, got %d", len(expected))
    39  	}
    40  
    41  	if actual == nil {
    42  		return assertions.ShouldNotBeNil(actual)
    43  	}
    44  
    45  	me, ok := actual.(errors.MultiError)
    46  	if !ok {
    47  		return assertions.ShouldHaveSameTypeAs(actual, errors.MultiError{})
    48  	}
    49  
    50  	if len(expected) == 0 {
    51  		return assertions.ShouldNotBeNil(me.First())
    52  	}
    53  
    54  	switch expected[0].(type) {
    55  	case string:
    56  	case error:
    57  	case errors.MultiError:
    58  		return fmt.Sprintf("expected value must not be a MultiError")
    59  	default:
    60  		if expected[0] != nil {
    61  			return fmt.Sprintf("unexpected argument type %T, expected string or error", expected[0])
    62  		}
    63  	}
    64  
    65  	for _, err := range me {
    66  		if ShouldErrLike(err, expected[0]) == "" {
    67  			return ""
    68  		}
    69  	}
    70  	return fmt.Sprintf("expected MultiError to contain %q", expected[0])
    71  }
    72  
    73  // ShouldErrLike compares an `error` or `string` on the left side, to `error`s
    74  // or `string`s on the right side.
    75  //
    76  // If multiple errors/strings are provided on the righthand side, they must all
    77  // be contained in the stringified error on the lefthand side.
    78  //
    79  // If the righthand side is the singluar `nil`, this expects the error to be
    80  // nil.
    81  //
    82  // Example:
    83  //
    84  //	// Usage                          Equivalent To
    85  //	So(err, ShouldErrLike, "custom")    // `err.Error()` ShouldContainSubstring "custom"
    86  //	So(err, ShouldErrLike, io.EOF)      // `err.Error()` ShouldContainSubstring io.EOF.Error()
    87  //	So(err, ShouldErrLike, "EOF")       // `err.Error()` ShouldContainSubstring "EOF"
    88  //	So(err, ShouldErrLike,
    89  //	   "thing", "other", "etc.")        // `err.Error()` contains all of these substrings.
    90  //	So(nilErr, ShouldErrLike, nil)      // nilErr ShouldBeNil
    91  //	So(nonNilErr, ShouldErrLike, "")    // nonNilErr ShouldNotBeNil
    92  func ShouldErrLike(actual any, expected ...any) string {
    93  	if len(expected) == 0 {
    94  		return "ShouldErrLike requires 1 or more expected values, got 0"
    95  	}
    96  
    97  	// If we have multiple expected arguments, they must all be non-nil
    98  	if len(expected) > 1 {
    99  		for _, e := range expected {
   100  			if e == nil {
   101  				return "ShouldErrLike only accepts `nil` on the right hand side as the sole argument."
   102  			}
   103  		}
   104  	}
   105  
   106  	if expected[0] == nil { // this can only happen if len(expected) == 1
   107  		return assertions.ShouldBeNil(actual)
   108  	} else if actual == nil {
   109  		return assertions.ShouldNotBeNil(actual)
   110  	}
   111  
   112  	ae, ok := actual.(error)
   113  	if !ok {
   114  		return assertions.ShouldImplement(actual, (*error)(nil))
   115  	}
   116  
   117  	for _, expect := range expected {
   118  		switch x := expect.(type) {
   119  		case string:
   120  			if ret := assertions.ShouldContainSubstring(ae.Error(), x); ret != "" {
   121  				return ret
   122  			}
   123  		case error:
   124  			if ret := assertions.ShouldContainSubstring(ae.Error(), x.Error()); ret != "" {
   125  				return ret
   126  			}
   127  		default:
   128  			return fmt.Sprintf("unexpected argument type %T, expected string or error", expect)
   129  		}
   130  	}
   131  
   132  	return ""
   133  }
   134  
   135  // ShouldPanicLike is the same as ShouldErrLike, but with the exception that it
   136  // takes a panic'ing func() as its first argument, instead of the error itself.
   137  func ShouldPanicLike(function any, expected ...any) (ret string) {
   138  	f, ok := function.(func())
   139  	if !ok {
   140  		return fmt.Sprintf("unexpected argument type %T, expected `func()`", function)
   141  	}
   142  	defer func() {
   143  		ret = ShouldErrLike(recover(), expected...)
   144  	}()
   145  	f()
   146  	return ShouldErrLike(nil, expected...)
   147  }
   148  
   149  // ShouldUnwrapTo asserts that an error, when unwrapped, equals another error.
   150  //
   151  // The actual field will be unwrapped using errors.Unwrap and then compared to
   152  // the error in expected.
   153  func ShouldUnwrapTo(actual any, expected ...any) string {
   154  	act, ok := actual.(error)
   155  	if !ok {
   156  		return fmt.Sprintf("ShouldUnwrapTo requires an error actual type, got %T", act)
   157  	}
   158  
   159  	if len(expected) != 1 {
   160  		return fmt.Sprintf("ShouldUnwrapTo requires exactly one expected value, got %d", len(expected))
   161  	}
   162  	exp, ok := expected[0].(error)
   163  	if !ok {
   164  		return fmt.Sprintf("ShouldUnwrapTo requires an error expected type, got %T", expected[0])
   165  	}
   166  
   167  	return assertions.ShouldEqual(errors.Unwrap(act), exp)
   168  }