github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/internal/env/env_test.go (about)

     1  package env_test
     2  
     3  import (
     4  	"errors"
     5  	"os"
     6  	"path/filepath"
     7  	"testing"
     8  
     9  	"github.com/google/go-cmp/cmp"
    10  	"github.com/quickfeed/quickfeed/internal/env"
    11  )
    12  
    13  func TestScmProviderEnv(t *testing.T) {
    14  	want := "github"
    15  	got := env.ScmProvider()
    16  	if got != want {
    17  		t.Errorf("ScmProvider() = %s, wanted %s", got, want)
    18  	}
    19  
    20  	env.SetFakeProvider(t)
    21  	want = "fake"
    22  	got = env.ScmProvider()
    23  	if got != want {
    24  		t.Errorf("ScmProvider() = %s, wanted %s", got, want)
    25  	}
    26  }
    27  
    28  func TestLoad(t *testing.T) {
    29  	fi, err := os.CreateTemp("", ".env")
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	defer func() {
    34  		fi.Close()
    35  		if err = os.Remove(fi.Name()); err != nil {
    36  			t.Fatal(err)
    37  		}
    38  	}()
    39  
    40  	want := map[string]string{
    41  		"QUICKFEED":           os.Getenv("QUICKFEED"),
    42  		"SOME_PATH":           "/quickfeed/root",
    43  		"QUICKFEED_TEST_ENV":  "test",
    44  		"QUICKFEED_TEST_ENV2": "test2",
    45  		"QUICKFEED_TEST_ENV3": "test3",
    46  		"QUICKFEED_TEST_ENV4": "test4 xyz",
    47  		"QUICKFEED_TEST_ENV5": "test5 = zyx",
    48  		"SOME_CERT_FILE":      "/quickfeed/root/cert/fullchain.pem",
    49  		"SOME_KEY_FILE":       filepath.Join(os.Getenv("QUICKFEED"), "cert", "fullchain.pem"),
    50  		"WITHOUT_QUOTES":      filepath.Join(os.Getenv("QUICKFEED"), "cert", "fullchain.pem"),
    51  	}
    52  
    53  	input := `QUICKFEED_TEST_ENV=test
    54  QUICKFEED_TEST_ENV2= test2
    55  
    56  QUICKFEED_TEST_ENV3=test3
    57  # Comment
    58  QUICKFEED_TEST_ENV4=test4 xyz
    59  ## Another comment
    60  QUICKFEED_TEST_ENV5=test5 = zyx
    61  # Variable to be expanded into other vars
    62  SOME_PATH=/quickfeed/root
    63  # Cert file and key file expanded
    64  SOME_CERT_FILE=$SOME_PATH/cert/fullchain.pem
    65  SOME_KEY_FILE=$QUICKFEED/cert/fullchain.pem
    66  WITHOUT_QUOTES="$QUICKFEED/cert/fullchain.pem"
    67  `
    68  	if _, err = fi.WriteString(input); err != nil {
    69  		t.Fatal(err)
    70  	}
    71  
    72  	if err = env.Load(fi.Name()); err != nil {
    73  		t.Fatal(err)
    74  	}
    75  
    76  	for k, v := range want {
    77  		if got := os.Getenv(k); got != v {
    78  			t.Errorf("os.Getenv(%q) = %q, wanted %q", k, got, v)
    79  		}
    80  	}
    81  }
    82  
    83  func TestExistsLogic(t *testing.T) {
    84  	const (
    85  		E = true  // exists
    86  		Ø = false // does not exist
    87  	)
    88  	type exist struct {
    89  		file bool
    90  		bak  bool
    91  	}
    92  
    93  	exists := func(filename string) bool {
    94  		_, err := os.Stat(filename)
    95  		return err == nil
    96  	}
    97  
    98  	msg := func(e bool) string {
    99  		if e {
   100  			return `"exists", wanted "does not exist"`
   101  		}
   102  		return `"does not exist", wanted "exists"`
   103  	}
   104  
   105  	const baseFilename = "env"
   106  	existsErr := env.ExistsError("dummy")   // will be replaced with other error with correct t.TempDir()
   107  	missingErr := env.MissingError("dummy") // will be replaced with other error with correct t.TempDir()
   108  
   109  	tests := []struct {
   110  		name    string
   111  		before  exist
   112  		after   exist
   113  		wantErr error
   114  	}{
   115  		{name: "NoFileExists   ", before: exist{file: Ø, bak: Ø}, after: exist{file: Ø, bak: Ø}, wantErr: missingErr},
   116  		{name: "EnvFileExists  ", before: exist{file: E, bak: Ø}, after: exist{file: E, bak: Ø}, wantErr: nil},
   117  		{name: "BakFileExists  ", before: exist{file: Ø, bak: E}, after: exist{file: Ø, bak: E}, wantErr: existsErr},
   118  		{name: "BothFilesExists", before: exist{file: E, bak: E}, after: exist{file: E, bak: E}, wantErr: existsErr},
   119  	}
   120  	for _, test := range tests {
   121  		t.Run(test.name, func(t *testing.T) {
   122  			var (
   123  				dir         = t.TempDir()
   124  				filename    = filepath.Join(dir, baseFilename)
   125  				bakFilename = filename + ".bak"
   126  			)
   127  			if test.before.file {
   128  				if _, err := os.Create(filename); err != nil {
   129  					t.Fatal(err)
   130  				}
   131  			}
   132  			if test.before.bak {
   133  				if _, err := os.Create(bakFilename); err != nil {
   134  					t.Fatal(err)
   135  				}
   136  			}
   137  			if errors.Is(test.wantErr, existsErr) {
   138  				// use error with correct t.TempDir()
   139  				test.wantErr = env.ExistsError(bakFilename)
   140  			}
   141  			if errors.Is(test.wantErr, missingErr) {
   142  				// use error with correct t.TempDir()
   143  				test.wantErr = env.MissingError(filename)
   144  			}
   145  			if err := env.Prepared(filename); !errors.Is(err, test.wantErr) {
   146  				t.Errorf("Prepared(%q) = %v, wanted %v", filepath.Base(filename), err, test.wantErr)
   147  			}
   148  			if exists(filename) != test.after.file {
   149  				t.Errorf("%q: %s", filepath.Base(filename), msg(test.after.file))
   150  			}
   151  			if exists(bakFilename) != test.after.bak {
   152  				t.Errorf("%q: %s", filepath.Base(bakFilename), msg(test.after.bak))
   153  			}
   154  		})
   155  	}
   156  }
   157  
   158  func TestSave(t *testing.T) {
   159  	fi, err := os.CreateTemp("", ".env")
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	defer func() {
   164  		fi.Close()
   165  		if err = os.Remove(fi.Name()); err != nil {
   166  			t.Fatal(err)
   167  		}
   168  	}()
   169  
   170  	prevContent := `QUICKFEED_TEST_ENV=test
   171  QUICKFEED_TEST_ENV2=test2
   172  QUICKFEED_CLIENT_ID=321
   173  QUICKFEED=/mumbo/jumbo
   174  `
   175  	if _, err = fi.WriteString(prevContent); err != nil {
   176  		t.Fatal(err)
   177  	}
   178  
   179  	want := map[string]string{
   180  		"QUICKFEED_APP_ID":        "weird al",
   181  		"QUICKFEED_APP_KEY":       "$QUICKFEED/internal/config/github/quickfeed.pem",
   182  		"QUICKFEED_CLIENT_ID":     "123",
   183  		"QUICKFEED_CLIENT_SECRET": "456",
   184  		"QUICKFEED_KEY_FILE":      "$QUICKFEED/internal/config/certs/privkey.pem",
   185  		"QUICKFEED_CERT_FILE":     "$QUICKFEED/internal/config/certs/fullchain.pem",
   186  		"QUICKFEED":               os.Getenv("QUICKFEED"),
   187  		"SOME_PATH":               "/quickfeed/root",
   188  		"SPEEDY":                  "$QUICKFEED/gonzales",
   189  	}
   190  	if err = env.Save(fi.Name(), want); err != nil {
   191  		t.Fatal(err)
   192  	}
   193  
   194  	if err = env.Load(fi.Name()); err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	for k, v := range want {
   199  		expVal := os.ExpandEnv(v)
   200  		if got := os.Getenv(k); got != expVal {
   201  			t.Errorf("os.Getenv(%q) = %q, wanted %q", k, got, expVal)
   202  		}
   203  	}
   204  	if os.Getenv("QUICKFEED_TEST_ENV") != "test" {
   205  		t.Errorf("os.Getenv(%q) = %q, wanted %q", "QUICKFEED_TEST_ENV", os.Getenv("QUICKFEED_TEST_ENV"), "test")
   206  	}
   207  	if os.Getenv("QUICKFEED_TEST_ENV2") != "test2" {
   208  		t.Errorf("os.Getenv(%q) = %q, wanted %q", "QUICKFEED_TEST_ENV", os.Getenv("QUICKFEED_TEST_ENV"), "test2")
   209  	}
   210  }
   211  
   212  func TestWhitelist(t *testing.T) {
   213  	test := []struct {
   214  		domains string
   215  		want    []string
   216  		err     bool
   217  	}{
   218  		{"", nil, true},
   219  		{",", nil, true},
   220  		{"localhost", nil, true},
   221  		{"localhost,example.com", nil, true},
   222  		{"123.12.1.1", nil, true},
   223  		{"172.31.120.166", nil, true},
   224  		{"84.22.1.92", nil, true},
   225  		{"example.com, www.example.com, localhost", nil, true},
   226  		{"example.com, www.example.com,127.0.0.1:8080", nil, true},
   227  		{"a.com, b.com, c.com", []string{"a.com", "b.com", "c.com"}, false},
   228  		{"a.com,b.com,c.com", []string{"a.com", "b.com", "c.com"}, false},
   229  		{"example.com, www.example.com", []string{"example.com", "www.example.com"}, false},
   230  		{"example.com, www.example.com,", []string{"example.com", "www.example.com"}, false},
   231  		{"example.com, www.example.com,,, , , ", []string{"example.com", "www.example.com"}, false},
   232  	}
   233  
   234  	for _, tc := range test {
   235  		t.Setenv("QUICKFEED_WHITELIST", tc.domains)
   236  		got, err := env.Whitelist()
   237  		if err != nil && !tc.err {
   238  			t.Errorf("Whitelist() = %v", err)
   239  		}
   240  		if diff := cmp.Diff(tc.want, got); diff != "" {
   241  			t.Errorf("Whitelist() mismatch (-want +got):\n%s", diff)
   242  		}
   243  	}
   244  }