github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/subshell/cmd/env.go (about) 1 package cmd 2 3 import ( 4 "log" 5 "os" 6 "path/filepath" 7 "strings" 8 9 "github.com/ActiveState/cli/internal/locale" 10 "github.com/ActiveState/cli/internal/logging" 11 "github.com/ActiveState/cli/internal/osutils" 12 ) 13 14 type OpenKeyFn func(path string) (osutils.RegistryKey, error) 15 16 type CmdEnv struct { 17 openKeyFn OpenKeyFn 18 // whether this updates the system environment 19 userScope bool 20 } 21 22 func NewCmdEnv(userScope bool) *CmdEnv { 23 openKeyFn := osutils.OpenSystemKey 24 if userScope { 25 openKeyFn = osutils.OpenUserKey 26 } 27 logging.Debug("Opening registry, userScope: %v", userScope) 28 return &CmdEnv{ 29 openKeyFn: openKeyFn, 30 userScope: userScope, 31 } 32 } 33 34 func getEnvironmentPath(userScope bool) string { 35 if userScope { 36 return "Environment" 37 } 38 return `SYSTEM\ControlSet001\Control\Session Manager\Environment` 39 } 40 41 // unsetUserEnv clears a state cool configured environment variable 42 // It only does this if the value equals the expected value (meaning if we can verify that state tool was in fact 43 // responsible for setting it) 44 func (c *CmdEnv) unset(keyName, oldValue string) error { 45 envPath := getEnvironmentPath(c.userScope) 46 key, err := c.openKeyFn(envPath) 47 if err != nil { 48 return locale.WrapError(err, "err_windows_registry") 49 } 50 defer key.Close() 51 52 logging.Debug("Unsetting key %s in %s", keyName, envPath) 53 54 v, _, err := key.GetStringValue(keyName) 55 if err != nil { 56 if osutils.IsNotExistError(err) { 57 return nil 58 } 59 return locale.WrapError(err, "err_windows_registry") 60 } 61 62 // Special handling if the key is PATH 63 if keyName == "PATH" { 64 updatedPath := cleanPath(v, oldValue) 65 return key.SetExpandStringValue(keyName, updatedPath) 66 } 67 68 // Check if we are responsible for the value and delete if so 69 if v == oldValue { 70 logging.Debug("Removing environment key %s", keyName) 71 return key.DeleteValue(keyName) 72 } 73 74 return nil 75 } 76 77 func cleanPath(keyValue, oldEntry string) string { 78 oldEntries := make(map[string]bool) 79 for _, entry := range strings.Split(oldEntry, string(os.PathListSeparator)) { 80 oldEntries[filepath.Clean(entry)] = true 81 } 82 83 var newValue []string 84 for _, entry := range strings.Split(keyValue, string(os.PathListSeparator)) { 85 if oldEntries[filepath.Clean(entry)] { 86 logging.Debug("Dropping path entry: %s", entry) 87 continue 88 } 89 newValue = append(newValue, entry) 90 } 91 return strings.Join(newValue, string(os.PathListSeparator)) 92 } 93 94 // Set sets a variable in the user environment and saves the original as a backup 95 func (c *CmdEnv) Set(name, newValue string) error { 96 key, err := c.openKeyFn(getEnvironmentPath(c.userScope)) 97 if err != nil { 98 return locale.WrapError(err, "err_windows_registry") 99 } 100 defer key.Close() 101 102 // Check if we're going to be overriding 103 _, valType, err := key.GetStringValue(name) 104 if err != nil && !osutils.IsNotExistError(err) { 105 return locale.WrapError(err, "err_windows_registry") 106 } 107 108 return osutils.SetStringValue(key, name, valType, newValue) 109 } 110 111 // Get retrieves a variable from the user environment, this prioritizes a backup if it exists 112 func (c *CmdEnv) Get(name string) (string, error) { 113 key, err := c.openKeyFn(getEnvironmentPath(c.userScope)) 114 if err != nil { 115 return "", locale.WrapError(err, "err_windows_registry") 116 } 117 defer key.Close() 118 119 v, _, err := key.GetStringValue(name) 120 if err != nil && !osutils.IsNotExistError(err) { 121 return v, locale.WrapError(err, "err_windows_registry") 122 } 123 return v, nil 124 } 125 126 // GetUnsafe is an alias for `get` intended for use by tests/integration tests, don't use for anything else! 127 func (c *CmdEnv) GetUnsafe(name string) string { 128 r, f := c.Get(name) 129 if f != nil { 130 log.Fatalf("GetUnsafe failed with: %s", f.Error()) 131 } 132 return r 133 }