github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/cmd/getgo/path.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !plan9
     6  // +build !plan9
     7  
     8  package main
     9  
    10  import (
    11  	"bufio"
    12  	"context"
    13  	"fmt"
    14  	"os"
    15  	"os/user"
    16  	"path/filepath"
    17  	"runtime"
    18  	"strings"
    19  )
    20  
    21  const (
    22  	bashConfig = ".bash_profile"
    23  	zshConfig  = ".zshrc"
    24  )
    25  
    26  // appendToPATH adds the given path to the PATH environment variable and
    27  // persists it for future sessions.
    28  func appendToPATH(value string) error {
    29  	if isInPATH(value) {
    30  		return nil
    31  	}
    32  	return persistEnvVar("PATH", pathVar+envSeparator+value)
    33  }
    34  
    35  func isInPATH(dir string) bool {
    36  	p := os.Getenv("PATH")
    37  
    38  	paths := strings.Split(p, envSeparator)
    39  	for _, d := range paths {
    40  		if d == dir {
    41  			return true
    42  		}
    43  	}
    44  
    45  	return false
    46  }
    47  
    48  func getHomeDir() (string, error) {
    49  	home := os.Getenv(homeKey)
    50  	if home != "" {
    51  		return home, nil
    52  	}
    53  
    54  	u, err := user.Current()
    55  	if err != nil {
    56  		return "", err
    57  	}
    58  	return u.HomeDir, nil
    59  }
    60  
    61  func checkStringExistsFile(filename, value string) (bool, error) {
    62  	file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
    63  	if err != nil {
    64  		if os.IsNotExist(err) {
    65  			return false, nil
    66  		}
    67  		return false, err
    68  	}
    69  	defer file.Close()
    70  
    71  	scanner := bufio.NewScanner(file)
    72  	for scanner.Scan() {
    73  		line := scanner.Text()
    74  		if line == value {
    75  			return true, nil
    76  		}
    77  	}
    78  
    79  	return false, scanner.Err()
    80  }
    81  
    82  func appendToFile(filename, value string) error {
    83  	verbosef("Adding %q to %s", value, filename)
    84  
    85  	ok, err := checkStringExistsFile(filename, value)
    86  	if err != nil {
    87  		return err
    88  	}
    89  	if ok {
    90  		// Nothing to do.
    91  		return nil
    92  	}
    93  
    94  	f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
    95  	if err != nil {
    96  		return err
    97  	}
    98  	defer f.Close()
    99  
   100  	_, err = f.WriteString(lineEnding + value + lineEnding)
   101  	return err
   102  }
   103  
   104  func isShell(name string) bool {
   105  	return strings.Contains(currentShell(), name)
   106  }
   107  
   108  // persistEnvVarWindows sets an environment variable in the Windows
   109  // registry.
   110  func persistEnvVarWindows(name, value string) error {
   111  	_, err := runCommand(context.Background(), "powershell", "-command",
   112  		fmt.Sprintf(`[Environment]::SetEnvironmentVariable("%s", "%s", "User")`, name, value))
   113  	return err
   114  }
   115  
   116  func persistEnvVar(name, value string) error {
   117  	if runtime.GOOS == "windows" {
   118  		if err := persistEnvVarWindows(name, value); err != nil {
   119  			return err
   120  		}
   121  
   122  		if isShell("cmd.exe") || isShell("powershell.exe") {
   123  			return os.Setenv(strings.ToUpper(name), value)
   124  		}
   125  		// User is in bash, zsh, etc.
   126  		// Also set the environment variable in their shell config.
   127  	}
   128  
   129  	rc, err := shellConfigFile()
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	line := fmt.Sprintf("export %s=%s", strings.ToUpper(name), value)
   135  	if err := appendToFile(rc, line); err != nil {
   136  		return err
   137  	}
   138  
   139  	return os.Setenv(strings.ToUpper(name), value)
   140  }
   141  
   142  func shellConfigFile() (string, error) {
   143  	home, err := getHomeDir()
   144  	if err != nil {
   145  		return "", err
   146  	}
   147  
   148  	switch {
   149  	case isShell("bash"):
   150  		return filepath.Join(home, bashConfig), nil
   151  	case isShell("zsh"):
   152  		return filepath.Join(home, zshConfig), nil
   153  	default:
   154  		return "", fmt.Errorf("%q is not a supported shell", currentShell())
   155  	}
   156  }