github.com/dtroyer-salad/og2/v2@v2.0.0-20240412154159-c47231610877/registry/remote/credentials/registry_test.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package credentials
    17  
    18  import (
    19  	"context"
    20  	"encoding/base64"
    21  	"errors"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"reflect"
    26  	"testing"
    27  
    28  	"oras.land/oras-go/v2/registry/remote"
    29  	"oras.land/oras-go/v2/registry/remote/auth"
    30  )
    31  
    32  // testStore implements the Store interface, used for testing purpose.
    33  type testStore struct {
    34  	storage map[string]auth.Credential
    35  }
    36  
    37  func (t *testStore) Get(ctx context.Context, serverAddress string) (auth.Credential, error) {
    38  	return t.storage[serverAddress], nil
    39  }
    40  
    41  func (t *testStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) error {
    42  	if len(t.storage) == 0 {
    43  		t.storage = make(map[string]auth.Credential)
    44  	}
    45  	t.storage[serverAddress] = cred
    46  	return nil
    47  }
    48  
    49  func (t *testStore) Delete(ctx context.Context, serverAddress string) error {
    50  	delete(t.storage, serverAddress)
    51  	return nil
    52  }
    53  
    54  func TestLogin(t *testing.T) {
    55  	// create a test registry
    56  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		wantedAuthHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(testUsername+":"+testPassword))
    58  		authHeader := r.Header.Get("Authorization")
    59  		if authHeader != wantedAuthHeader {
    60  			w.Header().Set("Www-Authenticate", `Basic realm="Test Server"`)
    61  			w.WriteHeader(http.StatusUnauthorized)
    62  		}
    63  	}))
    64  	defer ts.Close()
    65  	uri, _ := url.Parse(ts.URL)
    66  	reg, err := remote.NewRegistry(uri.Host)
    67  	if err != nil {
    68  		t.Fatalf("cannot create test registry: %v", err)
    69  	}
    70  	reg.PlainHTTP = true
    71  	// create a test store
    72  	s := &testStore{}
    73  	tests := []struct {
    74  		name     string
    75  		ctx      context.Context
    76  		registry *remote.Registry
    77  		cred     auth.Credential
    78  		wantErr  bool
    79  	}{
    80  		{
    81  			name:    "login succeeds",
    82  			ctx:     context.Background(),
    83  			cred:    auth.Credential{Username: testUsername, Password: testPassword},
    84  			wantErr: false,
    85  		},
    86  		{
    87  			name:    "login fails (incorrect password)",
    88  			ctx:     context.Background(),
    89  			cred:    auth.Credential{Username: testUsername, Password: "whatever"},
    90  			wantErr: true,
    91  		},
    92  		{
    93  			name:    "login fails (nil context makes remote.Ping fails)",
    94  			ctx:     nil,
    95  			cred:    auth.Credential{Username: testUsername, Password: testPassword},
    96  			wantErr: true,
    97  		},
    98  	}
    99  	for _, tt := range tests {
   100  		t.Run(tt.name, func(t *testing.T) {
   101  			// login to test registry
   102  			err := Login(tt.ctx, s, reg, tt.cred)
   103  			if (err != nil) != tt.wantErr {
   104  				t.Fatalf("Login() error = %v, wantErr %v", err, tt.wantErr)
   105  			}
   106  			if err != nil {
   107  				return
   108  			}
   109  			if got := s.storage[reg.Reference.Registry]; !reflect.DeepEqual(got, tt.cred) {
   110  				t.Fatalf("Stored credential = %v, want %v", got, tt.cred)
   111  			}
   112  			s.Delete(tt.ctx, reg.Reference.Registry)
   113  		})
   114  	}
   115  }
   116  
   117  func TestLogin_unsupportedClient(t *testing.T) {
   118  	var testClient http.Client
   119  	reg, err := remote.NewRegistry("whatever")
   120  	if err != nil {
   121  		t.Fatalf("cannot create test registry: %v", err)
   122  	}
   123  	reg.PlainHTTP = true
   124  	reg.Client = &testClient
   125  	ctx := context.Background()
   126  
   127  	s := &testStore{}
   128  	cred := auth.EmptyCredential
   129  	err = Login(ctx, s, reg, cred)
   130  	if wantErr := ErrClientTypeUnsupported; !errors.Is(err, wantErr) {
   131  		t.Errorf("Login() error = %v, wantErr %v", err, wantErr)
   132  	}
   133  }
   134  
   135  func TestLogout(t *testing.T) {
   136  	// create a test store
   137  	s := &testStore{}
   138  	s.storage = map[string]auth.Credential{
   139  		"localhost:2333":              {Username: "test_user", Password: "test_word"},
   140  		"https://index.docker.io/v1/": {Username: "user", Password: "word"},
   141  	}
   142  	tests := []struct {
   143  		name         string
   144  		ctx          context.Context
   145  		store        Store
   146  		registryName string
   147  		wantErr      bool
   148  	}{
   149  		{
   150  			name:         "logout of regular registry",
   151  			ctx:          context.Background(),
   152  			registryName: "localhost:2333",
   153  			wantErr:      false,
   154  		},
   155  		{
   156  			name:         "logout of docker.io",
   157  			ctx:          context.Background(),
   158  			registryName: "docker.io",
   159  			wantErr:      false,
   160  		},
   161  	}
   162  	for _, tt := range tests {
   163  		t.Run(tt.name, func(t *testing.T) {
   164  			if err := Logout(tt.ctx, s, tt.registryName); (err != nil) != tt.wantErr {
   165  				t.Fatalf("Logout() error = %v, wantErr %v", err, tt.wantErr)
   166  			}
   167  			if s.storage[tt.registryName] != auth.EmptyCredential {
   168  				t.Error("Credentials are not deleted")
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func Test_mapHostname(t *testing.T) {
   175  	tests := []struct {
   176  		name string
   177  		host string
   178  		want string
   179  	}{
   180  		{
   181  			name: "map docker.io to https://index.docker.io/v1/",
   182  			host: "docker.io",
   183  			want: "https://index.docker.io/v1/",
   184  		},
   185  		{
   186  			name: "do not map other host names",
   187  			host: "localhost:2333",
   188  			want: "localhost:2333",
   189  		},
   190  	}
   191  	for _, tt := range tests {
   192  		t.Run(tt.name, func(t *testing.T) {
   193  			if got := ServerAddressFromRegistry(tt.host); got != tt.want {
   194  				t.Errorf("mapHostname() = %v, want %v", got, tt.want)
   195  			}
   196  		})
   197  	}
   198  }
   199  
   200  func TestCredential(t *testing.T) {
   201  	// create a test store
   202  	s := &testStore{}
   203  	s.storage = map[string]auth.Credential{
   204  		"localhost:2333":              {Username: "test_user", Password: "test_word"},
   205  		"https://index.docker.io/v1/": {Username: "user", Password: "word"},
   206  	}
   207  	// create a test client using Credential
   208  	testClient := &auth.Client{}
   209  	testClient.Credential = Credential(s)
   210  	tests := []struct {
   211  		name           string
   212  		registry       string
   213  		wantCredential auth.Credential
   214  	}{
   215  		{
   216  			name:           "get credentials for localhost:2333",
   217  			registry:       "localhost:2333",
   218  			wantCredential: auth.Credential{Username: "test_user", Password: "test_word"},
   219  		},
   220  		{
   221  			name:           "get credentials for registry-1.docker.io",
   222  			registry:       "registry-1.docker.io",
   223  			wantCredential: auth.Credential{Username: "user", Password: "word"},
   224  		},
   225  		{
   226  			name:           "get credentials for a registry not stored",
   227  			registry:       "localhost:6666",
   228  			wantCredential: auth.EmptyCredential,
   229  		},
   230  		{
   231  			name:           "get credentials for an empty string",
   232  			registry:       "",
   233  			wantCredential: auth.EmptyCredential,
   234  		},
   235  	}
   236  	for _, tt := range tests {
   237  		t.Run(tt.name, func(t *testing.T) {
   238  			got, err := testClient.Credential(context.Background(), tt.registry)
   239  			if err != nil {
   240  				t.Errorf("could not get credential: %v", err)
   241  			}
   242  			if !reflect.DeepEqual(got, tt.wantCredential) {
   243  				t.Errorf("Credential() = %v, want %v", got, tt.wantCredential)
   244  			}
   245  		})
   246  	}
   247  }