github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/core/config/configtest/config.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package configtest
     8  
     9  import (
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"os"
    14  	"os/exec"
    15  	"path/filepath"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/spf13/viper"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  // AddDevConfigPath adds the DevConfigDir to the viper path.
    24  func AddDevConfigPath(v *viper.Viper) {
    25  	devPath := GetDevConfigDir()
    26  	if v != nil {
    27  		v.AddConfigPath(devPath)
    28  	} else {
    29  		viper.AddConfigPath(devPath)
    30  	}
    31  }
    32  
    33  func dirExists(path string) bool {
    34  	fi, err := os.Stat(path)
    35  	if err != nil {
    36  		return false
    37  	}
    38  	return fi.IsDir()
    39  }
    40  
    41  // GetDevConfigDir gets the path to the default configuration that is
    42  // maintained with the source tree. This should only be used in a
    43  // test/development context.
    44  func GetDevConfigDir() string {
    45  	path, err := gomodDevConfigDir()
    46  	if err != nil {
    47  		path, err = gopathDevConfigDir()
    48  		if err != nil {
    49  			panic(err)
    50  		}
    51  	}
    52  	return path
    53  }
    54  
    55  func gopathDevConfigDir() (string, error) {
    56  	buf := bytes.NewBuffer(nil)
    57  	cmd := exec.Command("go", "env", "GOPATH")
    58  	cmd.Stdout = buf
    59  	if err := cmd.Run(); err != nil {
    60  		return "", err
    61  	}
    62  
    63  	gopath := strings.TrimSpace(buf.String())
    64  	for _, p := range filepath.SplitList(gopath) {
    65  		devPath := filepath.Join(p, "src/github.com/hechain20/hechain/sampleconfig")
    66  		if dirExists(devPath) {
    67  			return devPath, nil
    68  		}
    69  	}
    70  
    71  	return "", fmt.Errorf("unable to find sampleconfig directory on GOPATH")
    72  }
    73  
    74  func gomodDevConfigDir() (string, error) {
    75  	buf := bytes.NewBuffer(nil)
    76  	cmd := exec.Command("go", "env", "GOMOD")
    77  	cmd.Stdout = buf
    78  
    79  	if err := cmd.Run(); err != nil {
    80  		return "", err
    81  	}
    82  
    83  	modFile := strings.TrimSpace(buf.String())
    84  	if modFile == "" {
    85  		return "", errors.New("not a module or not in module mode")
    86  	}
    87  
    88  	devPath := filepath.Join(filepath.Dir(modFile), "sampleconfig")
    89  	if !dirExists(devPath) {
    90  		return "", fmt.Errorf("%s does not exist", devPath)
    91  	}
    92  
    93  	return devPath, nil
    94  }
    95  
    96  // GetDevMspDir gets the path to the sampleconfig/msp tree that is maintained
    97  // with the source tree.  This should only be used in a test/development
    98  // context.
    99  func GetDevMspDir() string {
   100  	devDir := GetDevConfigDir()
   101  	return filepath.Join(devDir, "msp")
   102  }
   103  
   104  func SetDevFabricConfigPath(t *testing.T) (cleanup func()) {
   105  	t.Helper()
   106  
   107  	oldFabricCfgPath, resetFabricCfgPath := os.LookupEnv("FABRIC_CFG_PATH")
   108  	devConfigDir := GetDevConfigDir()
   109  
   110  	err := os.Setenv("FABRIC_CFG_PATH", devConfigDir)
   111  	require.NoError(t, err, "failed to set FABRIC_CFG_PATH")
   112  	if resetFabricCfgPath {
   113  		return func() {
   114  			err := os.Setenv("FABRIC_CFG_PATH", oldFabricCfgPath)
   115  			require.NoError(t, err)
   116  		}
   117  	}
   118  
   119  	return func() {
   120  		err := os.Unsetenv("FABRIC_CFG_PATH")
   121  		require.NoError(t, err)
   122  	}
   123  }