github.com/letsencrypt/boulder@v0.20251208.0/test/asserts.go (about)

     1  package test
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"reflect"
     9  	"slices"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/prometheus/client_golang/prometheus"
    15  	io_prometheus_client "github.com/prometheus/client_model/go"
    16  )
    17  
    18  // Assert a boolean
    19  func Assert(t *testing.T, result bool, message string) {
    20  	t.Helper()
    21  	if !result {
    22  		t.Fatal(message)
    23  	}
    24  }
    25  
    26  // AssertNil checks that an object is nil. Being a "boxed nil" (a nil value
    27  // wrapped in a non-nil interface type) is not good enough.
    28  func AssertNil(t *testing.T, obj any, message string) {
    29  	t.Helper()
    30  	if obj != nil {
    31  		t.Fatal(message)
    32  	}
    33  }
    34  
    35  // AssertNotNil checks an object to be non-nil. Being a "boxed nil" (a nil value
    36  // wrapped in a non-nil interface type) is not good enough.
    37  // Note that there is a gap between AssertNil and AssertNotNil. Both fail when
    38  // called with a boxed nil. This is intentional: we want to avoid boxed nils.
    39  func AssertNotNil(t *testing.T, obj any, message string) {
    40  	t.Helper()
    41  	if obj == nil {
    42  		t.Fatal(message)
    43  	}
    44  	switch reflect.TypeOf(obj).Kind() {
    45  	// .IsNil() only works on chan, func, interface, map, pointer, and slice.
    46  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
    47  		if reflect.ValueOf(obj).IsNil() {
    48  			t.Fatal(message)
    49  		}
    50  	}
    51  }
    52  
    53  // AssertBoxedNil checks that an inner object is nil. This is intentional for
    54  // testing purposes only.
    55  func AssertBoxedNil(t *testing.T, obj any, message string) {
    56  	t.Helper()
    57  	typ := reflect.TypeOf(obj).Kind()
    58  	switch typ {
    59  	// .IsNil() only works on chan, func, interface, map, pointer, and slice.
    60  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
    61  		if !reflect.ValueOf(obj).IsNil() {
    62  			t.Fatal(message)
    63  		}
    64  	default:
    65  		t.Fatalf("Cannot check type \"%s\". Needs to be of type chan, func, interface, map, pointer, or slice.", typ)
    66  	}
    67  }
    68  
    69  // AssertNotError checks that err is nil
    70  func AssertNotError(t *testing.T, err error, message string) {
    71  	t.Helper()
    72  	if err != nil {
    73  		t.Fatalf("%s: %s", message, err)
    74  	}
    75  }
    76  
    77  // AssertError checks that err is non-nil
    78  func AssertError(t *testing.T, err error, message string) {
    79  	t.Helper()
    80  	if err == nil {
    81  		t.Fatalf("%s: expected error but received none", message)
    82  	}
    83  }
    84  
    85  // AssertErrorWraps checks that err can be unwrapped into the given target.
    86  // NOTE: Has the side effect of actually performing that unwrapping.
    87  func AssertErrorWraps(t *testing.T, err error, target any) {
    88  	t.Helper()
    89  	if !errors.As(err, target) {
    90  		t.Fatalf("error does not wrap an error of the expected type: %q !> %+T", err.Error(), target)
    91  	}
    92  }
    93  
    94  // AssertErrorIs checks that err wraps the given error
    95  func AssertErrorIs(t *testing.T, err error, target error) {
    96  	t.Helper()
    97  
    98  	if err == nil {
    99  		t.Fatal("err was unexpectedly nil and should not have been")
   100  	}
   101  
   102  	if !errors.Is(err, target) {
   103  		t.Fatalf("error does not wrap expected error: %q !> %q", err.Error(), target.Error())
   104  	}
   105  }
   106  
   107  // AssertEquals uses the equality operator (==) to measure one and two
   108  func AssertEquals(t *testing.T, one any, two any) {
   109  	t.Helper()
   110  	if reflect.TypeOf(one) != reflect.TypeOf(two) {
   111  		t.Fatalf("cannot test equality of different types: %T != %T", one, two)
   112  	}
   113  	if one != two {
   114  		t.Fatalf("%#v != %#v", one, two)
   115  	}
   116  }
   117  
   118  // AssertDeepEquals uses the reflect.DeepEqual method to measure one and two
   119  func AssertDeepEquals(t *testing.T, one any, two any) {
   120  	t.Helper()
   121  	if !reflect.DeepEqual(one, two) {
   122  		t.Fatalf("[%#v] !(deep)= [%#v]", one, two)
   123  	}
   124  }
   125  
   126  // AssertMarshaledEquals marshals one and two to JSON, and then uses
   127  // the equality operator to measure them
   128  func AssertMarshaledEquals(t *testing.T, one any, two any) {
   129  	t.Helper()
   130  	oneJSON, err := json.Marshal(one)
   131  	AssertNotError(t, err, "Could not marshal 1st argument")
   132  	twoJSON, err := json.Marshal(two)
   133  	AssertNotError(t, err, "Could not marshal 2nd argument")
   134  
   135  	if !bytes.Equal(oneJSON, twoJSON) {
   136  		t.Fatalf("[%s] !(json)= [%s]", oneJSON, twoJSON)
   137  	}
   138  }
   139  
   140  // AssertUnmarshaledEquals unmarshals two JSON strings (got and expected) to
   141  // a map[string]interface{} and then uses reflect.DeepEqual to check they are
   142  // the same
   143  func AssertUnmarshaledEquals(t *testing.T, got, expected string) {
   144  	t.Helper()
   145  	var gotMap, expectedMap map[string]any
   146  	err := json.Unmarshal([]byte(got), &gotMap)
   147  	AssertNotError(t, err, "Could not unmarshal 'got'")
   148  	err = json.Unmarshal([]byte(expected), &expectedMap)
   149  	AssertNotError(t, err, "Could not unmarshal 'expected'")
   150  	if len(gotMap) != len(expectedMap) {
   151  		t.Errorf("Expected %d keys, but got %d", len(expectedMap), len(gotMap))
   152  	}
   153  	for k, v := range expectedMap {
   154  		if !reflect.DeepEqual(v, gotMap[k]) {
   155  			t.Errorf("Field %q: Expected \"%v\", got \"%v\"", k, v, gotMap[k])
   156  		}
   157  	}
   158  }
   159  
   160  // AssertNotEquals uses the equality operator to measure that one and two
   161  // are different
   162  func AssertNotEquals(t *testing.T, one any, two any) {
   163  	t.Helper()
   164  	if one == two {
   165  		t.Fatalf("%#v == %#v", one, two)
   166  	}
   167  }
   168  
   169  // AssertByteEquals uses bytes.Equal to measure one and two for equality.
   170  func AssertByteEquals(t *testing.T, one []byte, two []byte) {
   171  	t.Helper()
   172  	if !bytes.Equal(one, two) {
   173  		t.Fatalf("Byte [%s] != [%s]",
   174  			base64.StdEncoding.EncodeToString(one),
   175  			base64.StdEncoding.EncodeToString(two))
   176  	}
   177  }
   178  
   179  // AssertContains determines whether needle can be found in haystack
   180  func AssertContains(t *testing.T, haystack string, needle string) {
   181  	t.Helper()
   182  	if !strings.Contains(haystack, needle) {
   183  		t.Fatalf("String [%s] does not contain [%s]", haystack, needle)
   184  	}
   185  }
   186  
   187  // AssertNotContains determines if needle is not found in haystack
   188  func AssertNotContains(t *testing.T, haystack string, needle string) {
   189  	t.Helper()
   190  	if strings.Contains(haystack, needle) {
   191  		t.Fatalf("String [%s] contains [%s]", haystack, needle)
   192  	}
   193  }
   194  
   195  // AssertSliceContains determines if needle can be found in haystack
   196  func AssertSliceContains[T comparable](t *testing.T, haystack []T, needle T) {
   197  	t.Helper()
   198  	if slices.Contains(haystack, needle) {
   199  		return
   200  	}
   201  	t.Fatalf("Slice %v does not contain %v", haystack, needle)
   202  }
   203  
   204  // AssertMetricWithLabelsEquals determines whether the value held by a prometheus Collector
   205  // (e.g. Gauge, Counter, CounterVec, etc) is equal to the expected float64.
   206  // In order to make useful assertions about just a subset of labels (e.g. for a
   207  // CounterVec with fields "host" and "valid", being able to assert that two
   208  // "valid": "true" increments occurred, without caring which host was tagged in
   209  // each), takes a set of labels and ignores any metrics which have different
   210  // label values.
   211  // Only works for simple metrics (Counters and Gauges), or for the *count*
   212  // (not value) of data points in a Histogram.
   213  func AssertMetricWithLabelsEquals(t *testing.T, c prometheus.Collector, l prometheus.Labels, expected float64) {
   214  	t.Helper()
   215  	ch := make(chan prometheus.Metric)
   216  	done := make(chan struct{})
   217  	go func() {
   218  		c.Collect(ch)
   219  		close(done)
   220  	}()
   221  	var total float64
   222  	timeout := time.After(time.Second)
   223  loop:
   224  	for {
   225  	metric:
   226  		select {
   227  		case <-timeout:
   228  			t.Fatal("timed out collecting metrics")
   229  		case <-done:
   230  			break loop
   231  		case m := <-ch:
   232  			var iom io_prometheus_client.Metric
   233  			_ = m.Write(&iom)
   234  			for _, lp := range iom.Label {
   235  				// If any of the labels on this metric have the same name as but
   236  				// different value than a label in `l`, skip this metric.
   237  				val, ok := l[lp.GetName()]
   238  				if ok && lp.GetValue() != val {
   239  					break metric
   240  				}
   241  			}
   242  			// Exactly one of the Counter, Gauge, or Histogram values will be set by
   243  			// the .Write() operation, so add them all because the others will be 0.
   244  			total += iom.Counter.GetValue()
   245  			total += iom.Gauge.GetValue()
   246  			total += float64(iom.Histogram.GetSampleCount())
   247  		}
   248  	}
   249  	if total != expected {
   250  		t.Errorf("metric with labels %+v: got %g, want %g", l, total, expected)
   251  	}
   252  }
   253  
   254  // AssertHistogramBucketCount is similar to AssertMetricWithLabelsEquals, in
   255  // that it determines whether the number of samples within a given histogram
   256  // bucket matches the expectation. The bucket to check is indicated by a single
   257  // exemplar value; whichever bucket that value falls into is the bucket whose
   258  // sample count will be compared to the expected value.
   259  func AssertHistogramBucketCount(t *testing.T, c prometheus.Collector, l prometheus.Labels, b float64, expected uint64) {
   260  	t.Helper()
   261  	ch := make(chan prometheus.Metric)
   262  	done := make(chan struct{})
   263  	go func() {
   264  		c.Collect(ch)
   265  		close(done)
   266  	}()
   267  	var total uint64
   268  	timeout := time.After(time.Second)
   269  loop:
   270  	for {
   271  	metric:
   272  		select {
   273  		case <-timeout:
   274  			t.Fatal("timed out collecting metrics")
   275  		case <-done:
   276  			break loop
   277  		case m := <-ch:
   278  			var iom io_prometheus_client.Metric
   279  			_ = m.Write(&iom)
   280  			for _, lp := range iom.Label {
   281  				// If any of the labels on this metric have the same name as but
   282  				// different value than a label in `l`, skip this metric.
   283  				val, ok := l[lp.GetName()]
   284  				if ok && lp.GetValue() != val {
   285  					break metric
   286  				}
   287  			}
   288  			lowerBucketsCount := uint64(0)
   289  			for _, bucket := range iom.Histogram.Bucket {
   290  				if b <= bucket.GetUpperBound() {
   291  					total += bucket.GetCumulativeCount() - lowerBucketsCount
   292  					break
   293  				} else {
   294  					lowerBucketsCount += bucket.GetCumulativeCount()
   295  				}
   296  			}
   297  		}
   298  	}
   299  	if total != expected {
   300  		t.Errorf("histogram with labels %+v at bucket %g: got %d, want %d", l, b, total, expected)
   301  	}
   302  }