google.golang.org/grpc@v1.74.2/internal/xds/bootstrap/tlscreds/bundle_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package tlscreds
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"encoding/pem"
    26  	"errors"
    27  	"fmt"
    28  	"os"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"google.golang.org/grpc"
    34  	"google.golang.org/grpc/credentials/tls/certprovider"
    35  	"google.golang.org/grpc/internal/credentials/spiffe"
    36  	"google.golang.org/grpc/internal/grpctest"
    37  	"google.golang.org/grpc/internal/stubserver"
    38  	"google.golang.org/grpc/internal/testutils"
    39  	"google.golang.org/grpc/testdata"
    40  
    41  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    42  	testpb "google.golang.org/grpc/interop/grpc_testing"
    43  )
    44  
    45  const defaultTestTimeout = 5 * time.Second
    46  
    47  type s struct {
    48  	grpctest.Tester
    49  }
    50  
    51  func Test(t *testing.T) {
    52  	grpctest.RunSubTests(t, s{})
    53  }
    54  
    55  type failingProvider struct{}
    56  
    57  func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) {
    58  	return nil, errors.New("test error")
    59  }
    60  
    61  func (f failingProvider) Close() {}
    62  
    63  func (s) TestFailingProvider(t *testing.T) {
    64  	s := stubserver.StartTestService(t, nil, grpc.Creds(testutils.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)))
    65  	defer s.Stop()
    66  
    67  	cfg := fmt.Sprintf(`{
    68                 "ca_certificate_file": "%s",
    69                 "certificate_file": "%s",
    70                 "private_key_file": "%s",
    71  			   "spiffe_trust_bundle_map_file": "%s"
    72         }`,
    73  		testdata.Path("x509/server_ca_cert.pem"),
    74  		testdata.Path("x509/client1_cert.pem"),
    75  		testdata.Path("x509/client1_key.pem"),
    76  		testdata.Path("spiffe_end2end/client_spiffebundle.json"))
    77  	tlsBundle, stop, err := NewBundle([]byte(cfg))
    78  	if err != nil {
    79  		t.Fatalf("Failed to create TLS bundle: %v", err)
    80  	}
    81  	stop()
    82  
    83  	// Force a provider that returns an error, and make sure the client fails
    84  	// the handshake.
    85  	creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds)
    86  	if !ok {
    87  		t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials())
    88  	}
    89  	creds.provider = &failingProvider{}
    90  
    91  	conn, err := grpc.NewClient(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"))
    92  	if err != nil {
    93  		t.Fatalf("Error dialing: %v", err)
    94  	}
    95  	defer conn.Close()
    96  
    97  	client := testgrpc.NewTestServiceClient(conn)
    98  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    99  	defer cancel()
   100  	_, err = client.EmptyCall(ctx, &testpb.Empty{})
   101  	if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) {
   102  		t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr)
   103  	}
   104  }
   105  
   106  func rawCertsFromFile(t *testing.T, filePath string) [][]byte {
   107  	t.Helper()
   108  	rawCert, err := os.ReadFile(testdata.Path(filePath))
   109  	if err != nil {
   110  		t.Fatalf("Reading certificate file failed: %v", err)
   111  	}
   112  	block, _ := pem.Decode(rawCert)
   113  	if block == nil || block.Type != "CERTIFICATE" {
   114  		t.Fatalf("pem.Decode() failed to decode certificate in file %q", "spiffe/server1_spiffe.pem")
   115  	}
   116  	return [][]byte{block.Bytes}
   117  }
   118  
   119  func (s) TestSPIFFEVerifyFuncMismatchedCert(t *testing.T) {
   120  	spiffeBundleBytes, err := os.ReadFile(testdata.Path("spiffe_end2end/client_spiffebundle.json"))
   121  	if err != nil {
   122  		t.Fatalf("Reading spiffebundle file failed: %v", err)
   123  	}
   124  	spiffeBundle, err := spiffe.BundleMapFromBytes(spiffeBundleBytes)
   125  	if err != nil {
   126  		t.Fatalf("spiffe.BundleMapFromBytes() failed: %v", err)
   127  	}
   128  	verifyFunc := buildSPIFFEVerifyFunc(spiffeBundle)
   129  	verifiedChains := [][]*x509.Certificate{}
   130  	tests := []struct {
   131  		name            string
   132  		rawCerts        [][]byte
   133  		wantErrContains string
   134  	}{
   135  		{
   136  			name:            "mismathed cert",
   137  			rawCerts:        rawCertsFromFile(t, "spiffe/server1_spiffe.pem"),
   138  			wantErrContains: "spiffe: x509 certificate Verify failed",
   139  		},
   140  		{
   141  			name:            "bad input cert",
   142  			rawCerts:        [][]byte{[]byte("NOT_GOOD_DATA")},
   143  			wantErrContains: "spiffe: verify function could not parse input certificate",
   144  		},
   145  		{
   146  			name:            "no input bytes",
   147  			rawCerts:        nil,
   148  			wantErrContains: "no valid input certificates",
   149  		},
   150  	}
   151  	for _, tc := range tests {
   152  		t.Run(tc.name, func(t *testing.T) {
   153  			err = verifyFunc(tc.rawCerts, verifiedChains)
   154  			if err == nil {
   155  				t.Fatalf("buildSPIFFEVerifyFunc call succeeded. want failure")
   156  			}
   157  			if !strings.Contains(err.Error(), tc.wantErrContains) {
   158  				t.Fatalf("buildSPIFFEVerifyFunc got err %v want err to contain %v", err, tc.wantErrContains)
   159  			}
   160  		})
   161  	}
   162  }