github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/utils/zip/zip.go (about)

     1  // Copyright 2011-2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package zip
     5  
     6  import (
     7  	"archive/zip"
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"strings"
    15  
    16      "launchpad.net/juju-core/utils"
    17  )
    18  
    19  // FindAll returns the cleaned path of every file in the supplied zip reader.
    20  func FindAll(reader *zip.Reader) ([]string, error) {
    21  	return Find(reader, "*")
    22  }
    23  
    24  // Find returns the cleaned path of every file in the supplied zip reader whose
    25  // base name matches the supplied pattern, which is interpreted as in path.Match.
    26  func Find(reader *zip.Reader, pattern string) ([]string, error) {
    27  	// path.Match will only return an error if the pattern is not
    28  	// valid (*and* the supplied name is not empty, hence "check").
    29  	if _, err := path.Match(pattern, "check"); err != nil {
    30  		return nil, err
    31  	}
    32  	var matches []string
    33  	for _, zipFile := range reader.File {
    34  		cleanPath := path.Clean(zipFile.Name)
    35  		baseName := path.Base(cleanPath)
    36  		if match, _ := path.Match(pattern, baseName); match {
    37  			matches = append(matches, cleanPath)
    38  		}
    39  	}
    40  	return matches, nil
    41  }
    42  
    43  // ExtractAll extracts the supplied zip reader to the target path, overwriting
    44  // existing files and directories only where necessary.
    45  func ExtractAll(reader *zip.Reader, targetRoot string) error {
    46  	return Extract(reader, targetRoot, "")
    47  }
    48  
    49  // Extract extracts files from the supplied zip reader, from the (internal, slash-
    50  // separated) source path into the (external, OS-specific) target path. If the
    51  // source path does not reference a directory, the referenced file will be written
    52  // directly to the target path.
    53  func Extract(reader *zip.Reader, targetRoot, sourceRoot string) error {
    54  	sourceRoot = path.Clean(sourceRoot)
    55  	if sourceRoot == "." {
    56  		sourceRoot = ""
    57  	}
    58  	if !isSanePath(sourceRoot) {
    59  		return fmt.Errorf("cannot extract files rooted at %q", sourceRoot)
    60  	}
    61  	extractor := extractor{targetRoot, sourceRoot}
    62  	for _, zipFile := range reader.File {
    63  		if err := extractor.extract(zipFile); err != nil {
    64  			cleanName := path.Clean(zipFile.Name)
    65  			return fmt.Errorf("cannot extract %q: %v", cleanName, err)
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  type extractor struct {
    72  	targetRoot string
    73  	sourceRoot string
    74  }
    75  
    76  // targetPath returns the target path for a given zip file and whether
    77  // it should be extracted.
    78  func (x extractor) targetPath(zipFile *zip.File) (string, bool) {
    79  	cleanPath := path.Clean(zipFile.Name)
    80  	if cleanPath == x.sourceRoot {
    81  		return x.targetRoot, true
    82  	}
    83  	if x.sourceRoot != "" {
    84  		mustPrefix := x.sourceRoot + "/"
    85  		if !strings.HasPrefix(cleanPath, mustPrefix) {
    86  			return "", false
    87  		}
    88  		cleanPath = cleanPath[len(mustPrefix):]
    89  	}
    90  	return filepath.Join(x.targetRoot, filepath.FromSlash(cleanPath)), true
    91  }
    92  
    93  func (x extractor) extract(zipFile *zip.File) error {
    94  	targetPath, ok := x.targetPath(zipFile)
    95  	if !ok {
    96  		return nil
    97  	}
    98  	parentPath := filepath.Dir(targetPath)
    99  	if err := os.MkdirAll(parentPath, 0777); err != nil {
   100  		return err
   101  	}
   102  	mode := zipFile.Mode()
   103  	modePerm := mode & os.ModePerm
   104  	modeType := mode & os.ModeType
   105  	switch modeType {
   106  	case os.ModeDir:
   107  		return x.writeDir(targetPath, modePerm)
   108  	case os.ModeSymlink:
   109  		return x.writeSymlink(targetPath, zipFile)
   110  	case 0:
   111  		return x.writeFile(targetPath, zipFile, modePerm)
   112  	}
   113  	return fmt.Errorf("unknown file type %d", modeType)
   114  }
   115  
   116  func (x extractor) writeDir(targetPath string, modePerm os.FileMode) error {
   117  	fileInfo, err := os.Lstat(targetPath)
   118  	switch {
   119  	case err == nil:
   120  		mode := fileInfo.Mode()
   121  		if mode.IsDir() {
   122  			if mode&os.ModePerm != modePerm {
   123  				return os.Chmod(targetPath, modePerm)
   124  			}
   125  			return nil
   126  		}
   127  		fallthrough
   128  	case !os.IsNotExist(err):
   129  		if err := os.RemoveAll(targetPath); err != nil {
   130  			return err
   131  		}
   132  	}
   133  	return os.MkdirAll(targetPath, modePerm)
   134  }
   135  
   136  func (x extractor) writeFile(targetPath string, zipFile *zip.File, modePerm os.FileMode) error {
   137  	if _, err := os.Lstat(targetPath); !os.IsNotExist(err) {
   138  		if err := os.RemoveAll(targetPath); err != nil {
   139  			return err
   140  		}
   141  	}
   142  	writer, err := os.OpenFile(targetPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, modePerm)
   143  	if err != nil {
   144  		return err
   145  	}
   146  	defer writer.Close()
   147  	return copyTo(writer, zipFile)
   148  }
   149  
   150  func (x extractor) writeSymlink(targetPath string, zipFile *zip.File) error {
   151  	symlinkTarget, err := x.checkSymlink(targetPath, zipFile)
   152  	if err != nil {
   153  		return err
   154  	}
   155  	if _, err := os.Lstat(targetPath); !os.IsNotExist(err) {
   156  		if err := os.RemoveAll(targetPath); err != nil {
   157  			return err
   158  		}
   159  	}
   160  	return utils.Symlink(symlinkTarget, targetPath)
   161  }
   162  
   163  func (x extractor) checkSymlink(targetPath string, zipFile *zip.File) (string, error) {
   164  	var buffer bytes.Buffer
   165  	if err := copyTo(&buffer, zipFile); err != nil {
   166  		return "", err
   167  	}
   168  	symlinkTarget := buffer.String()
   169  	if filepath.IsAbs(symlinkTarget) {
   170  		return "", fmt.Errorf("symlink %q is absolute", symlinkTarget)
   171  	}
   172  	finalPath := filepath.Join(filepath.Dir(targetPath), symlinkTarget)
   173  	relativePath, err := filepath.Rel(x.targetRoot, finalPath)
   174  	if err != nil {
   175  		// Not tested, because I don't know how to trigger this condition.
   176  		return "", fmt.Errorf("symlink %q not comprehensible", symlinkTarget)
   177  	}
   178  	if !isSanePath(relativePath) {
   179  		return "", fmt.Errorf("symlink %q leads out of scope", symlinkTarget)
   180  	}
   181  	return symlinkTarget, nil
   182  }
   183  
   184  func copyTo(writer io.Writer, zipFile *zip.File) error {
   185  	reader, err := zipFile.Open()
   186  	if err != nil {
   187  		return err
   188  	}
   189  	_, err = io.Copy(writer, reader)
   190  	reader.Close()
   191  	return err
   192  }
   193  
   194  func isSanePath(path string) bool {
   195  	if path == ".." || strings.HasPrefix(path, "../") {
   196  		return false
   197  	}
   198  	return true
   199  }