github.com/letsencrypt/boulder@v0.20251208.0/ctpolicy/ctpolicy_test.go (about)

     1  package ctpolicy
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/jmhodges/clock"
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"google.golang.org/grpc"
    14  
    15  	"github.com/letsencrypt/boulder/core"
    16  	"github.com/letsencrypt/boulder/ctpolicy/loglist"
    17  	berrors "github.com/letsencrypt/boulder/errors"
    18  	blog "github.com/letsencrypt/boulder/log"
    19  	"github.com/letsencrypt/boulder/metrics"
    20  	pubpb "github.com/letsencrypt/boulder/publisher/proto"
    21  	"github.com/letsencrypt/boulder/test"
    22  )
    23  
    24  type mockPub struct{}
    25  
    26  func (mp *mockPub) SubmitToSingleCTWithResult(_ context.Context, _ *pubpb.Request, _ ...grpc.CallOption) (*pubpb.Result, error) {
    27  	return &pubpb.Result{Sct: []byte{0}}, nil
    28  }
    29  
    30  type mockFailPub struct{}
    31  
    32  func (mp *mockFailPub) SubmitToSingleCTWithResult(_ context.Context, _ *pubpb.Request, _ ...grpc.CallOption) (*pubpb.Result, error) {
    33  	return nil, errors.New("BAD")
    34  }
    35  
    36  type mockSlowPub struct{}
    37  
    38  func (mp *mockSlowPub) SubmitToSingleCTWithResult(ctx context.Context, _ *pubpb.Request, _ ...grpc.CallOption) (*pubpb.Result, error) {
    39  	<-ctx.Done()
    40  	return nil, errors.New("timed out")
    41  }
    42  
    43  func TestGetSCTs(t *testing.T) {
    44  	expired, cancel := context.WithDeadline(context.Background(), time.Now())
    45  	defer cancel()
    46  	missingSCTErr := berrors.MissingSCTs
    47  	testCases := []struct {
    48  		name       string
    49  		mock       pubpb.PublisherClient
    50  		logs       loglist.List
    51  		ctx        context.Context
    52  		result     core.SCTDERs
    53  		expectErr  string
    54  		berrorType *berrors.ErrorType
    55  	}{
    56  		{
    57  			name: "basic success case",
    58  			mock: &mockPub{},
    59  			logs: loglist.List{
    60  				{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
    61  				{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
    62  				{Name: "LogB1", Operator: "OperB", Url: "UrlB1", Key: []byte("KeyB1")},
    63  				{Name: "LogC1", Operator: "OperC", Url: "UrlC1", Key: []byte("KeyC1")},
    64  			},
    65  			ctx:    context.Background(),
    66  			result: core.SCTDERs{[]byte{0}, []byte{0}},
    67  		},
    68  		{
    69  			name: "basic failure case",
    70  			mock: &mockFailPub{},
    71  			logs: loglist.List{
    72  				{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
    73  				{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
    74  				{Name: "LogB1", Operator: "OperB", Url: "UrlB1", Key: []byte("KeyB1")},
    75  				{Name: "LogC1", Operator: "OperC", Url: "UrlC1", Key: []byte("KeyC1")},
    76  			},
    77  			ctx:        context.Background(),
    78  			expectErr:  "failed to get 2 SCTs, got 4 error(s)",
    79  			berrorType: &missingSCTErr,
    80  		},
    81  		{
    82  			name: "parent context timeout failure case",
    83  			mock: &mockSlowPub{},
    84  			logs: loglist.List{
    85  				{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
    86  				{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
    87  				{Name: "LogB1", Operator: "OperB", Url: "UrlB1", Key: []byte("KeyB1")},
    88  				{Name: "LogC1", Operator: "OperC", Url: "UrlC1", Key: []byte("KeyC1")},
    89  			},
    90  			ctx:        expired,
    91  			expectErr:  "failed to get 2 SCTs before ctx finished",
    92  			berrorType: &missingSCTErr,
    93  		},
    94  	}
    95  
    96  	for _, tc := range testCases {
    97  		t.Run(tc.name, func(t *testing.T) {
    98  			ctp := New(tc.mock, tc.logs, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
    99  			ret, err := ctp.GetSCTs(tc.ctx, []byte{0}, time.Time{})
   100  			if tc.result != nil {
   101  				test.AssertDeepEquals(t, ret, tc.result)
   102  			} else if tc.expectErr != "" {
   103  				if !strings.Contains(err.Error(), tc.expectErr) {
   104  					t.Errorf("Error %q did not match expected %q", err, tc.expectErr)
   105  				}
   106  				if tc.berrorType != nil {
   107  					test.AssertErrorIs(t, err, *tc.berrorType)
   108  				}
   109  			}
   110  		})
   111  	}
   112  }
   113  
   114  type mockFailOnePub struct {
   115  	badURL string
   116  }
   117  
   118  func (mp *mockFailOnePub) SubmitToSingleCTWithResult(_ context.Context, req *pubpb.Request, _ ...grpc.CallOption) (*pubpb.Result, error) {
   119  	if req.LogURL == mp.badURL {
   120  		return nil, errors.New("BAD")
   121  	}
   122  	return &pubpb.Result{Sct: []byte{0}}, nil
   123  }
   124  
   125  func TestGetSCTsMetrics(t *testing.T) {
   126  	ctp := New(&mockFailOnePub{badURL: "UrlA1"}, loglist.List{
   127  		{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
   128  		{Name: "LogB1", Operator: "OperB", Url: "UrlB1", Key: []byte("KeyB1")},
   129  		{Name: "LogC1", Operator: "OperC", Url: "UrlC1", Key: []byte("KeyC1")},
   130  	}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
   131  	_, err := ctp.GetSCTs(context.Background(), []byte{0}, time.Time{})
   132  	test.AssertNotError(t, err, "GetSCTs failed")
   133  	test.AssertMetricWithLabelsEquals(t, ctp.winnerCounter, prometheus.Labels{"url": "UrlB1", "result": succeeded}, 1)
   134  	test.AssertMetricWithLabelsEquals(t, ctp.winnerCounter, prometheus.Labels{"url": "UrlC1", "result": succeeded}, 1)
   135  }
   136  
   137  func TestGetSCTsFailMetrics(t *testing.T) {
   138  	// Ensure the proper metrics are incremented when GetSCTs fails.
   139  	ctp := New(&mockFailOnePub{badURL: "UrlA1"}, loglist.List{
   140  		{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
   141  		{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
   142  	}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
   143  	_, err := ctp.GetSCTs(context.Background(), []byte{0}, time.Time{})
   144  	test.AssertError(t, err, "GetSCTs should have failed")
   145  	test.AssertErrorIs(t, err, berrors.MissingSCTs)
   146  	test.AssertMetricWithLabelsEquals(t, ctp.winnerCounter, prometheus.Labels{"url": "UrlA1", "result": failed}, 1)
   147  
   148  	// Ensure the proper metrics are incremented when GetSCTs times out.
   149  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   150  	defer cancel()
   151  
   152  	ctp = New(&mockSlowPub{}, loglist.List{
   153  		{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
   154  		{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
   155  	}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
   156  	_, err = ctp.GetSCTs(ctx, []byte{0}, time.Time{})
   157  	test.AssertError(t, err, "GetSCTs should have timed out")
   158  	test.AssertErrorIs(t, err, berrors.MissingSCTs)
   159  	test.AssertContains(t, err.Error(), context.DeadlineExceeded.Error())
   160  	test.AssertMetricWithLabelsEquals(t, ctp.winnerCounter, prometheus.Labels{"url": "UrlA1", "result": failed}, 1)
   161  }
   162  
   163  func TestLogListMetrics(t *testing.T) {
   164  	fc := clock.NewFake()
   165  	Tomorrow := fc.Now().Add(24 * time.Hour)
   166  	NextWeek := fc.Now().Add(7 * 24 * time.Hour)
   167  
   168  	// Multiple operator groups with configured logs.
   169  	ctp := New(&mockPub{}, loglist.List{
   170  		{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1"), EndExclusive: Tomorrow},
   171  		{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2"), EndExclusive: NextWeek},
   172  		{Name: "LogB1", Operator: "OperB", Url: "UrlB1", Key: []byte("KeyB1"), EndExclusive: Tomorrow},
   173  	}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
   174  	test.AssertMetricWithLabelsEquals(t, ctp.shardExpiryGauge, prometheus.Labels{"operator": "OperA", "logID": "LogA1"}, 86400)
   175  	test.AssertMetricWithLabelsEquals(t, ctp.shardExpiryGauge, prometheus.Labels{"operator": "OperA", "logID": "LogA2"}, 604800)
   176  	test.AssertMetricWithLabelsEquals(t, ctp.shardExpiryGauge, prometheus.Labels{"operator": "OperB", "logID": "LogB1"}, 86400)
   177  }
   178  
   179  func TestCompliantSet(t *testing.T) {
   180  	for _, tc := range []struct {
   181  		name    string
   182  		results []result
   183  		want    core.SCTDERs
   184  	}{
   185  		{
   186  			name:    "nil input",
   187  			results: nil,
   188  			want:    nil,
   189  		},
   190  		{
   191  			name:    "zero length input",
   192  			results: []result{},
   193  			want:    nil,
   194  		},
   195  		{
   196  			name: "only one result",
   197  			results: []result{
   198  				{log: loglist.Log{Operator: "A", Tiled: false}, sct: []byte("sct1")},
   199  			},
   200  			want: nil,
   201  		},
   202  		{
   203  			name: "only one good result",
   204  			results: []result{
   205  				{log: loglist.Log{Operator: "A", Tiled: false}, sct: []byte("sct1")},
   206  				{log: loglist.Log{Operator: "B", Tiled: false}, err: errors.New("oops")},
   207  			},
   208  			want: nil,
   209  		},
   210  		{
   211  			name: "only one operator",
   212  			results: []result{
   213  				{log: loglist.Log{Operator: "A", Tiled: false}, sct: []byte("sct1")},
   214  				{log: loglist.Log{Operator: "A", Tiled: false}, sct: []byte("sct2")},
   215  			},
   216  			want: nil,
   217  		},
   218  		{
   219  			name: "all tiled",
   220  			results: []result{
   221  				{log: loglist.Log{Operator: "A", Tiled: true}, sct: []byte("sct1")},
   222  				{log: loglist.Log{Operator: "B", Tiled: true}, sct: []byte("sct2")},
   223  			},
   224  			want: nil,
   225  		},
   226  		{
   227  			name: "happy path",
   228  			results: []result{
   229  				{log: loglist.Log{Operator: "A", Tiled: false}, err: errors.New("oops")},
   230  				{log: loglist.Log{Operator: "A", Tiled: true}, sct: []byte("sct2")},
   231  				{log: loglist.Log{Operator: "A", Tiled: false}, sct: []byte("sct3")},
   232  				{log: loglist.Log{Operator: "B", Tiled: false}, err: errors.New("oops")},
   233  				{log: loglist.Log{Operator: "B", Tiled: true}, sct: []byte("sct4")},
   234  				{log: loglist.Log{Operator: "B", Tiled: false}, sct: []byte("sct6")},
   235  				{log: loglist.Log{Operator: "C", Tiled: false}, err: errors.New("oops")},
   236  				{log: loglist.Log{Operator: "C", Tiled: true}, sct: []byte("sct8")},
   237  				{log: loglist.Log{Operator: "C", Tiled: false}, sct: []byte("sct9")},
   238  			},
   239  			// The second and sixth results should be picked, because first and fourth
   240  			// are skipped for being errors, and fifth is skipped for also being tiled.
   241  			want: core.SCTDERs{[]byte("sct2"), []byte("sct6")},
   242  		},
   243  	} {
   244  		t.Run(tc.name, func(t *testing.T) {
   245  			got := compliantSet(tc.results)
   246  			if len(got) != len(tc.want) {
   247  				t.Fatalf("compliantSet(%#v) returned %d SCTs, but want %d", tc.results, len(got), len(tc.want))
   248  			}
   249  			for i, sct := range tc.want {
   250  				if !bytes.Equal(got[i], sct) {
   251  					t.Errorf("compliantSet(%#v) returned unexpected SCT at index %d", tc.results, i)
   252  				}
   253  			}
   254  		})
   255  	}
   256  }