istio.io/istio@v0.0.0-20240520182934-d79c90f27776/security/pkg/server/ca/authenticate/oidc_test.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package authenticate
    15  
    16  import (
    17  	"context"
    18  	"crypto/rand"
    19  	"crypto/rsa"
    20  	"encoding/json"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"reflect"
    25  	"strconv"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/go-jose/go-jose/v3"
    30  	"google.golang.org/grpc/metadata"
    31  
    32  	"istio.io/api/security/v1beta1"
    33  	"istio.io/istio/pkg/security"
    34  	"istio.io/istio/pkg/spiffe"
    35  )
    36  
    37  const (
    38  	bearerTokenPrefix = "Bearer "
    39  )
    40  
    41  type jwksServer struct {
    42  	key jose.JSONWebKeySet
    43  	t   *testing.T
    44  }
    45  
    46  func (k *jwksServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    47  	if err := json.NewEncoder(w).Encode(k.key); err != nil {
    48  		k.t.Fatalf("failed to encode the jwks: %v", err)
    49  	}
    50  }
    51  
    52  func TestNewJwtAuthenticator(t *testing.T) {
    53  	tests := []struct {
    54  		name      string
    55  		expectErr bool
    56  		jwtRule   string
    57  	}{
    58  		{
    59  			name:      "jwt rule with jwks_uri",
    60  			expectErr: false,
    61  			jwtRule:   `{"issuer": "foo", "jwks_uri": "baz", "audiences": ["aud1", "aud2"]}`,
    62  		},
    63  		{
    64  			name: "jwt rule with OIDC config expected to fail",
    65  			// "foo/.well-known/openid-configuration" is expected to fail
    66  			expectErr: true,
    67  			jwtRule:   `{"issuer": "foo", "audiences": ["aud1", "aud2"]}`,
    68  		},
    69  	}
    70  
    71  	for _, tt := range tests {
    72  		t.Run(tt.name, func(t *testing.T) {
    73  			jwtRule := v1beta1.JWTRule{}
    74  			err := json.Unmarshal([]byte(tt.jwtRule), &jwtRule)
    75  			if err != nil {
    76  				t.Fatalf("failed at unmarshal the jwt rule (%v), err: %v",
    77  					tt.jwtRule, err)
    78  			}
    79  			_, err = NewJwtAuthenticator(&jwtRule)
    80  			gotErr := err != nil
    81  			if gotErr != tt.expectErr {
    82  				t.Errorf("expect error is %v while actual error is %v", tt.expectErr, gotErr)
    83  			}
    84  		})
    85  	}
    86  }
    87  
    88  func TestCheckAudience(t *testing.T) {
    89  	tests := []struct {
    90  		name        string
    91  		expectRet   bool
    92  		audToCheck  []string
    93  		audExpected []string
    94  	}{
    95  		{
    96  			name:        "audience is in the expected set",
    97  			expectRet:   true,
    98  			audToCheck:  []string{"aud1"},
    99  			audExpected: []string{"aud1", "aud2"},
   100  		},
   101  		{
   102  			name:        "audience is NOT in the expected set",
   103  			expectRet:   false,
   104  			audToCheck:  []string{"aud3"},
   105  			audExpected: []string{"aud1", "aud2"},
   106  		},
   107  		{
   108  			name:        "one of the audiences is in the expected set",
   109  			expectRet:   true,
   110  			audToCheck:  []string{"aud1", "aud3"},
   111  			audExpected: []string{"aud1", "aud2"},
   112  		},
   113  	}
   114  
   115  	for _, tt := range tests {
   116  		t.Run(tt.name, func(t *testing.T) {
   117  			ret := checkAudience(tt.audToCheck, tt.audExpected)
   118  			if ret != tt.expectRet {
   119  				t.Errorf("expected return is %v while actual return is %v", tt.expectRet, ret)
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func TestOIDCAuthenticate(t *testing.T) {
   126  	// Create a JWKS server
   127  	rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
   128  	if err != nil {
   129  		t.Fatalf("failed to generate a private key: %v", err)
   130  	}
   131  	key := jose.JSONWebKey{Algorithm: string(jose.RS256), Key: rsaKey}
   132  	keySet := jose.JSONWebKeySet{}
   133  	keySet.Keys = append(keySet.Keys, key.Public())
   134  	server := httptest.NewServer(&jwksServer{key: keySet})
   135  	defer server.Close()
   136  
   137  	spiffe.SetTrustDomain("baz.svc.id.goog")
   138  
   139  	// Create a JWT authenticator
   140  	jwtRuleStr := `{"issuer": "` + server.URL + `", "jwks_uri": "` + server.URL + `", "audiences": ["baz.svc.id.goog"]}`
   141  	jwtRule := v1beta1.JWTRule{}
   142  	err = json.Unmarshal([]byte(jwtRuleStr), &jwtRule)
   143  	if err != nil {
   144  		t.Fatalf("failed at unmarshal jwt rule")
   145  	}
   146  	authenticator, err := NewJwtAuthenticator(&jwtRule)
   147  	if err != nil {
   148  		t.Fatalf("failed to create the JWT authenticator: %v", err)
   149  	}
   150  
   151  	// Create a valid JWT token
   152  	expStr := strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)
   153  	claims := `{"iss": "` + server.URL + `", "aud": ["baz.svc.id.goog"], "sub": "system:serviceaccount:bar:foo", "exp": ` + expStr + `}`
   154  	token, err := generateJWT(&key, []byte(claims))
   155  	if err != nil {
   156  		t.Fatalf("failed to generate JWT: %v", err)
   157  	}
   158  	// Create an expired JWT token
   159  	expiredStr := strconv.FormatInt(time.Now().Add(-time.Hour).Unix(), 10)
   160  	expiredClaims := `{"iss": "` + server.URL + `", "aud": ["baz.svc.id.goog"], "sub": "system:serviceaccount:bar:foo", "exp": ` + expiredStr + `}`
   161  	expiredToken, err := generateJWT(&key, []byte(expiredClaims))
   162  	if err != nil {
   163  		t.Fatalf("failed to generate an expired JWT: %v", err)
   164  	}
   165  	// Create a JWT token with wrong audience
   166  	claimsWrongAudience := `{"iss": "` + server.URL + `", "aud": ["wrong-audience"], "sub": "system:serviceaccount:bar:foo", "exp": ` + expStr + `}`
   167  	tokenWrongAudience, err := generateJWT(&key, []byte(claimsWrongAudience))
   168  	if err != nil {
   169  		t.Fatalf("failed to generate JWT: %v", err)
   170  	}
   171  	// Create a JWT token with invalid subject, which is not prefixed with "system:serviceaccount"
   172  	claimsWrongSubject := `{"iss": "` + server.URL + `", "aud": ["baz.svc.id.goog"], "sub": "bar:foo", "exp": ` + expStr + `}`
   173  	tokenInvalidSubject, err := generateJWT(&key, []byte(claimsWrongSubject))
   174  	if err != nil {
   175  		t.Fatalf("failed to generate JWT: %v", err)
   176  	}
   177  
   178  	tests := map[string]struct {
   179  		token      string
   180  		expectErr  bool
   181  		expectedID string
   182  	}{
   183  		"No bearer token": {
   184  			expectErr: true,
   185  		},
   186  		"Valid token": {
   187  			token:      token,
   188  			expectErr:  false,
   189  			expectedID: spiffe.MustGenSpiffeURI("bar", "foo"),
   190  		},
   191  		"Expired token": {
   192  			token:     expiredToken,
   193  			expectErr: true,
   194  		},
   195  		"Token with wrong audience": {
   196  			token:     tokenWrongAudience,
   197  			expectErr: true,
   198  		},
   199  		"Token with invalid subject": {
   200  			token:     tokenInvalidSubject,
   201  			expectErr: true,
   202  		},
   203  	}
   204  
   205  	for name, tc := range tests {
   206  		t.Run(name, func(t *testing.T) {
   207  			ctx := context.Background()
   208  			md := metadata.MD{}
   209  			if tc.token != "" {
   210  				token := bearerTokenPrefix + tc.token
   211  				md.Append("authorization", token)
   212  			}
   213  			ctx = metadata.NewIncomingContext(ctx, md)
   214  
   215  			actualCaller, err := authenticator.Authenticate(security.AuthContext{GrpcContext: ctx})
   216  			gotErr := err != nil
   217  			if gotErr != tc.expectErr {
   218  				t.Errorf("gotErr (%v) whereas expectErr (%v)", gotErr, tc.expectErr)
   219  			}
   220  			if gotErr {
   221  				return
   222  			}
   223  			expectedCaller := &security.Caller{
   224  				AuthSource: security.AuthSourceIDToken,
   225  				Identities: []string{tc.expectedID},
   226  			}
   227  			if !reflect.DeepEqual(actualCaller, expectedCaller) {
   228  				t.Errorf("%v: unexpected caller (want %v but got %v)", name, expectedCaller, actualCaller)
   229  			}
   230  		})
   231  	}
   232  }
   233  
   234  func generateJWT(key *jose.JSONWebKey, claims []byte) (string, error) {
   235  	signer, err := jose.NewSigner(jose.SigningKey{
   236  		Algorithm: jose.SignatureAlgorithm(key.Algorithm),
   237  		Key:       key,
   238  	}, nil)
   239  	if err != nil {
   240  		return "", fmt.Errorf("failed to create a signer: %v", err)
   241  	}
   242  	signature, err := signer.Sign(claims)
   243  	if err != nil {
   244  		return "", fmt.Errorf("failed to sign claims: %v", err)
   245  	}
   246  	jwt, err := signature.CompactSerialize()
   247  	if err != nil {
   248  		return "", fmt.Errorf("failed to serialize the JWT: %v", err)
   249  	}
   250  	return jwt, nil
   251  }