github.com/grailbio/base@v0.0.11/cmd/ticket-server/k8sblesser_test.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"os/exec"
    10  	"path"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aws/aws-sdk-go/service/eks"
    15  	ticketServerUtil "github.com/grailbio/base/cmd/ticket-server/testutil"
    16  	"github.com/grailbio/base/errors"
    17  	"github.com/grailbio/base/security/identity"
    18  	"github.com/grailbio/base/vcontext"
    19  	"github.com/grailbio/testutil"
    20  
    21  	assert "github.com/stretchr/testify/assert"
    22  
    23  	auth "k8s.io/api/authentication/v1"
    24  	client "k8s.io/client-go/kubernetes/typed/authentication/v1"
    25  	rest "k8s.io/client-go/rest"
    26  
    27  	"v.io/v23/naming"
    28  	"v.io/x/ref"
    29  )
    30  
    31  type FakeAWSSession struct {
    32  }
    33  
    34  type FakeAuthV1Client struct {
    35  	client.AuthenticationV1Interface
    36  	RESTClientReturn   rest.Interface
    37  	TokenReviewsReturn client.TokenReviewInterface
    38  }
    39  
    40  func (w *FakeAuthV1Client) RESTClient() rest.Interface {
    41  	return w.RESTClientReturn
    42  }
    43  
    44  func (w *FakeAuthV1Client) TokenReviews() client.TokenReviewInterface {
    45  	return w.TokenReviewsReturn
    46  }
    47  
    48  type FakeTokenReviews struct {
    49  	client.TokenReviewInterface
    50  	TokenReviewReturn *auth.TokenReview
    51  }
    52  
    53  func (t *FakeTokenReviews) Create(*auth.TokenReview) (*auth.TokenReview, error) {
    54  	var err error
    55  	return t.TokenReviewReturn, err
    56  }
    57  
    58  // FakeAWSSessionWrapper mocks the session wrapper used to isolate
    59  type FakeAWSSessionWrapper struct {
    60  	session               *FakeAWSSession
    61  	GetAuthV1ClientReturn client.AuthenticationV1Interface
    62  	ListEKSClustersReturn *eks.ListClustersOutput
    63  	AllEKSClusters        map[string]*eks.DescribeClusterOutput
    64  }
    65  
    66  func (w *FakeAWSSessionWrapper) DescribeEKSCluster(input *eks.DescribeClusterInput, roleARN string, region string) (*eks.DescribeClusterOutput, error) {
    67  	var err error
    68  	return w.AllEKSClusters[*input.Name], err
    69  }
    70  
    71  func (w *FakeAWSSessionWrapper) GetAuthV1Client(ctx context.Context, headers map[string]string, caCrt string, region string, endpoint string) (client.AuthenticationV1Interface, error) {
    72  	var err error
    73  	return w.GetAuthV1ClientReturn, err
    74  }
    75  
    76  func (w *FakeAWSSessionWrapper) ListEKSClusters(input *eks.ListClustersInput, roleARN string, region string) (*eks.ListClustersOutput, error) {
    77  	var err error
    78  	return w.ListEKSClustersReturn, err
    79  }
    80  
    81  // FakeContext mocks contexts so that we can pass them in to simulate logging, etc
    82  type FakeContext struct {
    83  	context.Context
    84  }
    85  
    86  // required to simulate logging.
    87  func (c *FakeContext) Value(key interface{}) interface{} {
    88  	return nil
    89  }
    90  
    91  // ClusterHelper generates all the cluster attributes used in a test
    92  type ClusterHelper struct {
    93  	Name          string
    94  	Arn           string
    95  	Crt           string
    96  	CrtEnc        string
    97  	RoleARN       string
    98  	Endpoint      string
    99  	Cluster       *eks.Cluster
   100  	ClusterOutput *eks.DescribeClusterOutput
   101  }
   102  
   103  func newClusterHelper(name, acctNum, crt, roleARN, region string, tags map[string]*string) *ClusterHelper {
   104  	fakeAccountName := "ACCTNAMEFOR" + name
   105  
   106  	ch := ClusterHelper{
   107  		Name:     name,
   108  		Arn:      "arn:aws:iam::" + acctNum + ":role/" + name,
   109  		Crt:      crt,
   110  		CrtEnc:   base64.StdEncoding.EncodeToString([]byte(crt)),
   111  		RoleARN:  roleARN,
   112  		Endpoint: "https://" + fakeAccountName + ".sk1." + region + ".eks.amazonaws.com",
   113  	}
   114  
   115  	ch.Cluster = &eks.Cluster{
   116  		Name:     &ch.Name,
   117  		RoleArn:  &ch.RoleARN,
   118  		Endpoint: &ch.Endpoint,
   119  		Tags:     tags,
   120  		Arn:      &ch.Arn,
   121  		CertificateAuthority: &eks.Certificate{
   122  			Data: &ch.CrtEnc,
   123  		},
   124  	}
   125  
   126  	ch.ClusterOutput = &eks.DescribeClusterOutput{
   127  		Cluster: ch.Cluster,
   128  	}
   129  
   130  	return &ch
   131  }
   132  
   133  // Note: we cannot test
   134  func TestK8sBlesser(t *testing.T) {
   135  	emptyTags := make(map[string]*string)
   136  	randomTag := "test"
   137  	emptyTags["RandomTag"] = &randomTag
   138  	acctNum := "111111111111"
   139  
   140  	ctx := vcontext.Background()
   141  	assert.NoError(t, ref.EnvClearCredentials())
   142  
   143  	t.Run("init", func(t *testing.T) {
   144  		fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}}
   145  		accountIDs := []string{"abc123456"}
   146  		awsRegions := []string{"us-west-2"}
   147  		testRole := "test-role"
   148  		compareAWSConn := newAwsConn(fakeSessionWrapper, testRole, awsRegions, accountIDs)
   149  		blesser := newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions)
   150  
   151  		// test that awsConn was configured
   152  		assert.Equal(t, blesser.awsConn, compareAWSConn)
   153  	})
   154  
   155  	t.Run("awsConn", func(t *testing.T) {
   156  		fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}}
   157  		accountIDs := []string{acctNum}
   158  		awsRegions := []string{"us-west-2"}
   159  		testRole := "test-role"
   160  		testRegion := "us-west-2"
   161  		wantCluster := newClusterHelper("test-cluster", acctNum, "fake-crt", testRole, testRegion, emptyTags)
   162  		otherCluster1 := newClusterHelper("other-cluster1", acctNum, "other-crt1", testRole, testRegion, emptyTags)
   163  		otherCluster2 := newClusterHelper("other-cluster2", acctNum, "other-crt2", "another-role", testRegion, emptyTags)
   164  
   165  		clusters := []string{wantCluster.Name, otherCluster1.Name, otherCluster2.Name}
   166  		var clusterOutputs = make(map[string]*eks.DescribeClusterOutput)
   167  		clusterOutputs[wantCluster.Name] = wantCluster.ClusterOutput
   168  		clusterOutputs[otherCluster1.Name] = otherCluster1.ClusterOutput
   169  		clusterOutputs[otherCluster2.Name] = otherCluster2.ClusterOutput
   170  
   171  		clusterPtrs := []*string{}
   172  		for i := range clusters {
   173  			clusterPtrs = append(clusterPtrs, &clusters[i])
   174  		}
   175  
   176  		fakeSessionWrapper.ListEKSClustersReturn = &eks.ListClustersOutput{
   177  			Clusters: clusterPtrs,
   178  		}
   179  
   180  		fakeSessionWrapper.AllEKSClusters = clusterOutputs
   181  
   182  		assert.NoError(t, ref.EnvClearCredentials())
   183  
   184  		blesser := newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions)
   185  
   186  		clustersOutput := blesser.awsConn.GetClusters(ctx, testRegion)
   187  		assert.Equal(t, clustersOutput, []*eks.Cluster{wantCluster.Cluster, otherCluster1.Cluster, otherCluster2.Cluster})
   188  
   189  		foundEksCluster, _ := blesser.awsConn.GetEKSCluster(ctx, testRegion, wantCluster.Crt)
   190  		assert.NotNil(t, foundEksCluster)
   191  	})
   192  
   193  	t.Run("k8sConn", func(t *testing.T) {
   194  		var (
   195  			foundUsername string
   196  			k8sConn       *k8sConn
   197  			err           error
   198  		)
   199  		fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}}
   200  		testRole := "test-role"
   201  		testToken := "test-token"
   202  		testRegion := "us-west-2"
   203  		testUsername := "system:serviceaccount:default:someService"
   204  		cluster := newClusterHelper("test-cluster", acctNum, "fake-crt", testRole, testRegion, emptyTags)
   205  
   206  		fakeTokenReviews := &FakeTokenReviews{}
   207  		fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{
   208  			Status: auth.TokenReviewStatus{
   209  				User: auth.UserInfo{
   210  					Username: testUsername,
   211  				},
   212  				Authenticated: true,
   213  			},
   214  		}
   215  		fakeContext := &FakeContext{}
   216  		fakeAuthV1Client := &FakeAuthV1Client{}
   217  		fakeAuthV1Client.TokenReviewsReturn = fakeTokenReviews
   218  
   219  		fakeSessionWrapper.GetAuthV1ClientReturn = fakeAuthV1Client
   220  		k8sConn = newK8sConn(fakeSessionWrapper, cluster.Cluster, testRegion, cluster.Crt, testToken)
   221  
   222  		foundUsername, err = k8sConn.GetK8sUsername(fakeContext)
   223  		assert.NoError(t, err)
   224  		assert.NotNil(t, foundUsername)
   225  		assert.Equal(t, testUsername, foundUsername)
   226  
   227  		// test failure outputs
   228  		fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{
   229  			Status: auth.TokenReviewStatus{
   230  				User: auth.UserInfo{
   231  					Username: "",
   232  				},
   233  				Authenticated: false,
   234  			},
   235  		}
   236  		k8sConn = newK8sConn(fakeSessionWrapper, cluster.Cluster, testRegion, cluster.Crt, testToken)
   237  		foundUsername, err = k8sConn.GetK8sUsername(fakeContext)
   238  		assert.NotNil(t, err)
   239  		assert.Empty(t, foundUsername)
   240  		assert.Equal(t, err, errors.New("requestToken authentication failed"))
   241  	})
   242  
   243  	t.Run("CreateK8sExtension", func(t *testing.T) {
   244  		var (
   245  			err       error
   246  			cluster   *ClusterHelper
   247  			extension string
   248  		)
   249  		testRole := "test-role"
   250  		testRegion := "us-west-2"
   251  		testNamespace := "default"
   252  		clusterName := "test-cluster"
   253  		serviceAccountName := "someService"
   254  		testUsername := "system:serviceaccount:" + testNamespace + ":" + serviceAccountName
   255  		fakeContext := &FakeContext{}
   256  
   257  		// test default cluster naming
   258  		cluster = newClusterHelper(clusterName, acctNum, "fake-crt", testRole, testRegion, emptyTags)
   259  		extension, err = CreateK8sExtension(fakeContext, cluster.Cluster, testUsername, testNamespace)
   260  		assert.NoError(t, err)
   261  		assert.Equal(t, "k8s:"+acctNum+":test-cluster:someService", extension)
   262  
   263  		// test cluster a/b
   264  		tags := make(map[string]*string)
   265  		clusterMode := "A"
   266  		tags["ClusterName"] = &clusterName
   267  		tags["ClusterMode"] = &clusterMode
   268  		cluster = newClusterHelper(clusterName+"-a", acctNum, "fake-crt", testRole, testRegion, tags)
   269  		extension, err = CreateK8sExtension(fakeContext, cluster.Cluster, testUsername, testNamespace)
   270  		assert.Nil(t, err)
   271  		assert.Equal(t, "k8s:"+acctNum+":"+clusterName+":"+serviceAccountName, extension)
   272  	})
   273  
   274  	t.Run("BlessK8s", func(t *testing.T) {
   275  		testRole := "test-role"
   276  		testToken := "test-token"
   277  		testRegion := "us-west-2"
   278  		testNamespace := "default"
   279  		clusterName := "test-cluster"
   280  		serviceAccountName := "someService"
   281  		testUsername := "system:serviceaccount:" + testNamespace + ":" + serviceAccountName
   282  		accountIDs := []string{acctNum}
   283  		awsRegions := []string{testRegion}
   284  
   285  		// tags for the ab Cluster
   286  		tags := make(map[string]*string)
   287  		clusterMode := "A"
   288  		tags["ClusterName"] = &clusterName
   289  		tags["ClusterMode"] = &clusterMode
   290  
   291  		// setup fake clusters, lg = legacy, ab = with cluster a/b
   292  		lgCluster := newClusterHelper(clusterName, acctNum, "lg-crt", testRole, testRegion, emptyTags)
   293  		abCluster := newClusterHelper(clusterName+"-a", acctNum, "ab-crt", testRole, testRegion, tags)
   294  
   295  		// creating clusters list
   296  		clusters := []string{lgCluster.Name, abCluster.Name}
   297  
   298  		// outputs list for the desired client output
   299  		var clusterOutputs = make(map[string]*eks.DescribeClusterOutput)
   300  		clusterOutputs[lgCluster.Name] = lgCluster.ClusterOutput
   301  		clusterOutputs[abCluster.Name] = abCluster.ClusterOutput
   302  
   303  		// assigning pointer to cluster names to cluster ptrs list
   304  		clusterPtrs := []*string{}
   305  		for i := range clusters {
   306  			clusterPtrs = append(clusterPtrs, &clusters[i])
   307  		}
   308  		// setup fake token reviews
   309  		fakeTokenReviews := &FakeTokenReviews{}
   310  		fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{
   311  			Status: auth.TokenReviewStatus{
   312  				User: auth.UserInfo{
   313  					Username: testUsername,
   314  				},
   315  				Authenticated: true,
   316  			},
   317  		}
   318  
   319  		// setup fake authv1 client
   320  		fakeAuthV1Client := &FakeAuthV1Client{}
   321  		fakeAuthV1Client.TokenReviewsReturn = fakeTokenReviews
   322  
   323  		// setup fake session wrapper
   324  		fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}}
   325  		fakeSessionWrapper.ListEKSClustersReturn = &eks.ListClustersOutput{
   326  			Clusters: clusterPtrs,
   327  		}
   328  		fakeSessionWrapper.AllEKSClusters = clusterOutputs
   329  		fakeSessionWrapper.GetAuthV1ClientReturn = fakeAuthV1Client
   330  
   331  		assert.NoError(t, ref.EnvClearCredentials())
   332  
   333  		// setup fake blessings server
   334  		pathEnv := "PATH=" + os.Getenv("PATH")
   335  		exe := testutil.GoExecutable(t, "//go/src/github.com/grailbio/base/cmd/grail-access/grail-access")
   336  
   337  		var blesserEndpoint naming.Endpoint
   338  		ctx, blesserEndpoint = ticketServerUtil.RunBlesserServer(
   339  			ctx,
   340  			t,
   341  			identity.K8sBlesserServer(newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions)),
   342  		)
   343  
   344  		var (
   345  			tmpDir           string
   346  			cleanUp          func()
   347  			stdout           string
   348  			principalDir     string
   349  			principalCleanUp func()
   350  			cmd              *exec.Cmd
   351  		)
   352  
   353  		// create local crt, namespace, and tokens for the legacy cluster
   354  		tmpDir, cleanUp = testutil.TempDir(t, "", "")
   355  		defer cleanUp()
   356  
   357  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "caCrt"), []byte(lgCluster.Crt), 0644))
   358  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "namespace"), []byte(testNamespace), 0644))
   359  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "token"), []byte(testToken), 0644))
   360  
   361  		// Run grail-access to create a principal and bless it with the k8s flow.
   362  		principalDir, principalCleanUp = testutil.TempDir(t, "", "")
   363  		defer principalCleanUp()
   364  		cmd = exec.Command(exe,
   365  			"-dir", principalDir,
   366  			"-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address),
   367  			"-k8s",
   368  			"-ca-crt", path.Join(tmpDir, "caCrt"),
   369  			"-namespace", path.Join(tmpDir, "namespace"),
   370  			"-token", path.Join(tmpDir, "token"),
   371  		)
   372  		cmd.Env = []string{pathEnv}
   373  		stdout, _ = ticketServerUtil.RunAndCapture(t, cmd)
   374  		assert.Contains(t, stdout, "k8s:111111111111:test-cluster:someService")
   375  
   376  		// create local crt, namespace, and tokens for the a/b cluster
   377  		tmpDir, cleanUp = testutil.TempDir(t, "", "")
   378  		defer cleanUp()
   379  
   380  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "caCrt"), []byte(abCluster.Crt), 0644))
   381  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "namespace"), []byte(testNamespace), 0644))
   382  		assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "token"), []byte(testToken), 0644))
   383  
   384  		// Run grail-access to create a principal and bless it with the k8s flow.
   385  		principalDir, principalCleanUp = testutil.TempDir(t, "", "")
   386  		defer principalCleanUp()
   387  		cmd = exec.Command(exe,
   388  			"-dir", principalDir,
   389  			"-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address),
   390  			"-k8s",
   391  			"-ca-crt", path.Join(tmpDir, "caCrt"),
   392  			"-namespace", path.Join(tmpDir, "namespace"),
   393  			"-token", path.Join(tmpDir, "token"),
   394  		)
   395  		cmd.Env = []string{pathEnv}
   396  		stdout, _ = ticketServerUtil.RunAndCapture(t, cmd)
   397  		assert.Contains(t, stdout, "k8s:111111111111:test-cluster:someService")
   398  	})
   399  }