github.com/0xKiwi/rules_go@v0.24.3/go/tools/builders/embed.go (about)

     1  // Copyright 2017 The Bazel Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // embed generates a .go file from the contents of a list of data files. It is
    16  // invoked by go_embed_data as an action.
    17  package main
    18  
    19  import (
    20  	"archive/tar"
    21  	"archive/zip"
    22  	"bufio"
    23  	"errors"
    24  	"flag"
    25  	"fmt"
    26  	"io"
    27  	"log"
    28  	"os"
    29  	"path"
    30  	"path/filepath"
    31  	"strconv"
    32  	"strings"
    33  	"text/template"
    34  	"unicode/utf8"
    35  )
    36  
    37  var headerTpl = template.Must(template.New("embed").Parse(`// Generated by go_embed_data for {{.Label}}. DO NOT EDIT.
    38  
    39  package {{.Package}}
    40  
    41  `))
    42  
    43  var multiFooterTpl = template.Must(template.New("embed").Parse(`
    44  var {{.Var}} = map[string]{{.Type}}{
    45  {{- range $i, $f := .FoundSources}}
    46  	{{$.Key $f}}: {{$.Var}}_{{$i}},
    47  {{- end}}
    48  }
    49  
    50  `))
    51  
    52  func main() {
    53  	log.SetPrefix("embed: ")
    54  	log.SetFlags(0) // don't print timestamps
    55  	if err := run(os.Args); err != nil {
    56  		log.Fatal(err)
    57  	}
    58  }
    59  
    60  type configuration struct {
    61  	Label, Package, Var      string
    62  	Multi                    bool
    63  	sources                  []string
    64  	FoundSources             []string
    65  	out, workspace           string
    66  	flatten, unpack, strData bool
    67  }
    68  
    69  func (c *configuration) Type() string {
    70  	if c.strData {
    71  		return "string"
    72  	} else {
    73  		return "[]byte"
    74  	}
    75  }
    76  
    77  func (c *configuration) Key(filename string) string {
    78  	workspacePrefix := "external/" + c.workspace + "/"
    79  	key := filepath.FromSlash(strings.TrimPrefix(filename, workspacePrefix))
    80  	if c.flatten {
    81  		key = path.Base(filename)
    82  	}
    83  	return strconv.Quote(key)
    84  }
    85  
    86  func run(args []string) error {
    87  	c, err := newConfiguration(args)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	f, err := os.Create(c.out)
    93  	if err != nil {
    94  		return err
    95  	}
    96  	defer f.Close()
    97  	w := bufio.NewWriter(f)
    98  	defer w.Flush()
    99  
   100  	if err := headerTpl.Execute(w, c); err != nil {
   101  		return err
   102  	}
   103  
   104  	if c.Multi {
   105  		return embedMultipleFiles(c, w)
   106  	}
   107  	return embedSingleFile(c, w)
   108  }
   109  
   110  func newConfiguration(args []string) (*configuration, error) {
   111  	var c configuration
   112  	flags := flag.NewFlagSet("embed", flag.ExitOnError)
   113  	flags.StringVar(&c.Label, "label", "", "Label of the rule being executed (required)")
   114  	flags.StringVar(&c.Package, "package", "", "Go package name (required)")
   115  	flags.StringVar(&c.Var, "var", "", "Variable name (required)")
   116  	flags.BoolVar(&c.Multi, "multi", false, "Whether the variable is a map or a single value")
   117  	flags.StringVar(&c.out, "out", "", "Go file to generate (required)")
   118  	flags.StringVar(&c.workspace, "workspace", "", "Name of the workspace (required)")
   119  	flags.BoolVar(&c.flatten, "flatten", false, "Whether to access files by base name")
   120  	flags.BoolVar(&c.strData, "string", false, "Whether to store contents as strings")
   121  	flags.BoolVar(&c.unpack, "unpack", false, "Whether to treat files as archives to unpack.")
   122  	flags.Parse(args[1:])
   123  	if c.Label == "" {
   124  		return nil, errors.New("error: -label option not provided")
   125  	}
   126  	if c.Package == "" {
   127  		return nil, errors.New("error: -package option not provided")
   128  	}
   129  	if c.Var == "" {
   130  		return nil, errors.New("error: -var option not provided")
   131  	}
   132  	if c.out == "" {
   133  		return nil, errors.New("error: -out option not provided")
   134  	}
   135  	if c.workspace == "" {
   136  		return nil, errors.New("error: -workspace option not provided")
   137  	}
   138  	c.sources = flags.Args()
   139  	if !c.Multi && len(c.sources) != 1 {
   140  		return nil, fmt.Errorf("error: -multi flag not given, so want exactly one source; got %d", len(c.sources))
   141  	}
   142  	if c.unpack {
   143  		if !c.Multi {
   144  			return nil, errors.New("error: -multi flag is required for -unpack mode.")
   145  		}
   146  		for _, src := range c.sources {
   147  			if ext := filepath.Ext(src); ext != ".zip" && ext != ".tar" {
   148  				return nil, fmt.Errorf("error: -unpack flag expects .zip or .tar extension (got %q)", ext)
   149  			}
   150  		}
   151  	}
   152  	return &c, nil
   153  }
   154  
   155  func embedSingleFile(c *configuration, w io.Writer) error {
   156  	dataBegin, dataEnd := "\"", "\"\n"
   157  	if !c.strData {
   158  		dataBegin, dataEnd = "[]byte(\"", "\")\n"
   159  	}
   160  
   161  	if _, err := fmt.Fprintf(w, "var %s = %s", c.Var, dataBegin); err != nil {
   162  		return err
   163  	}
   164  	if err := embedFileContents(w, c.sources[0]); err != nil {
   165  		return err
   166  	}
   167  	_, err := fmt.Fprint(w, dataEnd)
   168  	return err
   169  }
   170  
   171  func embedMultipleFiles(c *configuration, w io.Writer) error {
   172  	dataBegin, dataEnd := "\"", "\"\n"
   173  	if !c.strData {
   174  		dataBegin, dataEnd = "[]byte(\"", "\")\n"
   175  	}
   176  
   177  	if _, err := fmt.Fprint(w, "var (\n"); err != nil {
   178  		return err
   179  	}
   180  	if err := findSources(c, func(i int, f io.Reader) error {
   181  		if _, err := fmt.Fprintf(w, "\t%s_%d = %s", c.Var, i, dataBegin); err != nil {
   182  			return err
   183  		}
   184  		if _, err := io.Copy(&escapeWriter{w}, f); err != nil {
   185  			return err
   186  		}
   187  		if _, err := fmt.Fprint(w, dataEnd); err != nil {
   188  			return err
   189  		}
   190  		return nil
   191  	}); err != nil {
   192  		return err
   193  	}
   194  	if _, err := fmt.Fprint(w, ")\n"); err != nil {
   195  		return err
   196  	}
   197  	if err := multiFooterTpl.Execute(w, c); err != nil {
   198  		return err
   199  	}
   200  	return nil
   201  }
   202  
   203  func findSources(c *configuration, cb func(i int, f io.Reader) error) error {
   204  	if c.unpack {
   205  		for _, filename := range c.sources {
   206  			ext := filepath.Ext(filename)
   207  			if ext == ".zip" {
   208  				if err := findZipSources(c, filename, cb); err != nil {
   209  					return err
   210  				}
   211  			} else if ext == ".tar" {
   212  				if err := findTarSources(c, filename, cb); err != nil {
   213  					return err
   214  				}
   215  			} else {
   216  				panic("unknown archive extension: " + ext)
   217  			}
   218  		}
   219  		return nil
   220  	}
   221  	for _, filename := range c.sources {
   222  		f, err := os.Open(filename)
   223  		if err != nil {
   224  			return err
   225  		}
   226  		err = cb(len(c.FoundSources), bufio.NewReader(f))
   227  		f.Close()
   228  		if err != nil {
   229  			return err
   230  		}
   231  		c.FoundSources = append(c.FoundSources, filename)
   232  	}
   233  	return nil
   234  }
   235  
   236  func findZipSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
   237  	r, err := zip.OpenReader(filename)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	defer r.Close()
   242  	for _, file := range r.File {
   243  		f, err := file.Open()
   244  		if err != nil {
   245  			return err
   246  		}
   247  		err = cb(len(c.FoundSources), f)
   248  		f.Close()
   249  		if err != nil {
   250  			return err
   251  		}
   252  		c.FoundSources = append(c.FoundSources, file.Name)
   253  	}
   254  	return nil
   255  }
   256  
   257  func findTarSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
   258  	tf, err := os.Open(filename)
   259  	if err != nil {
   260  		return err
   261  	}
   262  	defer tf.Close()
   263  	reader := tar.NewReader(bufio.NewReader(tf))
   264  	for {
   265  		h, err := reader.Next()
   266  		if err == io.EOF {
   267  			return nil
   268  		}
   269  		if err != nil {
   270  			return err
   271  		}
   272  		if h.Typeflag != tar.TypeReg {
   273  			continue
   274  		}
   275  		if err := cb(len(c.FoundSources), &io.LimitedReader{
   276  			R: reader,
   277  			N: h.Size,
   278  		}); err != nil {
   279  			return err
   280  		}
   281  		c.FoundSources = append(c.FoundSources, h.Name)
   282  	}
   283  }
   284  
   285  func embedFileContents(w io.Writer, filename string) error {
   286  	f, err := os.Open(filename)
   287  	if err != nil {
   288  		return err
   289  	}
   290  	defer f.Close()
   291  
   292  	_, err = io.Copy(&escapeWriter{w}, bufio.NewReader(f))
   293  	return err
   294  }
   295  
   296  type escapeWriter struct {
   297  	w io.Writer
   298  }
   299  
   300  func (w *escapeWriter) Write(data []byte) (n int, err error) {
   301  	n = len(data)
   302  
   303  	for err == nil && len(data) > 0 {
   304  		// https://golang.org/ref/spec#String_literals: "Within the quotes, any
   305  		// character may appear except newline and unescaped double quote. The
   306  		// text between the quotes forms the value of the literal, with backslash
   307  		// escapes interpreted as they are in rune literals […]."
   308  		switch b := data[0]; b {
   309  		case '\\':
   310  			_, err = w.w.Write([]byte(`\\`))
   311  		case '"':
   312  			_, err = w.w.Write([]byte(`\"`))
   313  		case '\n':
   314  			_, err = w.w.Write([]byte(`\n`))
   315  
   316  		case '\x00':
   317  			// https://golang.org/ref/spec#Source_code_representation: "Implementation
   318  			// restriction: For compatibility with other tools, a compiler may
   319  			// disallow the NUL character (U+0000) in the source text."
   320  			_, err = w.w.Write([]byte(`\x00`))
   321  
   322  		default:
   323  			// https://golang.org/ref/spec#Source_code_representation: "Implementation
   324  			// restriction: […] A byte order mark may be disallowed anywhere else in
   325  			// the source."
   326  			const byteOrderMark = '\uFEFF'
   327  
   328  			if r, size := utf8.DecodeRune(data); r != utf8.RuneError && r != byteOrderMark {
   329  				_, err = w.w.Write(data[:size])
   330  				data = data[size:]
   331  				continue
   332  			}
   333  
   334  			_, err = fmt.Fprintf(w.w, `\x%02x`, b)
   335  		}
   336  		data = data[1:]
   337  	}
   338  
   339  	return n - len(data), err
   340  }