github.com/philpearl/plenc@v0.0.15/cmd/plenctag/main.go (about)

     1  // plenctag adds plenc tags to your structs
     2  package main
     3  
     4  import (
     5  	"bytes"
     6  	"flag"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"strconv"
    14  	"unicode"
    15  	"unicode/utf8"
    16  
    17  	"github.com/fatih/structtag"
    18  )
    19  
    20  // config defines how tags should be modified
    21  type config struct {
    22  	write            bool
    23  	excludeJSONMinus bool
    24  	excludeSQLMinus  bool
    25  	excludePrivate   bool
    26  
    27  	fset *token.FileSet
    28  }
    29  
    30  func main() {
    31  	if err := run(); err != nil {
    32  		fmt.Fprintln(os.Stderr, err.Error())
    33  		os.Exit(1)
    34  	}
    35  }
    36  
    37  func run() error {
    38  	var cfg config
    39  
    40  	flag.BoolVar(&cfg.write, "w", true, "Write result to (source) file instead of stdout")
    41  	flag.BoolVar(&cfg.excludeJSONMinus, "json", false, "Exclude json:\"-\"")
    42  	flag.BoolVar(&cfg.excludeSQLMinus, "sql", true, "Exclude sql:\"-\"")
    43  	flag.BoolVar(&cfg.excludePrivate, "private", true, "Exclude private fields (starting with lower case letter)")
    44  
    45  	flag.Parse()
    46  
    47  	if flag.NArg() == 0 {
    48  		fmt.Fprintln(os.Stderr, "no files specified")
    49  		flag.Usage()
    50  		os.Exit(1)
    51  	}
    52  
    53  	for _, filename := range flag.Args() {
    54  		node, err := cfg.parse(filename)
    55  		if err != nil {
    56  			return err
    57  		}
    58  
    59  		rewrittenNode, err := cfg.rewrite(node)
    60  		if err != nil {
    61  			return err
    62  		}
    63  
    64  		if err := cfg.format(rewrittenNode, filename); err != nil {
    65  			return err
    66  		}
    67  	}
    68  
    69  	return nil
    70  }
    71  
    72  func (c *config) parse(filename string) (ast.Node, error) {
    73  	c.fset = token.NewFileSet()
    74  	return parser.ParseFile(c.fset, filename, nil, parser.ParseComments)
    75  }
    76  
    77  func (c *config) format(file ast.Node, filename string) error {
    78  	var buf bytes.Buffer
    79  	err := format.Node(&buf, c.fset, file)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	if c.write {
    85  		err = os.WriteFile(filename, buf.Bytes(), 0)
    86  		if err != nil {
    87  			return err
    88  		}
    89  	} else {
    90  		fmt.Println(buf.String())
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  // rewrite rewrites the node for structs
    97  func (c *config) rewrite(node ast.Node) (ast.Node, error) {
    98  	var errs rewriteErrors
    99  
   100  	recordError := func(f *ast.Field, err error) {
   101  		errs.Append(fmt.Errorf("%s:%d:%d:%s",
   102  			c.fset.Position(f.Pos()).Filename,
   103  			c.fset.Position(f.Pos()).Line,
   104  			c.fset.Position(f.Pos()).Column,
   105  			err))
   106  	}
   107  
   108  	rewriteFunc := func(n ast.Node) bool {
   109  		x, ok := n.(*ast.StructType)
   110  		if !ok {
   111  			return true
   112  		}
   113  
   114  		// We make two passes through the fields. First we find the maximum existing plenc tag value. In the
   115  		// second pass we add plenc tags starting after this max value and skipping fields that match filters
   116  		var maxPlenc int
   117  		for _, f := range x.Fields.List {
   118  			if f.Tag == nil {
   119  				continue
   120  			}
   121  
   122  			pl, err := plencValue(f.Tag.Value)
   123  			if err != nil {
   124  				recordError(f, err)
   125  				continue
   126  			}
   127  
   128  			if pl > maxPlenc {
   129  				maxPlenc = pl
   130  			}
   131  		}
   132  
   133  		// Now we make updates
   134  		for _, f := range x.Fields.List {
   135  			if c.excludePrivate {
   136  				r, _ := utf8.DecodeRuneInString(f.Names[0].Name)
   137  				if unicode.IsLower(r) {
   138  					continue
   139  				}
   140  			}
   141  			if f.Tag == nil {
   142  				f.Tag = &ast.BasicLit{}
   143  			}
   144  
   145  			tags, err := extractTags(f.Tag.Value)
   146  			if err != nil {
   147  				recordError(f, err)
   148  				continue
   149  			}
   150  
   151  			if _, err := tags.Get("plenc"); err == nil {
   152  				// This field has a plenc tag, so we can leave it alone
   153  				continue
   154  			}
   155  
   156  			// No plenc tag. Either we explicitly exclude it `plenc:"-"`, or we give it a number `plenc:"12"`
   157  			tag := structtag.Tag{Key: "plenc"}
   158  			if c.isExcluded(tags) {
   159  				tag.Name = "-"
   160  			} else {
   161  				maxPlenc++
   162  				tag.Name = strconv.Itoa(maxPlenc)
   163  
   164  			}
   165  			tags.Set(&tag)
   166  
   167  			f.Tag.Value = quote(tags.String())
   168  		}
   169  
   170  		return true
   171  	}
   172  
   173  	ast.Inspect(node, rewriteFunc)
   174  
   175  	if errs != nil {
   176  		return node, errs
   177  	}
   178  
   179  	return node, nil
   180  }
   181  
   182  func (c *config) isExcluded(tags *structtag.Tags) bool {
   183  	if c.excludeSQLMinus {
   184  		tag, err := tags.Get("sql")
   185  		if err == nil && tag.Name == "-" {
   186  			return true
   187  		}
   188  	}
   189  	if c.excludeJSONMinus {
   190  		tag, err := tags.Get("json")
   191  		if err == nil && tag.Name == "-" {
   192  			return true
   193  		}
   194  	}
   195  	return false
   196  }
   197  
   198  func extractTags(tag string) (*structtag.Tags, error) {
   199  	if tag == "" {
   200  		return &structtag.Tags{}, nil
   201  	}
   202  	var err error
   203  	tag, err = strconv.Unquote(tag)
   204  	if err != nil {
   205  		return nil, fmt.Errorf("could not unquote tags. %w", err)
   206  	}
   207  
   208  	return structtag.Parse(tag)
   209  }
   210  
   211  func plencValue(tag string) (int, error) {
   212  	tags, err := extractTags(tag)
   213  	if err != nil {
   214  		return 0, err
   215  	}
   216  
   217  	if tags == nil {
   218  		return 0, err
   219  	}
   220  
   221  	tagg, err := tags.Get("plenc")
   222  	if err != nil {
   223  		// Only error is if it isn't present
   224  		return 0, nil
   225  	}
   226  
   227  	if tagg.Name == "-" {
   228  		// explicitly excluded
   229  		return 0, err
   230  	}
   231  
   232  	return strconv.Atoi(tagg.Name)
   233  }
   234  
   235  func quote(tag string) string {
   236  	return "`" + tag + "`"
   237  }
   238  
   239  type rewriteErrors []error
   240  
   241  func (r rewriteErrors) Error() string {
   242  	var buf bytes.Buffer
   243  	for _, e := range r {
   244  		buf.WriteString(fmt.Sprintf("%s\n", e.Error()))
   245  	}
   246  	return buf.String()
   247  }
   248  
   249  func (r *rewriteErrors) Append(err error) {
   250  	if err == nil {
   251  		return
   252  	}
   253  	*r = append(*r, err)
   254  }