github.com/psrajat/prototool@v1.3.0/internal/protoc/downloader.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package protoc
    22  
    23  import (
    24  	"archive/zip"
    25  	"bytes"
    26  	"crypto/sha512"
    27  	"encoding/base64"
    28  	"fmt"
    29  	"io/ioutil"
    30  	"net/http"
    31  	"os"
    32  	"os/exec"
    33  	"path/filepath"
    34  	"runtime"
    35  	"strings"
    36  	"sync"
    37  
    38  	"github.com/uber/prototool/internal/file"
    39  	"github.com/uber/prototool/internal/settings"
    40  	"github.com/uber/prototool/internal/vars"
    41  	"go.uber.org/multierr"
    42  	"go.uber.org/zap"
    43  )
    44  
    45  type downloader struct {
    46  	lock sync.RWMutex
    47  
    48  	logger    *zap.Logger
    49  	cachePath string
    50  	protocURL string
    51  	config    settings.Config
    52  
    53  	// the looked-up and verified to exist base path
    54  	cachedBasePath string
    55  
    56  	// If set, Prototool will invoke protoc and include
    57  	// the well-known-types, from the configured binPath
    58  	// and wktPath.
    59  	protocBinPath string
    60  	protocWKTPath string
    61  }
    62  
    63  func newDownloader(config settings.Config, options ...DownloaderOption) (*downloader, error) {
    64  	downloader := &downloader{
    65  		config: config,
    66  		logger: zap.NewNop(),
    67  	}
    68  	for _, option := range options {
    69  		option(downloader)
    70  	}
    71  	if downloader.config.Compile.ProtobufVersion == "" {
    72  		downloader.config.Compile.ProtobufVersion = vars.DefaultProtocVersion
    73  	}
    74  	if downloader.protocBinPath != "" || downloader.protocWKTPath != "" {
    75  		if downloader.protocURL != "" {
    76  			return nil, fmt.Errorf("cannot use protoc-url in combination with either protoc-bin-path or protoc-wkt-path")
    77  		}
    78  		if downloader.protocBinPath == "" || downloader.protocWKTPath == "" {
    79  			return nil, fmt.Errorf("both protoc-bin-path and protoc-wkt-path must be set")
    80  		}
    81  		cleanBinPath := filepath.Clean(downloader.protocBinPath)
    82  		if _, err := os.Stat(cleanBinPath); os.IsNotExist(err) {
    83  			return nil, err
    84  		}
    85  		cleanWKTPath := filepath.Clean(downloader.protocWKTPath)
    86  		if _, err := os.Stat(cleanWKTPath); os.IsNotExist(err) {
    87  			return nil, err
    88  		}
    89  		protobufPath := filepath.Join(cleanWKTPath, "google", "protobuf")
    90  		info, err := os.Stat(protobufPath)
    91  		if os.IsNotExist(err) {
    92  			return nil, err
    93  		}
    94  		if !info.IsDir() {
    95  			return nil, fmt.Errorf("%q is not a valid well-known types directory", protobufPath)
    96  		}
    97  		downloader.protocBinPath = cleanBinPath
    98  		downloader.protocWKTPath = cleanWKTPath
    99  	}
   100  	return downloader, nil
   101  }
   102  
   103  func (d *downloader) Download() (string, error) {
   104  	d.lock.RLock()
   105  	cachedBasePath := d.cachedBasePath
   106  	d.lock.RUnlock()
   107  	if cachedBasePath != "" {
   108  		return cachedBasePath, nil
   109  	}
   110  	return d.cache()
   111  }
   112  
   113  func (d *downloader) ProtocPath() (string, error) {
   114  	if d.protocBinPath != "" {
   115  		return d.protocBinPath, nil
   116  	}
   117  	basePath, err := d.Download()
   118  	if err != nil {
   119  		return "", err
   120  	}
   121  	return filepath.Join(basePath, "bin", "protoc"), nil
   122  }
   123  
   124  func (d *downloader) WellKnownTypesIncludePath() (string, error) {
   125  	if d.protocWKTPath != "" {
   126  		return d.protocWKTPath, nil
   127  	}
   128  	basePath, err := d.Download()
   129  	if err != nil {
   130  		return "", err
   131  	}
   132  	return filepath.Join(basePath, "include"), nil
   133  }
   134  
   135  func (d *downloader) Delete() error {
   136  	basePath, err := d.getBasePathNoVersion()
   137  	if err != nil {
   138  		return err
   139  	}
   140  	d.cachedBasePath = ""
   141  	d.logger.Debug("deleting", zap.String("path", basePath))
   142  	return os.RemoveAll(basePath)
   143  }
   144  
   145  func (d *downloader) cache() (string, error) {
   146  	if d.protocBinPath != "" {
   147  		return d.protocBinPath, nil
   148  	}
   149  
   150  	d.lock.Lock()
   151  	defer d.lock.Unlock()
   152  
   153  	basePath, err := d.getBasePath()
   154  	if err != nil {
   155  		return "", err
   156  	}
   157  	if err := d.checkDownloaded(basePath); err != nil {
   158  		if err := d.download(basePath); err != nil {
   159  			return "", err
   160  		}
   161  		if err := d.checkDownloaded(basePath); err != nil {
   162  			return "", err
   163  		}
   164  		d.logger.Debug("protobuf downloaded", zap.String("path", basePath))
   165  	} else {
   166  		d.logger.Debug("protobuf already downloaded", zap.String("path", basePath))
   167  	}
   168  
   169  	d.cachedBasePath = basePath
   170  	return basePath, nil
   171  }
   172  
   173  func (d *downloader) checkDownloaded(basePath string) error {
   174  	buffer := bytes.NewBuffer(nil)
   175  	cmd := exec.Command(filepath.Join(basePath, "bin", "protoc"), "--version")
   176  	cmd.Stdout = buffer
   177  	if err := cmd.Run(); err != nil {
   178  		return err
   179  	}
   180  	if d.protocURL != "" {
   181  		// skip version check since we do not know the version
   182  		return nil
   183  	}
   184  	output := strings.TrimSpace(buffer.String())
   185  	d.logger.Debug("output from protoc --version", zap.String("output", output))
   186  	expected := fmt.Sprintf("libprotoc %s", d.config.Compile.ProtobufVersion)
   187  	if output != expected {
   188  		return fmt.Errorf("expected %s from protoc --version, got %s", expected, output)
   189  	}
   190  	return nil
   191  }
   192  
   193  func (d *downloader) download(basePath string) (retErr error) {
   194  	return d.downloadInternal(basePath, runtime.GOOS, runtime.GOARCH)
   195  }
   196  
   197  func (d *downloader) downloadInternal(basePath string, goos string, goarch string) (retErr error) {
   198  	data, err := d.getDownloadData(goos, goarch)
   199  	if err != nil {
   200  		return err
   201  	}
   202  	// this is a working but hacky unzip
   203  	// there must be a library for this
   204  	// we don't properly copy directories, modification times, etc
   205  	readerAt := bytes.NewReader(data)
   206  	zipReader, err := zip.NewReader(readerAt, int64(len(data)))
   207  	if err != nil {
   208  		return err
   209  	}
   210  	for _, file := range zipReader.File {
   211  		fileMode := file.Mode()
   212  		d.logger.Debug("found protobuf file in zip", zap.String("fileName", file.Name), zap.Any("fileMode", fileMode))
   213  		if fileMode.IsDir() {
   214  			continue
   215  		}
   216  		readCloser, err := file.Open()
   217  		if err != nil {
   218  			return err
   219  		}
   220  		defer func() {
   221  			retErr = multierr.Append(retErr, readCloser.Close())
   222  		}()
   223  		fileData, err := ioutil.ReadAll(readCloser)
   224  		if err != nil {
   225  			return err
   226  		}
   227  		writeFilePath := filepath.Join(basePath, file.Name)
   228  		if err := os.MkdirAll(filepath.Dir(writeFilePath), 0755); err != nil {
   229  			return err
   230  		}
   231  		writeFile, err := os.OpenFile(writeFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode)
   232  		if err != nil {
   233  			return err
   234  		}
   235  		defer func() {
   236  			retErr = multierr.Append(retErr, writeFile.Close())
   237  		}()
   238  		if _, err := writeFile.Write(fileData); err != nil {
   239  			return err
   240  		}
   241  		d.logger.Debug("wrote protobuf file", zap.String("path", writeFilePath))
   242  	}
   243  	return nil
   244  }
   245  
   246  func (d *downloader) getDownloadData(goos string, goarch string) (_ []byte, retErr error) {
   247  	url, err := d.getProtocURL(goos, goarch)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	defer func() {
   252  		if retErr == nil {
   253  			d.logger.Debug("downloaded protobuf zip file", zap.String("url", url))
   254  		}
   255  	}()
   256  
   257  	switch {
   258  	case strings.HasPrefix(url, "file://"):
   259  		return ioutil.ReadFile(strings.TrimPrefix(url, "file://"))
   260  	case strings.HasPrefix(url, "http://"), strings.HasPrefix(url, "https://"):
   261  		response, err := http.Get(url)
   262  		if err != nil || response.StatusCode != http.StatusOK {
   263  			// if there is not given protocURL, we tried to
   264  			// download this from GitHub Releases, so add
   265  			// extra context to the error message
   266  			if d.protocURL == "" {
   267  				return nil, fmt.Errorf("error downloading %s: %v\nMake sure GitHub Releases has a proper protoc zip file of the form protoc-VERSION-OS-ARCH.zip at https://github.com/protocolbuffers/protobuf/releases/v%s\nNote that many micro versions do not have this, and no version before 3.0.0-beta-2 has this", url, err, d.config.Compile.ProtobufVersion)
   268  			}
   269  			return nil, err
   270  		}
   271  		defer func() {
   272  			if response.Body != nil {
   273  				retErr = multierr.Append(retErr, response.Body.Close())
   274  			}
   275  		}()
   276  		return ioutil.ReadAll(response.Body)
   277  	default:
   278  		return nil, fmt.Errorf("unknown url, can only handle http, https, file: %s", url)
   279  	}
   280  
   281  }
   282  
   283  func (d *downloader) getProtocURL(goos string, goarch string) (string, error) {
   284  	if d.protocURL != "" {
   285  		return d.protocURL, nil
   286  	}
   287  	_, unameM, err := getUnameSUnameMPaths(goos, goarch)
   288  	if err != nil {
   289  		return "", err
   290  	}
   291  	protocS, err := getProtocSPath(goos)
   292  	if err != nil {
   293  		return "", err
   294  	}
   295  	return fmt.Sprintf(
   296  		"https://github.com/protocolbuffers/protobuf/releases/download/v%s/protoc-%s-%s-%s.zip",
   297  		d.config.Compile.ProtobufVersion,
   298  		d.config.Compile.ProtobufVersion,
   299  		protocS,
   300  		unameM,
   301  	), nil
   302  }
   303  
   304  func (d *downloader) getBasePath() (string, error) {
   305  	basePathNoVersion, err := d.getBasePathNoVersion()
   306  	if err != nil {
   307  		return "", err
   308  	}
   309  	return filepath.Join(basePathNoVersion, d.getBasePathVersionPart()), nil
   310  }
   311  
   312  func (d *downloader) getBasePathNoVersion() (string, error) {
   313  	basePath := d.cachePath
   314  	var err error
   315  	if basePath == "" {
   316  		basePath, err = getDefaultBasePath()
   317  		if err != nil {
   318  			return "", err
   319  		}
   320  	} else {
   321  		basePath, err = file.AbsClean(basePath)
   322  		if err != nil {
   323  			return "", err
   324  		}
   325  	}
   326  	if err := file.CheckAbs(basePath); err != nil {
   327  		return "", err
   328  	}
   329  	return filepath.Join(basePath, "protobuf"), nil
   330  }
   331  
   332  func (d *downloader) getBasePathVersionPart() string {
   333  	if d.protocURL != "" {
   334  		// we don't know the version or what is going on here
   335  		hash := sha512.New()
   336  		_, _ = hash.Write([]byte(d.protocURL))
   337  		return base64.URLEncoding.EncodeToString(hash.Sum(nil))
   338  	}
   339  	return d.config.Compile.ProtobufVersion
   340  }
   341  
   342  func getDefaultBasePath() (string, error) {
   343  	return getDefaultBasePathInternal(runtime.GOOS, runtime.GOARCH, os.Getenv)
   344  }
   345  
   346  func getDefaultBasePathInternal(goos string, goarch string, getenvFunc func(string) string) (string, error) {
   347  	unameS, unameM, err := getUnameSUnameMPaths(goos, goarch)
   348  	if err != nil {
   349  		return "", err
   350  	}
   351  	xdgCacheHome := getenvFunc("XDG_CACHE_HOME")
   352  	if xdgCacheHome != "" {
   353  		return filepath.Join(xdgCacheHome, "prototool", unameS, unameM), nil
   354  	}
   355  	home := getenvFunc("HOME")
   356  	if home == "" {
   357  		return "", fmt.Errorf("HOME is not set")
   358  	}
   359  	switch unameS {
   360  	case "Darwin":
   361  		return filepath.Join(home, "Library", "Caches", "prototool", unameS, unameM), nil
   362  	case "Linux":
   363  		return filepath.Join(home, ".cache", "prototool", unameS, unameM), nil
   364  	default:
   365  		return "", fmt.Errorf("invalid value for uname -s: %v", unameS)
   366  	}
   367  }
   368  
   369  func getProtocSPath(goos string) (string, error) {
   370  	switch goos {
   371  	case "darwin":
   372  		return "osx", nil
   373  	case "linux":
   374  		return "linux", nil
   375  	default:
   376  		return "", fmt.Errorf("unsupported value for runtime.GOOS: %v", goos)
   377  	}
   378  }
   379  
   380  func getUnameSUnameMPaths(goos string, goarch string) (string, string, error) {
   381  	var unameS string
   382  	switch goos {
   383  	case "darwin":
   384  		unameS = "Darwin"
   385  	case "linux":
   386  		unameS = "Linux"
   387  	default:
   388  		return "", "", fmt.Errorf("unsupported value for runtime.GOOS: %v", goos)
   389  	}
   390  	var unameM string
   391  	switch goarch {
   392  	case "amd64":
   393  		unameM = "x86_64"
   394  	default:
   395  		return "", "", fmt.Errorf("unsupported value for runtime.GOARCH: %v", goarch)
   396  	}
   397  	return unameS, unameM, nil
   398  }