github.com/jingweno/gh@v2.1.1-0.20221007190738-04a7985fa9a1+incompatible/commands/updater.go (about)

     1  package commands
     2  
     3  import (
     4  	"archive/zip"
     5  	"fmt"
     6  	goupdate "github.com/inconshreveable/go-update"
     7  	"github.com/jingweno/gh/git"
     8  	"github.com/jingweno/gh/github"
     9  	"github.com/jingweno/gh/utils"
    10  	"io"
    11  	"io/ioutil"
    12  	"math/rand"
    13  	"net/http"
    14  	"os"
    15  	"path/filepath"
    16  	"runtime"
    17  	"strings"
    18  	"time"
    19  )
    20  
    21  const ghAutoUpdateConfig = "gh.autoUpdate"
    22  
    23  func NewUpdater() *Updater {
    24  	version := os.Getenv("GH_VERSION")
    25  	if version == "" {
    26  		version = Version
    27  	}
    28  
    29  	timestampPath := filepath.Join(os.Getenv("HOME"), ".config", "gh-update")
    30  	return &Updater{
    31  		Host:           github.DefaultHost(),
    32  		CurrentVersion: version,
    33  		timestampPath:  timestampPath,
    34  	}
    35  }
    36  
    37  type Updater struct {
    38  	Host           string
    39  	CurrentVersion string
    40  	timestampPath  string
    41  }
    42  
    43  func (updater *Updater) timeToUpdate() bool {
    44  	if updater.CurrentVersion == "dev" || readTime(updater.timestampPath).After(time.Now()) {
    45  		return false
    46  	}
    47  
    48  	// the next update is in about 14 days
    49  	wait := 13*24*time.Hour + randDuration(24*time.Hour)
    50  	return writeTime(updater.timestampPath, time.Now().Add(wait))
    51  }
    52  
    53  func (updater *Updater) PromptForUpdate() (err error) {
    54  	if !updater.timeToUpdate() {
    55  		return
    56  	}
    57  
    58  	releaseName, version := updater.latestReleaseNameAndVersion()
    59  	if version != "" && version != updater.CurrentVersion {
    60  		switch autoUpdateConfig() {
    61  		case "always":
    62  			err = updater.updateTo(releaseName, version)
    63  		case "never":
    64  			return
    65  		default:
    66  			fmt.Println("There is a newer version of gh available.")
    67  			fmt.Print("Would you like to update? ([Y]es/[N]o/[A]lways/N[e]ver): ")
    68  			var confirm string
    69  			fmt.Scan(&confirm)
    70  
    71  			always := utils.IsOption(confirm, "a", "always")
    72  			if always || utils.IsOption(confirm, "y", "yes") {
    73  				err = updater.updateTo(releaseName, version)
    74  			}
    75  
    76  			saveAutoUpdateConfiguration(confirm, always)
    77  		}
    78  	}
    79  
    80  	return
    81  }
    82  
    83  func (updater *Updater) Update() (err error) {
    84  	releaseName, version := updater.latestReleaseNameAndVersion()
    85  	if version == "" {
    86  		fmt.Println("There is no newer version of gh available.")
    87  		return
    88  	}
    89  
    90  	if version == updater.CurrentVersion {
    91  		fmt.Printf("You're already on the latest version: %s\n", version)
    92  	} else {
    93  		err = updater.updateTo(releaseName, version)
    94  	}
    95  
    96  	return
    97  }
    98  
    99  func (updater *Updater) latestReleaseNameAndVersion() (name, version string) {
   100  	// Create Client with a stub Credentials
   101  	c := github.Client{Credentials: &github.Credentials{Host: updater.Host}}
   102  	name, _ = c.GhLatestTagName()
   103  	version = strings.TrimPrefix(name, "v")
   104  
   105  	return
   106  }
   107  
   108  func (updater *Updater) updateTo(releaseName, version string) (err error) {
   109  	fmt.Printf("Updating gh to %s...\n", version)
   110  	downloadURL := fmt.Sprintf("https://%s/jingweno/gh/releases/download/%s/gh_%s_%s_%s.zip", updater.Host, releaseName, version, runtime.GOOS, runtime.GOARCH)
   111  	path, err := downloadFile(downloadURL)
   112  	if err != nil {
   113  		return
   114  	}
   115  
   116  	exec, err := unzipExecutable(path)
   117  	if err != nil {
   118  		return
   119  	}
   120  
   121  	err, _ = goupdate.FromFile(exec)
   122  	if err == nil {
   123  		fmt.Println("Done!")
   124  	}
   125  
   126  	return
   127  }
   128  
   129  func unzipExecutable(path string) (exec string, err error) {
   130  	rc, err := zip.OpenReader(path)
   131  	if err != nil {
   132  		err = fmt.Errorf("Can't open zip file %s: %s", path, err)
   133  		return
   134  	}
   135  	defer rc.Close()
   136  
   137  	for _, file := range rc.File {
   138  		if !strings.HasPrefix(file.Name, "gh") {
   139  			continue
   140  		}
   141  
   142  		dir := filepath.Dir(path)
   143  		exec, err = unzipFile(file, dir)
   144  		break
   145  	}
   146  
   147  	if exec == "" && err == nil {
   148  		err = fmt.Errorf("No gh executable is found in %s", path)
   149  	}
   150  
   151  	return
   152  }
   153  
   154  func unzipFile(file *zip.File, to string) (exec string, err error) {
   155  	frc, err := file.Open()
   156  	if err != nil {
   157  		err = fmt.Errorf("Can't open zip entry %s when reading: %s", file.Name, err)
   158  		return
   159  	}
   160  	defer frc.Close()
   161  
   162  	dest := filepath.Join(to, filepath.Base(file.Name))
   163  	f, err := os.Create(dest)
   164  	if err != nil {
   165  		return
   166  	}
   167  	defer f.Close()
   168  
   169  	copied, err := io.Copy(f, frc)
   170  	if err != nil {
   171  		return
   172  	}
   173  
   174  	if uint32(copied) != file.UncompressedSize {
   175  		err = fmt.Errorf("Zip entry %s is corrupted", file.Name)
   176  		return
   177  	}
   178  
   179  	exec = f.Name()
   180  
   181  	return
   182  }
   183  
   184  func downloadFile(url string) (path string, err error) {
   185  	dir, err := ioutil.TempDir("", "gh-update")
   186  	if err != nil {
   187  		return
   188  	}
   189  
   190  	resp, err := http.Get(url)
   191  	if err != nil {
   192  		return
   193  	}
   194  	defer resp.Body.Close()
   195  
   196  	if resp.StatusCode >= 300 || resp.StatusCode < 200 {
   197  		err = fmt.Errorf("Can't download %s: %d", url, resp.StatusCode)
   198  		return
   199  	}
   200  
   201  	file, err := os.Create(filepath.Join(dir, filepath.Base(url)))
   202  	if err != nil {
   203  		return
   204  	}
   205  	defer file.Close()
   206  
   207  	_, err = io.Copy(file, resp.Body)
   208  	if err != nil {
   209  		return
   210  	}
   211  
   212  	path = file.Name()
   213  
   214  	return
   215  }
   216  
   217  func randDuration(n time.Duration) time.Duration {
   218  	return time.Duration(rand.Int63n(int64(n)))
   219  }
   220  
   221  func readTime(path string) time.Time {
   222  	p, err := ioutil.ReadFile(path)
   223  	if os.IsNotExist(err) {
   224  		return time.Time{}
   225  	}
   226  	if err != nil {
   227  		return time.Now().Add(1000 * time.Hour)
   228  	}
   229  
   230  	t, err := time.Parse(time.RFC3339, strings.TrimSpace(string(p)))
   231  	if err != nil {
   232  		return time.Time{}
   233  	}
   234  
   235  	return t
   236  }
   237  
   238  func writeTime(path string, t time.Time) bool {
   239  	return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil
   240  }
   241  
   242  func saveAutoUpdateConfiguration(confirm string, always bool) {
   243  	if always {
   244  		git.SetGlobalConfig(ghAutoUpdateConfig, "always")
   245  	} else if utils.IsOption(confirm, "e", "never") {
   246  		git.SetGlobalConfig(ghAutoUpdateConfig, "never")
   247  	}
   248  }
   249  
   250  func autoUpdateConfig() (opt string) {
   251  	opt = os.Getenv("GH_AUTOUPDATE")
   252  	if opt == "" {
   253  		opt, _ = git.GlobalConfig(ghAutoUpdateConfig)
   254  	}
   255  
   256  	return
   257  }