github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/uzip/uzip.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package uzip contains utilities for file system->zip and zip->file system conversions.
     6  package uzip
     7  
     8  import (
     9  	"archive/zip"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"os"
    14  	"path/filepath"
    15  
    16  	"github.com/mvdan/u-root-coreutils/pkg/upath"
    17  )
    18  
    19  // ToZip packs the all files at dir to a zip archive at dest.
    20  func ToZip(dir, dest, comment string) (reterr error) {
    21  	if info, err := os.Stat(dir); err != nil {
    22  		return err
    23  	} else if !info.IsDir() {
    24  		return fmt.Errorf("%s is not a directory", dir)
    25  	}
    26  	archive, err := os.Create(dest)
    27  	if err != nil {
    28  		return err
    29  	}
    30  	defer func() {
    31  		if err := archive.Close(); err != nil && reterr == nil {
    32  			reterr = err
    33  		}
    34  	}()
    35  
    36  	z := zip.NewWriter(archive)
    37  	defer func() {
    38  		if comment != "" {
    39  			z.SetComment(comment)
    40  		}
    41  		if err := z.Close(); err != nil && reterr == nil {
    42  			reterr = err
    43  		}
    44  	}()
    45  
    46  	return writeDir(dir, z)
    47  }
    48  
    49  // AppendZip packs the all files at dir to a zip archive at dest.
    50  func AppendZip(dir, dest, comment string) (reterr error) {
    51  	if info, err := os.Stat(dir); err != nil {
    52  		return err
    53  	} else if !info.IsDir() {
    54  		return fmt.Errorf("%s is not a directory", dir)
    55  	}
    56  
    57  	archive, err := os.OpenFile(dest, os.O_RDWR, 0)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	defer func() {
    62  		if err := archive.Close(); err != nil && reterr == nil {
    63  			reterr = err
    64  		}
    65  	}()
    66  
    67  	// Go to the end of the file because we are appending.
    68  	end, err := archive.Seek(0, os.SEEK_END)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	z := zip.NewWriter(archive)
    74  	z.SetOffset(end)
    75  	defer func() {
    76  		if comment != "" {
    77  			z.SetComment(comment)
    78  		}
    79  		if err := z.Close(); err != nil && reterr == nil {
    80  			reterr = err
    81  		}
    82  	}()
    83  
    84  	return writeDir(dir, z)
    85  }
    86  
    87  func writeDir(dir string, z *zip.Writer) error {
    88  	return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
    89  		if err != nil {
    90  			return err
    91  		}
    92  
    93  		// do not include srcDir into archive
    94  		if info.Name() == filepath.Base(dir) {
    95  			return nil
    96  		}
    97  
    98  		header, err := zip.FileInfoHeader(info)
    99  		if err != nil {
   100  			return err
   101  		}
   102  
   103  		// adjust header.Name to preserve folder structure
   104  		header.Name, err = filepath.Rel(dir, path)
   105  		if err != nil {
   106  			return err
   107  		}
   108  
   109  		if info.IsDir() {
   110  			header.Name += "/"
   111  		} else {
   112  			header.Method = zip.Deflate
   113  		}
   114  
   115  		writer, err := z.CreateHeader(header)
   116  		if err != nil {
   117  			return err
   118  		}
   119  
   120  		if info.IsDir() {
   121  			return nil
   122  		}
   123  
   124  		file, err := os.Open(path)
   125  		if err != nil {
   126  			return err
   127  		}
   128  		defer file.Close()
   129  		_, err = io.Copy(writer, file)
   130  		return err
   131  	})
   132  }
   133  
   134  // Comment returns the comment from the zip file.
   135  func Comment(file string) (string, error) {
   136  	z, err := zip.OpenReader(file)
   137  	if err != nil {
   138  		return "", err
   139  	}
   140  	defer z.Close()
   141  	return z.Comment, nil
   142  }
   143  
   144  // FromZip extracts the zip archive at src to dir.
   145  func FromZip(src, dir string) error {
   146  	z, err := zip.OpenReader(src)
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	if err = os.MkdirAll(dir, 0o755); err != nil {
   152  		return err
   153  	}
   154  
   155  	for _, file := range z.File {
   156  		path, err := upath.SafeFilepathJoin(dir, file.Name)
   157  		if err != nil {
   158  			// The behavior is to skip files which are unsafe due to
   159  			// zipslip, but continue extracting everything else.
   160  			log.Printf("Warning: Skipping file %q due to: %v", file.Name, err)
   161  			continue
   162  		}
   163  
   164  		if file.FileInfo().IsDir() {
   165  			if err = os.MkdirAll(path, file.Mode()); err != nil {
   166  				return err
   167  			}
   168  			continue
   169  		}
   170  
   171  		fileReader, err := file.Open()
   172  		if err != nil {
   173  			return err
   174  		}
   175  
   176  		targetFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
   177  		if err != nil {
   178  			return err
   179  		}
   180  
   181  		if _, err := io.Copy(targetFile, fileReader); err != nil {
   182  			return err
   183  		}
   184  
   185  		if err = fileReader.Close(); err != nil {
   186  			return err
   187  		}
   188  
   189  		if err = targetFile.Close(); err != nil {
   190  			return err
   191  		}
   192  	}
   193  
   194  	return nil
   195  }