github.com/visualfc/goembed@v0.3.3/embed.go (about)

     1  package goembed
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/printer"
     8  	"go/token"
     9  	"sort"
    10  	"strings"
    11  
    12  	embedparser "github.com/visualfc/goembed/parser"
    13  )
    14  
    15  // Kind is embed var type kind
    16  type Kind int
    17  
    18  const (
    19  	EmbedUnknown Kind = iota
    20  	EmbedBytes
    21  	EmbedString
    22  	EmbedFiles
    23  	EmbedMaybeAlias // may be alias string or []byte
    24  )
    25  
    26  // Embed describes go:embed variable
    27  type Embed struct {
    28  	Name     string
    29  	Kind     Kind
    30  	Patterns []string
    31  	Pos      token.Position
    32  	Spec     *ast.ValueSpec
    33  }
    34  
    35  // embedPos is go:embed start postion
    36  func (e *Embed) embedPos() (pos token.Position) {
    37  	pos = e.Pos
    38  	pos.Column -= 9
    39  	return
    40  }
    41  
    42  type embedPattern struct {
    43  	Patterns string
    44  	Pos      token.Position
    45  }
    46  
    47  // CheckEmbed lookup go:embed vars for embedPatternPos
    48  func CheckEmbed(embedPatternPos map[string][]token.Position, fset *token.FileSet, files []*ast.File) ([]*Embed, error) {
    49  	if len(embedPatternPos) == 0 {
    50  		return nil, nil
    51  	}
    52  	fmap := make(map[string]bool)
    53  	var ep []*embedPattern
    54  	for k, v := range embedPatternPos {
    55  		for _, pos := range v {
    56  			fmap[pos.Filename] = true
    57  			ep = append(ep, &embedPattern{k, pos})
    58  		}
    59  	}
    60  	sort.SliceStable(ep, func(i, j int) bool {
    61  		n := strings.Compare(ep[i].Pos.Filename, ep[j].Pos.Filename)
    62  		if n == 0 {
    63  			return ep[i].Pos.Offset < ep[j].Pos.Offset
    64  		}
    65  		return n < 0
    66  	})
    67  	var eps []*Embed
    68  	last := &Embed{Patterns: []string{ep[0].Patterns}, Pos: ep[0].Pos}
    69  	eps = append(eps, last)
    70  	for i := 1; i < len(ep); i++ {
    71  		e := ep[i]
    72  		if e.Pos.Filename == last.Pos.Filename &&
    73  			(e.Pos.Line == last.Pos.Line || e.Pos.Line == last.Pos.Line+1) {
    74  			last.Patterns = append(last.Patterns, e.Patterns)
    75  			last.Pos = e.Pos
    76  		} else {
    77  			last = &Embed{Patterns: []string{e.Patterns}, Pos: e.Pos}
    78  			eps = append(eps, last)
    79  		}
    80  	}
    81  	for _, file := range files {
    82  		if fmap[fset.Position(file.Package).Filename] {
    83  			err := findEmbed(fset, file, eps)
    84  			if err != nil {
    85  				return nil, err
    86  			}
    87  		}
    88  	}
    89  	for _, e := range eps {
    90  		if e.Spec == nil {
    91  			return nil, fmt.Errorf("%v: misplaced go:embed directive", e.embedPos())
    92  		}
    93  	}
    94  	return eps, nil
    95  }
    96  
    97  func checkIdent(v ast.Expr, name string) bool {
    98  	if ident, ok := v.(*ast.Ident); ok && ident.Name == name {
    99  		return true
   100  	}
   101  	return false
   102  }
   103  
   104  func embedKind(typ ast.Expr, importName string) Kind {
   105  	switch v := typ.(type) {
   106  	case *ast.Ident:
   107  		switch v.Name {
   108  		case "string":
   109  			return EmbedString
   110  		case "FS":
   111  			if importName == "." {
   112  				return EmbedFiles
   113  			}
   114  		}
   115  		return EmbedMaybeAlias
   116  	case *ast.ArrayType:
   117  		if v.Len != nil {
   118  			break
   119  		}
   120  		if ident, ok := v.Elt.(*ast.Ident); ok {
   121  			if ident.Name == "byte" {
   122  				return EmbedBytes
   123  			}
   124  			return EmbedMaybeAlias
   125  		}
   126  	case *ast.SelectorExpr:
   127  		if checkIdent(v.X, importName) && checkIdent(v.Sel, "FS") {
   128  			return EmbedFiles
   129  		}
   130  	}
   131  	return EmbedUnknown
   132  }
   133  
   134  func findEmbed(fset *token.FileSet, file *ast.File, eps []*Embed) error {
   135  	importName, err := embedparser.FindEmbedImportName(file)
   136  	if err != nil {
   137  		return err
   138  	}
   139  	for _, decl := range file.Decls {
   140  		if d, ok := decl.(*ast.GenDecl); ok && d.Tok == token.VAR {
   141  			for _, spec := range d.Specs {
   142  				vs, ok := spec.(*ast.ValueSpec)
   143  				if !ok {
   144  					continue
   145  				}
   146  				name := vs.Names[0]
   147  				pos := fset.Position(name.NamePos)
   148  				for _, e := range eps {
   149  					if pos.Filename == e.Pos.Filename &&
   150  						pos.Line == e.Pos.Line+1 {
   151  						if len(vs.Names) != 1 {
   152  							return fmt.Errorf("%v: go:embed cannot apply to multiple vars", e.embedPos())
   153  						}
   154  						if len(vs.Values) > 0 {
   155  							return fmt.Errorf("%v: go:embed cannot apply to var with initializer", e.embedPos())
   156  						}
   157  						kind := embedKind(vs.Type, importName)
   158  						if kind == EmbedUnknown {
   159  							var buf bytes.Buffer
   160  							printer.Fprint(&buf, fset, vs.Type)
   161  							return fmt.Errorf("%v: go:embed cannot apply to var of type %v", pos, buf.String())
   162  						}
   163  						e.Name = name.Name
   164  						e.Kind = kind
   165  						e.Spec = vs
   166  					}
   167  				}
   168  			}
   169  		}
   170  	}
   171  	return nil
   172  }