github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/subshell/cmd/env_test.go (about)

     1  package cmd
     2  
     3  import (
     4  	"os"
     5  	"reflect"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/ActiveState/cli/internal/osutils"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/thoas/go-funk"
    12  )
    13  
    14  type RegistryKeyMock struct {
    15  	getCalls       []string
    16  	setCalls       []string
    17  	setExpandCalls []string
    18  	delCalls       []string
    19  
    20  	getResults map[string]RegistryValue
    21  	setResults map[string]error
    22  	delResults map[string]error
    23  }
    24  
    25  type RegistryValue struct {
    26  	Value string
    27  	Error error
    28  }
    29  
    30  func (r *RegistryKeyMock) GetStringValue(name string) (string, uint32, error) {
    31  	r.getCalls = append(r.getCalls, name)
    32  	if v, ok := r.getResults[name]; ok {
    33  		return v.Value, 0, v.Error
    34  	}
    35  	return "", 0, osutils.NotExistError()
    36  }
    37  
    38  func (r *RegistryKeyMock) SetStringValue(name, value string) error {
    39  	r.setCalls = append(r.setCalls, name+"="+value)
    40  	if v, ok := r.setResults[name]; ok {
    41  		return v
    42  	}
    43  	return nil
    44  }
    45  
    46  func (r *RegistryKeyMock) SetExpandStringValue(name, value string) error {
    47  	r.setExpandCalls = append(r.setExpandCalls, name+"="+value)
    48  	if v, ok := r.setResults[name]; ok {
    49  		return v
    50  	}
    51  	return nil
    52  }
    53  
    54  func (r *RegistryKeyMock) DeleteValue(name string) error {
    55  	r.delCalls = append(r.getCalls, name)
    56  	if v, ok := r.delResults[name]; ok {
    57  		return v
    58  	}
    59  	return nil
    60  }
    61  
    62  func (r *RegistryKeyMock) Close() error {
    63  	return nil
    64  }
    65  
    66  func TestCmdEnv_unset(t *testing.T) {
    67  	type fields struct {
    68  		registryMock *RegistryKeyMock
    69  		openKeyErr   error
    70  	}
    71  	type args struct {
    72  		name          string
    73  		ifValueEquals string
    74  	}
    75  	type want struct {
    76  		returnValue      error
    77  		registryGetCalls *[]string // nil means it should have no calls
    78  		registrySetCalls *[]string
    79  		registryDelCalls *[]string
    80  	}
    81  	tests := []struct {
    82  		name   string
    83  		fields fields
    84  		args   args
    85  		want   want
    86  	}{
    87  		{
    88  			"unset, value not equals",
    89  			fields{&RegistryKeyMock{}, nil},
    90  			args{
    91  				"key",
    92  				"value_not_equal",
    93  			},
    94  			want{
    95  				nil,
    96  				&[]string{},
    97  				nil,
    98  				nil,
    99  			},
   100  		},
   101  		{
   102  			"unset, value equals",
   103  			fields{&RegistryKeyMock{
   104  				getResults: map[string]RegistryValue{
   105  					"key": RegistryValue{"value_equals", nil},
   106  				},
   107  			}, nil},
   108  			args{
   109  				"key",
   110  				"value_equals",
   111  			},
   112  			want{
   113  				nil,
   114  				&[]string{},
   115  				nil,
   116  				&[]string{"key"},
   117  			},
   118  		},
   119  	}
   120  	for _, tt := range tests {
   121  		t.Run(tt.name, func(t *testing.T) {
   122  			c := &CmdEnv{
   123  				openKeyFn: func(path string) (osutils.RegistryKey, error) {
   124  					return tt.fields.registryMock, tt.fields.openKeyErr
   125  				},
   126  			}
   127  			if got := c.unset(tt.args.name, tt.args.ifValueEquals); !reflect.DeepEqual(got, tt.want.returnValue) {
   128  				t.Errorf("unset() = %v, want %v", got, tt.want)
   129  			}
   130  			rm := tt.fields.registryMock
   131  
   132  			registryValidator(t, rm.getCalls, tt.want.registryGetCalls, "GET")
   133  			registryValidator(t, rm.setCalls, tt.want.registrySetCalls, "SET")
   134  			registryValidator(t, rm.setExpandCalls, &[]string{}, "EXPAND")
   135  			registryValidator(t, rm.delCalls, tt.want.registryDelCalls, "DEL")
   136  		})
   137  	}
   138  }
   139  
   140  func TestCmdEnv_set(t *testing.T) {
   141  	type fields struct {
   142  		registryMock *RegistryKeyMock
   143  		openKeyErr   error
   144  	}
   145  	type args struct {
   146  		name  string
   147  		value string
   148  	}
   149  	type want struct {
   150  		returnValue      error
   151  		registryGetCalls *[]string // nil means it should have no calls
   152  		registrySetCalls *[]string
   153  	}
   154  	tests := []struct {
   155  		name   string
   156  		fields fields
   157  		args   args
   158  		want   want
   159  	}{
   160  		{
   161  			"set",
   162  			fields{&RegistryKeyMock{}, nil},
   163  			args{
   164  				"key",
   165  				"value",
   166  			},
   167  			want{
   168  				nil,
   169  				&[]string{},
   170  				&[]string{"key=value", "!key_original"},
   171  			},
   172  		},
   173  	}
   174  	for _, tt := range tests {
   175  		t.Run(tt.name, func(t *testing.T) {
   176  			c := &CmdEnv{
   177  				openKeyFn: func(path string) (osutils.RegistryKey, error) {
   178  					return tt.fields.registryMock, tt.fields.openKeyErr
   179  				},
   180  			}
   181  			if got := c.Set(tt.args.name, tt.args.value); !reflect.DeepEqual(got, tt.want.returnValue) {
   182  				t.Errorf("set() = %v, want %v", got, tt.want)
   183  			}
   184  			rm := tt.fields.registryMock
   185  
   186  			registryValidator(t, rm.getCalls, tt.want.registryGetCalls, "GET")
   187  			registryValidator(t, rm.setCalls, tt.want.registrySetCalls, "SET")
   188  		})
   189  	}
   190  }
   191  
   192  func TestCmdEnv_get(t *testing.T) {
   193  	type fields struct {
   194  		registryMock *RegistryKeyMock
   195  		openKeyErr   error
   196  	}
   197  	type args struct {
   198  		name string
   199  	}
   200  	type want struct {
   201  		returnValue      string
   202  		returnFailure    error
   203  		registryGetCalls *[]string // nil means it should have no calls
   204  	}
   205  	tests := []struct {
   206  		name   string
   207  		fields fields
   208  		args   args
   209  		want   want
   210  	}{
   211  		{
   212  			"get nonexist",
   213  			fields{&RegistryKeyMock{}, nil},
   214  			args{
   215  				"key",
   216  			},
   217  			want{
   218  				"",
   219  				nil,
   220  				&[]string{"key"},
   221  			},
   222  		},
   223  		{
   224  			"get existing",
   225  			fields{&RegistryKeyMock{
   226  				getResults: map[string]RegistryValue{
   227  					"key": RegistryValue{"value", nil},
   228  				},
   229  			}, nil},
   230  			args{
   231  				"key",
   232  			},
   233  			want{
   234  				"value",
   235  				nil,
   236  				&[]string{"key"},
   237  			},
   238  		},
   239  	}
   240  	for _, tt := range tests {
   241  		t.Run(tt.name, func(t *testing.T) {
   242  			c := &CmdEnv{
   243  				openKeyFn: func(path string) (osutils.RegistryKey, error) {
   244  					return tt.fields.registryMock, tt.fields.openKeyErr
   245  				},
   246  			}
   247  			got, gotFail := c.Get(tt.args.name)
   248  			if !reflect.DeepEqual(got, tt.want.returnValue) {
   249  				t.Errorf("get() = %v, want %v", got, tt.want)
   250  			}
   251  			if !reflect.DeepEqual(gotFail, tt.want.returnFailure) {
   252  				t.Errorf("get() err = %v, want %v", gotFail, tt.want)
   253  			}
   254  
   255  			rm := tt.fields.registryMock
   256  			registryValidator(t, rm.getCalls, tt.want.registryGetCalls, "GET")
   257  		})
   258  	}
   259  }
   260  
   261  func registryValidator(t *testing.T, got []string, want *[]string, name string) {
   262  	if want == nil && len(got) > 0 {
   263  		t.Errorf("%s: registry should have no calls but got: %v", name, got)
   264  		t.FailNow()
   265  	}
   266  
   267  	if want != nil {
   268  		for _, v := range *want {
   269  			exclude := strings.HasPrefix(v, "!")
   270  			if exclude {
   271  				v = strings.TrimPrefix(v, "!")
   272  			}
   273  			contains := funk.Contains(got, v)
   274  			if exclude && contains {
   275  				t.Errorf("%s: should not contain: %s, calls: %v", name, v, got)
   276  			}
   277  			if !exclude && !contains {
   278  				t.Errorf("%s: should have contained: %s, calls: %v", name, v, got)
   279  			}
   280  		}
   281  	}
   282  }
   283  
   284  func TestCleanPath(t *testing.T) {
   285  	paths := []string{"foo", "bar", "baz/quux"}
   286  	path := strings.Join(paths, string(os.PathListSeparator))
   287  
   288  	cleaned := cleanPath(path, "bar")
   289  	assert.Equal(t, cleaned, strings.Join([]string{"foo", "baz/quux"}, string(os.PathListSeparator)))
   290  
   291  	cleaned = cleanPath(path, strings.Join([]string{"bar", "baz/quux"}, string(os.PathListSeparator)))
   292  	assert.Equal(t, cleaned, "foo")
   293  }