github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/security/ticket/helper_test.go (about)

     1  package ticket_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"github.com/Schaudge/grailbase/security/ticket"
     7  	"reflect"
     8  	"testing"
     9  	"v.io/v23/context"
    10  )
    11  
    12  func TestGetter_path(t *testing.T) {
    13  	t.Run("it joins with slashes", func(t *testing.T) {
    14  		want := "ok"
    15  		client := mockString("string/key", want)
    16  		got, err := client.GetString(testContext(), "string", "key")
    17  		if err != nil {
    18  			t.Fatal(err)
    19  		}
    20  		if got != want {
    21  			t.Errorf("got %v, want %v", got, want)
    22  		}
    23  	})
    24  }
    25  
    26  func TestGetter_getTicket(t *testing.T) {
    27  	t.Run("it can error", func(t *testing.T) {
    28  		client := mockString("some/key", "ok")
    29  		_, err := client.GetString(testContext(), "other/key")
    30  		if err == nil {
    31  			t.Fatal("want error, got nil")
    32  		}
    33  	})
    34  }
    35  
    36  func TestGetter_GetData(t *testing.T) {
    37  	key := "data/key"
    38  	want := []byte{1, 2, 3}
    39  	client := mockData(key, want)
    40  
    41  	got, err := client.GetData(testContext(), key)
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	if !bytes.Equal(got, want) {
    46  		t.Errorf("got %v, want %v", got, want)
    47  	}
    48  }
    49  
    50  func TestGetter_GetString(t *testing.T) {
    51  	key := "string/key"
    52  	want := "this is just a test"
    53  	client := mockString(key, want)
    54  
    55  	got, err := client.GetString(testContext(), key)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	if got != want {
    60  		t.Errorf("got %v, want %v", got, want)
    61  	}
    62  }
    63  
    64  func TestGetter_GetAws(t *testing.T) {
    65  	key := "aws/key"
    66  	want := ticket.AwsTicket{
    67  		AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{
    68  			Region: "region",
    69  			Role:   "role",
    70  			TtlSec: 123,
    71  		},
    72  		AwsCredentials: ticket.AwsCredentials{
    73  			Region:          "region",
    74  			AccessKeyId:     "accessKeyID",
    75  			SecretAccessKey: "secretAccessKey",
    76  			SessionToken:    "sessionToken",
    77  			Expiration:      "expiration",
    78  		},
    79  	}
    80  	client := mockAws(key, want)
    81  
    82  	got, err := client.GetAws(testContext(), key)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	if !reflect.DeepEqual(got, want) {
    87  		t.Errorf("got %v, want %v", got, want)
    88  	}
    89  }
    90  
    91  func TestGetter_GetS3(t *testing.T) {
    92  	key := "s3/key"
    93  	want := ticket.S3Ticket{
    94  		AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{
    95  			Region: "region",
    96  			Role:   "role",
    97  			TtlSec: 123,
    98  		},
    99  		Endpoint: "endpoint",
   100  		Bucket:   "bucket",
   101  		Prefix:   "prefix",
   102  	}
   103  	client := mockS3(key, want)
   104  
   105  	got, err := client.GetS3(testContext(), key)
   106  	if err != nil {
   107  		t.Fatal(err)
   108  	}
   109  	if !reflect.DeepEqual(got, want) {
   110  		t.Errorf("got %v, want %v", got, want)
   111  	}
   112  }
   113  
   114  func TestGetter_GetSshCertificate(t *testing.T) {
   115  	key := "ssh/key"
   116  	want := ticket.SshCertificateTicket{
   117  		Username: "username",
   118  	}
   119  	client := mockSshCertificate(key, want)
   120  
   121  	got, err := client.GetSshCertificate(testContext(), key)
   122  	if err != nil {
   123  		t.Fatal(err)
   124  	}
   125  	if !reflect.DeepEqual(got, want) {
   126  		t.Errorf("got %v, want %v", got, want)
   127  	}
   128  }
   129  
   130  func TestGetter_GetEcr(t *testing.T) {
   131  	key := "ecr/key"
   132  	want := ticket.EcrTicket{
   133  		AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{
   134  			Region: "region",
   135  			Role:   "role",
   136  			TtlSec: 123,
   137  		},
   138  		AuthorizationToken: "authorizationToken",
   139  		Expiration:         "expiration",
   140  		Endpoint:           "endpoint",
   141  	}
   142  	client := mockEcr(key, want)
   143  
   144  	got, err := client.GetEcr(testContext(), key)
   145  	if err != nil {
   146  		t.Fatal(err)
   147  	}
   148  	if !reflect.DeepEqual(got, want) {
   149  		t.Errorf("got %v, want %v", got, want)
   150  	}
   151  }
   152  
   153  func TestGetter_GetTlsServer(t *testing.T) {
   154  	key := "TlsServer/key"
   155  	want := ticket.TlsServerTicket{
   156  		TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{
   157  			Authority:  "authority",
   158  			TtlSec:     123,
   159  			CommonName: "commonName",
   160  		},
   161  		Credentials: ticket.TlsCredentials{
   162  			AuthorityCert: "authorityCert",
   163  			Cert:          "cert",
   164  			Key:           "key",
   165  		},
   166  	}
   167  	client := mockTlsServer(key, want)
   168  
   169  	got, err := client.GetTlsServer(testContext(), key)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	if !reflect.DeepEqual(got, want) {
   174  		t.Errorf("got %v, want %v", got, want)
   175  	}
   176  }
   177  
   178  func TestGetter_GetTlsClient(t *testing.T) {
   179  	key := "TlsClient/key"
   180  	want := ticket.TlsClientTicket{
   181  		TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{
   182  			Authority:  "authority",
   183  			TtlSec:     123,
   184  			CommonName: "commonName",
   185  		},
   186  		Credentials: ticket.TlsCredentials{
   187  			AuthorityCert: "authorityCert",
   188  			Cert:          "cert",
   189  			Key:           "key",
   190  		},
   191  	}
   192  	client := mockTlsClient(key, want)
   193  
   194  	got, err := client.GetTlsClient(testContext(), key)
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  	if !reflect.DeepEqual(got, want) {
   199  		t.Errorf("got %v, want %v", got, want)
   200  	}
   201  }
   202  
   203  func TestGetter_GetDocker(t *testing.T) {
   204  	key := "Docker/key"
   205  	want := ticket.DockerTicket{
   206  		TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{
   207  			Authority:  "authority",
   208  			TtlSec:     123,
   209  			CommonName: "commonName",
   210  		},
   211  		Credentials: ticket.TlsCredentials{
   212  			AuthorityCert: "authorityCert",
   213  			Cert:          "cert",
   214  			Key:           "key",
   215  		},
   216  		Url: "url",
   217  	}
   218  	client := mockDocker(key, want)
   219  
   220  	got, err := client.GetDocker(testContext(), key)
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  	if !reflect.DeepEqual(got, want) {
   225  		t.Errorf("got %v, want %v", got, want)
   226  	}
   227  }
   228  
   229  func TestGetter_GetDockerServer(t *testing.T) {
   230  	key := "DockerServer/key"
   231  	want := ticket.DockerServerTicket{}
   232  	client := mockDockerServer(key, want)
   233  
   234  	got, err := client.GetDockerServer(testContext(), key)
   235  	if err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	if !reflect.DeepEqual(got, want) {
   239  		t.Errorf("got %v, want %v", got, want)
   240  	}
   241  }
   242  
   243  func TestGetter_GetDockerClient(t *testing.T) {
   244  	key := "DockerClient/key"
   245  	want := ticket.DockerClientTicket{}
   246  	client := mockDockerClient(key, want)
   247  
   248  	got, err := client.GetDockerClient(testContext(), key)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	if !reflect.DeepEqual(got, want) {
   253  		t.Errorf("got %v, want %v", got, want)
   254  	}
   255  }
   256  
   257  func TestGetter_GetB2(t *testing.T) {
   258  	key := "B2/key"
   259  	want := ticket.B2Ticket{}
   260  	client := mockB2(key, want)
   261  
   262  	got, err := client.GetB2(testContext(), key)
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	if !reflect.DeepEqual(got, want) {
   267  		t.Errorf("got %v, want %v", got, want)
   268  	}
   269  }
   270  
   271  func TestGetter_GetVanadium(t *testing.T) {
   272  	key := "Vanadium/key"
   273  	want := ticket.VanadiumTicket{}
   274  	client := mockVanadium(key, want)
   275  
   276  	got, err := client.GetVanadium(testContext(), key)
   277  	if err != nil {
   278  		t.Fatal(err)
   279  	}
   280  	if !reflect.DeepEqual(got, want) {
   281  		t.Errorf("got %v, want %v", got, want)
   282  	}
   283  }
   284  
   285  // testContext creates a nil context which is safe to use with a mock client.
   286  func testContext() *context.T {
   287  	return nil
   288  }
   289  
   290  // mock is a shim to convert a ticket value to a ticket.Getter.
   291  func mock(expectKey string, value interface{}) ticket.Getter {
   292  	return func(_ *context.T, gotKey string) (ticket.Ticket, error) {
   293  		if gotKey == expectKey {
   294  			return value.(ticket.Ticket), nil
   295  		}
   296  
   297  		return nil, fmt.Errorf("ticket not found")
   298  	}
   299  }
   300  
   301  func mockData(key string, data []byte) ticket.Getter {
   302  	return mock(key, ticket.TicketGenericTicket{
   303  		Value: ticket.GenericTicket{Data: data},
   304  	})
   305  }
   306  
   307  func mockString(key string, s string) ticket.Getter {
   308  	return mock(key, ticket.TicketGenericTicket{
   309  		Value: ticket.GenericTicket{Data: []byte(s)},
   310  	})
   311  }
   312  
   313  func mockAws(key string, aws ticket.AwsTicket) ticket.Getter {
   314  	return mock(key, ticket.TicketAwsTicket{Value: aws})
   315  }
   316  func mockS3(key string, s3 ticket.S3Ticket) ticket.Getter {
   317  	return mock(key, ticket.TicketS3Ticket{Value: s3})
   318  }
   319  func mockSshCertificate(key string, ssh ticket.SshCertificateTicket) ticket.Getter {
   320  	return mock(key, ticket.TicketSshCertificateTicket{Value: ssh})
   321  }
   322  func mockEcr(key string, ecr ticket.EcrTicket) ticket.Getter {
   323  	return mock(key, ticket.TicketEcrTicket{Value: ecr})
   324  }
   325  func mockTlsServer(key string, TlsServer ticket.TlsServerTicket) ticket.Getter {
   326  	return mock(key, ticket.TicketTlsServerTicket{Value: TlsServer})
   327  }
   328  func mockTlsClient(key string, TlsClient ticket.TlsClientTicket) ticket.Getter {
   329  	return mock(key, ticket.TicketTlsClientTicket{Value: TlsClient})
   330  }
   331  func mockDocker(key string, Docker ticket.DockerTicket) ticket.Getter {
   332  	return mock(key, ticket.TicketDockerTicket{Value: Docker})
   333  }
   334  func mockDockerServer(key string, DockerServer ticket.DockerServerTicket) ticket.Getter {
   335  	return mock(key, ticket.TicketDockerServerTicket{Value: DockerServer})
   336  }
   337  func mockDockerClient(key string, DockerClient ticket.DockerClientTicket) ticket.Getter {
   338  	return mock(key, ticket.TicketDockerClientTicket{Value: DockerClient})
   339  }
   340  func mockB2(key string, b2 ticket.B2Ticket) ticket.Getter {
   341  	return mock(key, ticket.TicketB2Ticket{Value: b2})
   342  }
   343  func mockVanadium(key string, vanadium ticket.VanadiumTicket) ticket.Getter {
   344  	return mock(key, ticket.TicketVanadiumTicket{Value: vanadium})
   345  }