github.com/zebozhuang/go@v0.0.0-20200207033046-f8a98f6f5c5d/src/crypto/x509/root_unix_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build dragonfly freebsd linux netbsd openbsd solaris
     6  
     7  package x509
     8  
     9  import (
    10  	"fmt"
    11  	"os"
    12  	"testing"
    13  )
    14  
    15  const (
    16  	testDir     = "testdata"
    17  	testDirCN   = "test-dir"
    18  	testFile    = "test-file.crt"
    19  	testFileCN  = "test-file"
    20  	testMissing = "missing"
    21  )
    22  
    23  func TestEnvVars(t *testing.T) {
    24  	testCases := []struct {
    25  		name    string
    26  		fileEnv string
    27  		dirEnv  string
    28  		files   []string
    29  		dirs    []string
    30  		cns     []string
    31  	}{
    32  		{
    33  			// Environment variables override the default locations preventing fall through.
    34  			name:    "override-defaults",
    35  			fileEnv: testMissing,
    36  			dirEnv:  testMissing,
    37  			files:   []string{testFile},
    38  			dirs:    []string{testDir},
    39  			cns:     nil,
    40  		},
    41  		{
    42  			// File environment overrides default file locations.
    43  			name:    "file",
    44  			fileEnv: testFile,
    45  			dirEnv:  "",
    46  			files:   nil,
    47  			dirs:    nil,
    48  			cns:     []string{testFileCN},
    49  		},
    50  		{
    51  			// Directory environment overrides default directory locations.
    52  			name:    "dir",
    53  			fileEnv: "",
    54  			dirEnv:  testDir,
    55  			files:   nil,
    56  			dirs:    nil,
    57  			cns:     []string{testDirCN},
    58  		},
    59  		{
    60  			// File & directory environment overrides both default locations.
    61  			name:    "file+dir",
    62  			fileEnv: testFile,
    63  			dirEnv:  testDir,
    64  			files:   nil,
    65  			dirs:    nil,
    66  			cns:     []string{testFileCN, testDirCN},
    67  		},
    68  		{
    69  			// Environment variable empty / unset uses default locations.
    70  			name:    "empty-fall-through",
    71  			fileEnv: "",
    72  			dirEnv:  "",
    73  			files:   []string{testFile},
    74  			dirs:    []string{testDir},
    75  			cns:     []string{testFileCN, testDirCN},
    76  		},
    77  	}
    78  
    79  	// Save old settings so we can restore before the test ends.
    80  	origCertFiles, origCertDirectories := certFiles, certDirectories
    81  	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
    82  	defer func() {
    83  		certFiles = origCertFiles
    84  		certDirectories = origCertDirectories
    85  		os.Setenv(certFileEnv, origFile)
    86  		os.Setenv(certDirEnv, origDir)
    87  	}()
    88  
    89  	for _, tc := range testCases {
    90  		t.Run(tc.name, func(t *testing.T) {
    91  			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
    92  				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
    93  			}
    94  			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
    95  				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
    96  			}
    97  
    98  			certFiles, certDirectories = tc.files, tc.dirs
    99  
   100  			r, err := loadSystemRoots()
   101  			if err != nil {
   102  				t.Fatal("unexpected failure:", err)
   103  			}
   104  
   105  			if r == nil {
   106  				if tc.cns == nil {
   107  					// Expected nil
   108  					return
   109  				}
   110  				t.Fatal("nil roots")
   111  			}
   112  
   113  			// Verify that the returned certs match, otherwise report where the mismatch is.
   114  			for i, cn := range tc.cns {
   115  				if i >= len(r.certs) {
   116  					t.Errorf("missing cert %v @ %v", cn, i)
   117  				} else if r.certs[i].Subject.CommonName != cn {
   118  					fmt.Printf("%#v\n", r.certs[0].Subject)
   119  					t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
   120  				}
   121  			}
   122  			if len(r.certs) > len(tc.cns) {
   123  				t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
   124  			}
   125  		})
   126  	}
   127  }