github.com/april1989/origin-go-tools@v0.0.32/cmd/getgo/path.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build !plan9
     6  
     7  package main
     8  
     9  import (
    10  	"bufio"
    11  	"context"
    12  	"fmt"
    13  	"os"
    14  	"os/user"
    15  	"path/filepath"
    16  	"runtime"
    17  	"strings"
    18  )
    19  
    20  const (
    21  	bashConfig = ".bash_profile"
    22  	zshConfig  = ".zshrc"
    23  )
    24  
    25  // appendToPATH adds the given path to the PATH environment variable and
    26  // persists it for future sessions.
    27  func appendToPATH(value string) error {
    28  	if isInPATH(value) {
    29  		return nil
    30  	}
    31  	return persistEnvVar("PATH", pathVar+envSeparator+value)
    32  }
    33  
    34  func isInPATH(dir string) bool {
    35  	p := os.Getenv("PATH")
    36  
    37  	paths := strings.Split(p, envSeparator)
    38  	for _, d := range paths {
    39  		if d == dir {
    40  			return true
    41  		}
    42  	}
    43  
    44  	return false
    45  }
    46  
    47  func getHomeDir() (string, error) {
    48  	home := os.Getenv(homeKey)
    49  	if home != "" {
    50  		return home, nil
    51  	}
    52  
    53  	u, err := user.Current()
    54  	if err != nil {
    55  		return "", err
    56  	}
    57  	return u.HomeDir, nil
    58  }
    59  
    60  func checkStringExistsFile(filename, value string) (bool, error) {
    61  	file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
    62  	if err != nil {
    63  		if os.IsNotExist(err) {
    64  			return false, nil
    65  		}
    66  		return false, err
    67  	}
    68  	defer file.Close()
    69  
    70  	scanner := bufio.NewScanner(file)
    71  	for scanner.Scan() {
    72  		line := scanner.Text()
    73  		if line == value {
    74  			return true, nil
    75  		}
    76  	}
    77  
    78  	return false, scanner.Err()
    79  }
    80  
    81  func appendToFile(filename, value string) error {
    82  	verbosef("Adding %q to %s", value, filename)
    83  
    84  	ok, err := checkStringExistsFile(filename, value)
    85  	if err != nil {
    86  		return err
    87  	}
    88  	if ok {
    89  		// Nothing to do.
    90  		return nil
    91  	}
    92  
    93  	f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer f.Close()
    98  
    99  	_, err = f.WriteString(lineEnding + value + lineEnding)
   100  	return err
   101  }
   102  
   103  func isShell(name string) bool {
   104  	return strings.Contains(currentShell(), name)
   105  }
   106  
   107  // persistEnvVarWindows sets an environment variable in the Windows
   108  // registry.
   109  func persistEnvVarWindows(name, value string) error {
   110  	_, err := runCommand(context.Background(), "powershell", "-command",
   111  		fmt.Sprintf(`[Environment]::SetEnvironmentVariable("%s", "%s", "User")`, name, value))
   112  	return err
   113  }
   114  
   115  func persistEnvVar(name, value string) error {
   116  	if runtime.GOOS == "windows" {
   117  		if err := persistEnvVarWindows(name, value); err != nil {
   118  			return err
   119  		}
   120  
   121  		if isShell("cmd.exe") || isShell("powershell.exe") {
   122  			return os.Setenv(strings.ToUpper(name), value)
   123  		}
   124  		// User is in bash, zsh, etc.
   125  		// Also set the environment variable in their shell config.
   126  	}
   127  
   128  	rc, err := shellConfigFile()
   129  	if err != nil {
   130  		return err
   131  	}
   132  
   133  	line := fmt.Sprintf("export %s=%s", strings.ToUpper(name), value)
   134  	if err := appendToFile(rc, line); err != nil {
   135  		return err
   136  	}
   137  
   138  	return os.Setenv(strings.ToUpper(name), value)
   139  }
   140  
   141  func shellConfigFile() (string, error) {
   142  	home, err := getHomeDir()
   143  	if err != nil {
   144  		return "", err
   145  	}
   146  
   147  	switch {
   148  	case isShell("bash"):
   149  		return filepath.Join(home, bashConfig), nil
   150  	case isShell("zsh"):
   151  		return filepath.Join(home, zshConfig), nil
   152  	default:
   153  		return "", fmt.Errorf("%q is not a supported shell", currentShell())
   154  	}
   155  }