blake.io/pqx@v0.2.2-0.20231231055241-83f2254c0a07/internal/fetch/fetch.go (about)

     1  package fetch
     2  
     3  import (
     4  	"archive/tar"
     5  	"archive/zip"
     6  	"bytes"
     7  	"context"
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"runtime"
    15  	"strings"
    16  
    17  	"github.com/xi2/xz"
    18  	"kr.dev/errorfmt"
    19  )
    20  
    21  var envCacheDir = os.Getenv("PQX_BIN_DIR")
    22  
    23  func BinaryURL(version string) string {
    24  	const fetchURLTempl = "https://repo1.maven.org/maven2/io/zonky/test/postgres/embedded-postgres-binaries-$OS-$ARCH/$VERSION/embedded-postgres-binaries-$OS-$ARCH-$VERSION.jar"
    25  
    26  	// TODO(bmizerany): validate version
    27  	return strings.NewReplacer(
    28  		"$OS", getOS(),
    29  		"$ARCH", getArch(),
    30  		"$VERSION", version,
    31  	).Replace(fetchURLTempl)
    32  }
    33  
    34  func pgDir(version string) string {
    35  	cacheDir := envCacheDir
    36  	if cacheDir == "" {
    37  		cacheDir = os.Getenv("HOME")
    38  		if cacheDir == "" {
    39  			cacheDir, _ = os.UserHomeDir()
    40  		}
    41  		if cacheDir == "" {
    42  			panic("no HOME; try setting PQX_BIN_DIR instead")
    43  		}
    44  	}
    45  	return filepath.Join(cacheDir, ".cache/pqx", version)
    46  }
    47  
    48  func Binary(ctx context.Context, version string) (binDir string, err error) {
    49  	defer errorfmt.Handlef("fetchBinary: %w", &err)
    50  
    51  	dir := pgDir(version)
    52  	if err := os.MkdirAll(dir, 0755); err != nil {
    53  		return "", err
    54  	}
    55  
    56  	binDir = path.Join(dir, "bin")
    57  	_, err = os.Stat(binDir)
    58  	if err == nil {
    59  		// already cached
    60  		// TODO(bmizerany): validate the dir has what we think it has?
    61  		return binDir, nil
    62  	}
    63  
    64  	binURL := BinaryURL(version)
    65  	defer errorfmt.Handlef("%s: %w", binURL, &err)
    66  
    67  	req, err := http.NewRequestWithContext(ctx, "GET", binURL, nil)
    68  	if err != nil {
    69  		return "", err
    70  	}
    71  	res, err := http.DefaultClient.Do(req)
    72  	if err != nil {
    73  		return "", err
    74  	}
    75  	defer res.Body.Close()
    76  
    77  	if err := extractJar(ctx, dir, res.Body); err != nil {
    78  		return "", err
    79  	}
    80  
    81  	return binDir, nil
    82  }
    83  
    84  func extractJar(ctx context.Context, dir string, r io.Reader) (err error) {
    85  	defer errorfmt.Handlef("extractJar: %w", &err)
    86  
    87  	buf, size, err := slurp(r)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	zr, err := zip.NewReader(buf, size)
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	for _, f := range zr.File {
    98  		matched, err := path.Match("postgres-*.txz", f.Name)
    99  		if err != nil {
   100  			return err
   101  		}
   102  		if !matched {
   103  			continue
   104  		}
   105  
   106  		o, err := f.Open()
   107  		if err != nil {
   108  			return err
   109  		}
   110  		defer o.Close()
   111  		return extractTxn(ctx, dir, o)
   112  	}
   113  
   114  	return errors.New("no postgres-*.txz found in archive")
   115  }
   116  
   117  func extractTxn(ctx context.Context, dir string, r io.Reader) (err error) {
   118  	defer errorfmt.Handlef("extractTxn: %w", &err)
   119  
   120  	xr, err := xz.NewReader(r, 0)
   121  	if err != nil {
   122  		return err
   123  	}
   124  	tr := tar.NewReader(xr)
   125  	for {
   126  		h, err := tr.Next()
   127  		if err == io.EOF {
   128  			break
   129  		}
   130  		if err != nil {
   131  			return err
   132  		}
   133  
   134  		name := filepath.Join(dir, h.Name)
   135  		if err := os.Mkdir(filepath.Dir(name), 0755); err != nil {
   136  			if !os.IsExist(err) {
   137  				return err
   138  			}
   139  		}
   140  
   141  		switch h.Typeflag {
   142  		case tar.TypeReg:
   143  			f, err := os.OpenFile(name, os.O_CREATE|os.O_RDWR, os.FileMode(h.Mode))
   144  			if err != nil {
   145  				return err
   146  			}
   147  			if _, err := io.Copy(f, tr); err != nil {
   148  				f.Close()
   149  				return err
   150  			}
   151  			if err := f.Close(); err != nil {
   152  				return err
   153  			}
   154  		case tar.TypeSymlink:
   155  			if err := os.RemoveAll(name); err != nil {
   156  				return err
   157  			}
   158  			if err := os.Symlink(h.Linkname, name); err != nil {
   159  				return err
   160  			}
   161  		}
   162  	}
   163  	return nil
   164  }
   165  
   166  func getOS() string {
   167  	goos := runtime.GOOS
   168  	_, err := os.Stat("/etc/alpine-release")
   169  	if os.IsExist(err) {
   170  		return goos + "-alpine"
   171  	}
   172  	return goos
   173  }
   174  
   175  // TODO(bmizerany): Add support for 32bit machines?
   176  var archLookup = map[string]string{
   177  	"amd":   "amd64",
   178  	"arm64": "arm64v8",
   179  	"ppc64": "ppc64le",
   180  }
   181  
   182  func getArch() string {
   183  	goarch := runtime.GOARCH
   184  	if arch := archLookup[goarch]; arch != "" {
   185  		return arch
   186  	}
   187  	return runtime.GOARCH
   188  }
   189  
   190  func slurp(r io.Reader) (*bytes.Reader, int64, error) {
   191  	data, err := io.ReadAll(r)
   192  	if err != nil {
   193  		return nil, 0, err
   194  	}
   195  	return bytes.NewReader(data), int64(len(data)), nil
   196  }