github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/subshell/cmd/env.go (about)

     1  package cmd
     2  
     3  import (
     4  	"log"
     5  	"os"
     6  	"path/filepath"
     7  	"strings"
     8  
     9  	"github.com/ActiveState/cli/internal/locale"
    10  	"github.com/ActiveState/cli/internal/logging"
    11  	"github.com/ActiveState/cli/internal/osutils"
    12  )
    13  
    14  type OpenKeyFn func(path string) (osutils.RegistryKey, error)
    15  
    16  type CmdEnv struct {
    17  	openKeyFn OpenKeyFn
    18  	// whether this updates the system environment
    19  	userScope bool
    20  }
    21  
    22  func NewCmdEnv(userScope bool) *CmdEnv {
    23  	openKeyFn := osutils.OpenSystemKey
    24  	if userScope {
    25  		openKeyFn = osutils.OpenUserKey
    26  	}
    27  	logging.Debug("Opening registry, userScope: %v", userScope)
    28  	return &CmdEnv{
    29  		openKeyFn: openKeyFn,
    30  		userScope: userScope,
    31  	}
    32  }
    33  
    34  func getEnvironmentPath(userScope bool) string {
    35  	if userScope {
    36  		return "Environment"
    37  	}
    38  	return `SYSTEM\ControlSet001\Control\Session Manager\Environment`
    39  }
    40  
    41  // unsetUserEnv clears a state cool configured environment variable
    42  // It only does this if the value equals the expected value (meaning if we can verify that state tool was in fact
    43  // responsible for setting it)
    44  func (c *CmdEnv) unset(keyName, oldValue string) error {
    45  	envPath := getEnvironmentPath(c.userScope)
    46  	key, err := c.openKeyFn(envPath)
    47  	if err != nil {
    48  		return locale.WrapError(err, "err_windows_registry")
    49  	}
    50  	defer key.Close()
    51  
    52  	logging.Debug("Unsetting key %s in %s", keyName, envPath)
    53  
    54  	v, _, err := key.GetStringValue(keyName)
    55  	if err != nil {
    56  		if osutils.IsNotExistError(err) {
    57  			return nil
    58  		}
    59  		return locale.WrapError(err, "err_windows_registry")
    60  	}
    61  
    62  	// Special handling if the key is PATH
    63  	if keyName == "PATH" {
    64  		updatedPath := cleanPath(v, oldValue)
    65  		return key.SetExpandStringValue(keyName, updatedPath)
    66  	}
    67  
    68  	// Check if we are responsible for the value and delete if so
    69  	if v == oldValue {
    70  		logging.Debug("Removing environment key %s", keyName)
    71  		return key.DeleteValue(keyName)
    72  	}
    73  
    74  	return nil
    75  }
    76  
    77  func cleanPath(keyValue, oldEntry string) string {
    78  	oldEntries := make(map[string]bool)
    79  	for _, entry := range strings.Split(oldEntry, string(os.PathListSeparator)) {
    80  		oldEntries[filepath.Clean(entry)] = true
    81  	}
    82  
    83  	var newValue []string
    84  	for _, entry := range strings.Split(keyValue, string(os.PathListSeparator)) {
    85  		if oldEntries[filepath.Clean(entry)] {
    86  			logging.Debug("Dropping path entry: %s", entry)
    87  			continue
    88  		}
    89  		newValue = append(newValue, entry)
    90  	}
    91  	return strings.Join(newValue, string(os.PathListSeparator))
    92  }
    93  
    94  // Set sets a variable in the user environment and saves the original as a backup
    95  func (c *CmdEnv) Set(name, newValue string) error {
    96  	key, err := c.openKeyFn(getEnvironmentPath(c.userScope))
    97  	if err != nil {
    98  		return locale.WrapError(err, "err_windows_registry")
    99  	}
   100  	defer key.Close()
   101  
   102  	// Check if we're going to be overriding
   103  	_, valType, err := key.GetStringValue(name)
   104  	if err != nil && !osutils.IsNotExistError(err) {
   105  		return locale.WrapError(err, "err_windows_registry")
   106  	}
   107  
   108  	return osutils.SetStringValue(key, name, valType, newValue)
   109  }
   110  
   111  // Get retrieves a variable from the user environment, this prioritizes a backup if it exists
   112  func (c *CmdEnv) Get(name string) (string, error) {
   113  	key, err := c.openKeyFn(getEnvironmentPath(c.userScope))
   114  	if err != nil {
   115  		return "", locale.WrapError(err, "err_windows_registry")
   116  	}
   117  	defer key.Close()
   118  
   119  	v, _, err := key.GetStringValue(name)
   120  	if err != nil && !osutils.IsNotExistError(err) {
   121  		return v, locale.WrapError(err, "err_windows_registry")
   122  	}
   123  	return v, nil
   124  }
   125  
   126  // GetUnsafe is an alias for `get` intended for use by tests/integration tests, don't use for anything else!
   127  func (c *CmdEnv) GetUnsafe(name string) string {
   128  	r, f := c.Get(name)
   129  	if f != nil {
   130  		log.Fatalf("GetUnsafe failed with: %s", f.Error())
   131  	}
   132  	return r
   133  }