github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/x509/root_unix_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
     6  // +build dragonfly freebsd linux netbsd openbsd solaris
     7  
     8  package x509
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"os"
    14  	"path/filepath"
    15  	"reflect"
    16  	"strings"
    17  	"testing"
    18  )
    19  
    20  const (
    21  	testDir     = "testdata"
    22  	testDirCN   = "test-dir"
    23  	testFile    = "test-file.crt"
    24  	testFileCN  = "test-file"
    25  	testMissing = "missing"
    26  )
    27  
    28  func TestEnvVars(t *testing.T) {
    29  	testCases := []struct {
    30  		name    string
    31  		fileEnv string
    32  		dirEnv  string
    33  		files   []string
    34  		dirs    []string
    35  		cns     []string
    36  	}{
    37  		{
    38  			// Environment variables override the default locations preventing fall through.
    39  			name:    "override-defaults",
    40  			fileEnv: testMissing,
    41  			dirEnv:  testMissing,
    42  			files:   []string{testFile},
    43  			dirs:    []string{testDir},
    44  			cns:     nil,
    45  		},
    46  		{
    47  			// File environment overrides default file locations.
    48  			name:    "file",
    49  			fileEnv: testFile,
    50  			dirEnv:  "",
    51  			files:   nil,
    52  			dirs:    nil,
    53  			cns:     []string{testFileCN},
    54  		},
    55  		{
    56  			// Directory environment overrides default directory locations.
    57  			name:    "dir",
    58  			fileEnv: "",
    59  			dirEnv:  testDir,
    60  			files:   nil,
    61  			dirs:    nil,
    62  			cns:     []string{testDirCN},
    63  		},
    64  		{
    65  			// File & directory environment overrides both default locations.
    66  			name:    "file+dir",
    67  			fileEnv: testFile,
    68  			dirEnv:  testDir,
    69  			files:   nil,
    70  			dirs:    nil,
    71  			cns:     []string{testFileCN, testDirCN},
    72  		},
    73  		{
    74  			// Environment variable empty / unset uses default locations.
    75  			name:    "empty-fall-through",
    76  			fileEnv: "",
    77  			dirEnv:  "",
    78  			files:   []string{testFile},
    79  			dirs:    []string{testDir},
    80  			cns:     []string{testFileCN, testDirCN},
    81  		},
    82  	}
    83  
    84  	// Save old settings so we can restore before the test ends.
    85  	origCertFiles, origCertDirectories := certFiles, certDirectories
    86  	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
    87  	defer func() {
    88  		certFiles = origCertFiles
    89  		certDirectories = origCertDirectories
    90  		err := os.Setenv(certFileEnv, origFile)
    91  		if err != nil {
    92  			panic(err)
    93  		}
    94  		err = os.Setenv(certDirEnv, origDir)
    95  		if err != nil {
    96  			panic(err)
    97  		}
    98  	}()
    99  
   100  	for _, tc := range testCases {
   101  		t.Run(tc.name, func(t *testing.T) {
   102  			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
   103  				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
   104  			}
   105  			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
   106  				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
   107  			}
   108  
   109  			certFiles, certDirectories = tc.files, tc.dirs
   110  
   111  			r, err := loadSystemRoots()
   112  			if err != nil {
   113  				t.Fatal("unexpected failure:", err)
   114  			}
   115  
   116  			if r == nil {
   117  				t.Fatal("nil roots")
   118  			}
   119  
   120  			// Verify that the returned certs match, otherwise report where the mismatch is.
   121  			for i, cn := range tc.cns {
   122  				if i >= r.len() {
   123  					t.Errorf("missing cert %v @ %v", cn, i)
   124  				} else if r.mustCert(t, i).Subject.CommonName != cn {
   125  					fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
   126  					t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
   127  				}
   128  			}
   129  			if r.len() > len(tc.cns) {
   130  				t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
   131  			}
   132  		})
   133  	}
   134  }
   135  
   136  // Ensure that "SSL_CERT_DIR" when used as the environment
   137  // variable delimited by colons, allows loadSystemRoots to
   138  // load all the roots from the respective directories.
   139  // See https://golang.org/issue/35325.
   140  func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
   141  	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
   142  	origCertFiles := certFiles[:]
   143  
   144  	// To prevent any other certs from being loaded in
   145  	// through "SSL_CERT_FILE" or from known "certFiles",
   146  	// clear them all, and they'll be reverting on defer.
   147  	certFiles = certFiles[:0]
   148  	err := os.Setenv(certFileEnv, "")
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  
   153  	defer func() {
   154  		certFiles = origCertFiles[:]
   155  		err := os.Setenv(certDirEnv, origDir)
   156  		if err != nil {
   157  			panic(err)
   158  		}
   159  		err = os.Setenv(certFileEnv, origFile)
   160  		if err != nil {
   161  			panic(err)
   162  		}
   163  	}()
   164  
   165  	tmpDir := t.TempDir()
   166  
   167  	rootPEMs := []string{
   168  		geoTrustRoot,
   169  		googleLeaf,
   170  		startComRoot,
   171  	}
   172  
   173  	var certDirs []string
   174  	for i, certPEM := range rootPEMs {
   175  		certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i))
   176  		if err := os.MkdirAll(certDir, 0755); err != nil {
   177  			t.Fatalf("Failed to create certificate dir: %v", err)
   178  		}
   179  		certOutFile := filepath.Join(certDir, "cert.crt")
   180  		if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil {
   181  			t.Fatalf("Failed to write certificate to file: %v", err)
   182  		}
   183  		certDirs = append(certDirs, certDir)
   184  	}
   185  
   186  	// Sanity check: the number of certDirs should be equal to the number of roots.
   187  	if g, w := len(certDirs), len(rootPEMs); g != w {
   188  		t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w)
   189  	}
   190  
   191  	// Now finally concatenate them with a colon.
   192  	colonConcatCertDirs := strings.Join(certDirs, ":")
   193  	err = os.Setenv(certDirEnv, colonConcatCertDirs)
   194  	if err != nil {
   195  		t.Fatal(err)
   196  	}
   197  	gotPool, err := loadSystemRoots()
   198  	if err != nil {
   199  		t.Fatalf("Failed to load system roots: %v", err)
   200  	}
   201  	subjects := gotPool.Subjects()
   202  	// We expect exactly len(rootPEMs) subjects back.
   203  	if g, w := len(subjects), len(rootPEMs); g != w {
   204  		t.Fatalf("Invalid number of subjects: got %d want %d", g, w)
   205  	}
   206  
   207  	wantPool := NewCertPool()
   208  	for _, certPEM := range rootPEMs {
   209  		wantPool.AppendCertsFromPEM([]byte(certPEM))
   210  	}
   211  	strCertPool := func(p *CertPool) string {
   212  		return string(bytes.Join(p.Subjects(), []byte("\n")))
   213  	}
   214  
   215  	if !certPoolEqual(gotPool, wantPool) {
   216  		g, w := strCertPool(gotPool), strCertPool(wantPool)
   217  		t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
   218  	}
   219  }
   220  
   221  func TestReadUniqueDirectoryEntries(t *testing.T) {
   222  	tmp := t.TempDir()
   223  	temp := func(base string) string { return filepath.Join(tmp, base) }
   224  	if f, err := os.Create(temp("file")); err != nil {
   225  		t.Fatal(err)
   226  	} else {
   227  		err := f.Close()
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  	}
   232  	if err := os.Symlink("target-in", temp("link-in")); err != nil {
   233  		t.Fatal(err)
   234  	}
   235  	if err := os.Symlink("../target-out", temp("link-out")); err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	got, err := readUniqueDirectoryEntries(tmp)
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  	gotNames := []string{}
   243  	for _, fi := range got {
   244  		gotNames = append(gotNames, fi.Name())
   245  	}
   246  	wantNames := []string{"file", "link-out"}
   247  	if !reflect.DeepEqual(gotNames, wantNames) {
   248  		t.Errorf("got %q; want %q", gotNames, wantNames)
   249  	}
   250  }