istio.io/istio@v0.0.0-20240520182934-d79c90f27776/security/pkg/pki/util/san_test.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"crypto/x509/pkix"
    19  	"encoding/asn1"
    20  	"net/netip"
    21  	"reflect"
    22  	"testing"
    23  )
    24  
    25  func getSANExtension(identites []Identity, t *testing.T) *pkix.Extension {
    26  	ext, err := BuildSANExtension(identites)
    27  	if err != nil {
    28  		t.Errorf("A unexpected error has been encountered (error: %v)", err)
    29  	}
    30  	return ext
    31  }
    32  
    33  func TestBuildSubjectAltNameExtension(t *testing.T) {
    34  	uriIdentity := Identity{Type: TypeURI, Value: []byte("spiffe://test.domain.com/ns/default/sa/default")}
    35  	ipIdentity := Identity{Type: TypeIP, Value: netip.MustParseAddr("10.0.0.1").AsSlice()}
    36  	dnsIdentity := Identity{Type: TypeDNS, Value: []byte("test.domain.com")}
    37  
    38  	testCases := map[string]struct {
    39  		hosts       string
    40  		expectedExt *pkix.Extension
    41  	}{
    42  		"URI host": {
    43  			hosts:       "spiffe://test.domain.com/ns/default/sa/default",
    44  			expectedExt: getSANExtension([]Identity{uriIdentity}, t),
    45  		},
    46  		"IP host": {
    47  			hosts:       "10.0.0.1",
    48  			expectedExt: getSANExtension([]Identity{ipIdentity}, t),
    49  		},
    50  		"DNS host": {
    51  			hosts:       "test.domain.com",
    52  			expectedExt: getSANExtension([]Identity{dnsIdentity}, t),
    53  		},
    54  		"URI, IP and DNS hosts": {
    55  			hosts:       "spiffe://test.domain.com/ns/default/sa/default,10.0.0.1,test.domain.com",
    56  			expectedExt: getSANExtension([]Identity{uriIdentity, ipIdentity, dnsIdentity}, t),
    57  		},
    58  	}
    59  
    60  	for id, tc := range testCases {
    61  		if ext, err := BuildSubjectAltNameExtension(tc.hosts); err != nil {
    62  			t.Errorf("Case %q: a unexpected error has been encountered (error: %v)", id, err)
    63  		} else if !reflect.DeepEqual(ext, tc.expectedExt) {
    64  			t.Errorf("Case %q: unexpected extension returned: want %v but got %v", id, tc.expectedExt, ext)
    65  		}
    66  	}
    67  }
    68  
    69  func TestBuildAndExtractIdentities(t *testing.T) {
    70  	ids := []Identity{
    71  		{Type: TypeDNS, Value: []byte("test.domain.com")},
    72  		{Type: TypeIP, Value: []byte("10.0.0.1")},
    73  		{Type: TypeURI, Value: []byte("spiffe://test.domain.com/ns/default/sa/default")},
    74  	}
    75  	san, err := BuildSANExtension(ids)
    76  	if err != nil {
    77  		t.Errorf("A unexpected error has been encountered (error: %v)", err)
    78  	}
    79  
    80  	actualIDs, err := ExtractIDsFromSAN(san)
    81  	if err != nil {
    82  		t.Errorf("A unexpected error has been encountered (error: %v)", err)
    83  	}
    84  
    85  	if !reflect.DeepEqual(actualIDs, ids) {
    86  		t.Errorf("Unmatched identities: before encoding: %v, after decoding %v", ids, actualIDs)
    87  	}
    88  
    89  	if !san.Critical {
    90  		t.Errorf("SAN field is not critical.")
    91  	}
    92  }
    93  
    94  func TestBuildSANExtensionWithError(t *testing.T) {
    95  	id := Identity{Type: 10}
    96  	if _, err := BuildSANExtension([]Identity{id}); err == nil {
    97  		t.Error("Expecting error to be returned but got nil")
    98  	}
    99  }
   100  
   101  func TestExtractIDsFromSANWithError(t *testing.T) {
   102  	testCases := map[string]struct {
   103  		ext *pkix.Extension
   104  	}{
   105  		"Wrong OID": {
   106  			ext: &pkix.Extension{
   107  				Id: asn1.ObjectIdentifier{1, 2, 3},
   108  			},
   109  		},
   110  		"Wrong encoding": {
   111  			ext: &pkix.Extension{
   112  				Id:    oidSubjectAlternativeName,
   113  				Value: []byte("bad value"),
   114  			},
   115  		},
   116  	}
   117  
   118  	for id, tc := range testCases {
   119  		if _, err := ExtractIDsFromSAN(tc.ext); err == nil {
   120  			t.Errorf("%v: Expecting error to be returned but got nil", id)
   121  		}
   122  	}
   123  }
   124  
   125  func TestExtractIDsFromSANWithBadEncoding(t *testing.T) {
   126  	ext := &pkix.Extension{
   127  		Id:    oidSubjectAlternativeName,
   128  		Value: []byte("bad value"),
   129  	}
   130  
   131  	if _, err := ExtractIDsFromSAN(ext); err == nil {
   132  		t.Error("Expecting error to be returned but got nil")
   133  	}
   134  }
   135  
   136  func TestExtractSANExtension(t *testing.T) {
   137  	testCases := map[string]struct {
   138  		exts  []pkix.Extension
   139  		found bool
   140  	}{
   141  		"No extension": {
   142  			exts:  []pkix.Extension{},
   143  			found: false,
   144  		},
   145  		"An extensions with wrong OID": {
   146  			exts: []pkix.Extension{
   147  				{Id: asn1.ObjectIdentifier{1, 2, 3}},
   148  			},
   149  			found: false,
   150  		},
   151  		"Correct SAN extension": {
   152  			exts: []pkix.Extension{
   153  				{Id: asn1.ObjectIdentifier{1, 2, 3}},
   154  				{Id: asn1.ObjectIdentifier{2, 5, 29, 17}},
   155  				{Id: asn1.ObjectIdentifier{3, 2, 1}},
   156  			},
   157  			found: true,
   158  		},
   159  	}
   160  
   161  	for id, tc := range testCases {
   162  		found := ExtractSANExtension(tc.exts) != nil
   163  		if found != tc.found {
   164  			t.Errorf("Case %q: expect `found` to be %t but got %t", id, tc.found, found)
   165  		}
   166  	}
   167  }
   168  
   169  func TestExtractIDs(t *testing.T) {
   170  	id := "test.id"
   171  	sanExt, err := BuildSANExtension([]Identity{
   172  		{Type: TypeURI, Value: []byte(id)},
   173  	})
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  
   178  	testCases := map[string]struct {
   179  		exts           []pkix.Extension
   180  		expectedIDs    []string
   181  		expectedErrMsg string
   182  	}{
   183  		"Empty extension list": {
   184  			exts:           []pkix.Extension{},
   185  			expectedIDs:    nil,
   186  			expectedErrMsg: "the SAN extension does not exist",
   187  		},
   188  		"Extensions without SAN": {
   189  			exts: []pkix.Extension{
   190  				{Id: asn1.ObjectIdentifier{1, 2, 3, 4}},
   191  				{Id: asn1.ObjectIdentifier{3, 2, 1}},
   192  			},
   193  			expectedIDs:    nil,
   194  			expectedErrMsg: "the SAN extension does not exist",
   195  		},
   196  		"Extensions with bad SAN": {
   197  			exts: []pkix.Extension{
   198  				{Id: asn1.ObjectIdentifier{2, 5, 29, 17}, Value: []byte("bad san bytes")},
   199  			},
   200  			expectedIDs:    nil,
   201  			expectedErrMsg: "failed to extract identities from SAN extension (error asn1: syntax error: data truncated)",
   202  		},
   203  		"Extensions with incorrectly encoded SAN": {
   204  			exts: []pkix.Extension{
   205  				{Id: asn1.ObjectIdentifier{2, 5, 29, 17}, Value: append(copyBytes(sanExt.Value), 'x')},
   206  			},
   207  			expectedIDs:    nil,
   208  			expectedErrMsg: "failed to extract identities from SAN extension (error the SAN extension is incorrectly encoded)",
   209  		},
   210  		"Extensions with SAN": {
   211  			exts: []pkix.Extension{
   212  				{Id: asn1.ObjectIdentifier{1, 2, 3, 4}},
   213  				*sanExt,
   214  				{Id: asn1.ObjectIdentifier{3, 2, 1}},
   215  			},
   216  			expectedIDs: []string{id},
   217  		},
   218  	}
   219  
   220  	for id, tc := range testCases {
   221  		actualIDs, err := ExtractIDs(tc.exts)
   222  		if !reflect.DeepEqual(actualIDs, tc.expectedIDs) {
   223  			t.Errorf("Case %q: unexpected identities: want %v but got %v", id, tc.expectedIDs, actualIDs)
   224  		}
   225  		if tc.expectedErrMsg != "" {
   226  			if err == nil {
   227  				t.Errorf("Case %q: no error message returned: want %s", id, tc.expectedErrMsg)
   228  			} else if tc.expectedErrMsg != err.Error() {
   229  				t.Errorf("Case %q: unexpected error message: want %s but got %s", id, tc.expectedErrMsg, err.Error())
   230  			}
   231  		}
   232  	}
   233  }