github.com/linchen2chris/hugo@v0.0.0-20230307053224-cec209389705/codegen/methods.go (about)

     1  // Copyright 2019 The Hugo Authors. All rights reserved.
     2  // Some functions in this file (see comments) is based on the Go source code,
     3  // copyright The Go Authors and  governed by a BSD-style license.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  // http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  // Package codegen contains helpers for code generation.
    17  package codegen
    18  
    19  import (
    20  	"fmt"
    21  	"go/ast"
    22  	"go/parser"
    23  	"go/token"
    24  	"os"
    25  	"path"
    26  	"path/filepath"
    27  	"reflect"
    28  	"regexp"
    29  	"sort"
    30  	"strings"
    31  	"sync"
    32  )
    33  
    34  // Make room for insertions
    35  const weightWidth = 1000
    36  
    37  // NewInspector creates a new Inspector given a source root.
    38  func NewInspector(root string) *Inspector {
    39  	return &Inspector{ProjectRootDir: root}
    40  }
    41  
    42  // Inspector provides methods to help code generation. It uses a combination
    43  // of reflection and source code AST to do the heavy lifting.
    44  type Inspector struct {
    45  	ProjectRootDir string
    46  
    47  	init sync.Once
    48  
    49  	// Determines method order. Go's reflect sorts lexicographically, so
    50  	// we must parse the source to preserve this order.
    51  	methodWeight map[string]map[string]int
    52  }
    53  
    54  // MethodsFromTypes create a method set from the include slice, excluding any
    55  // method in exclude.
    56  func (c *Inspector) MethodsFromTypes(include []reflect.Type, exclude []reflect.Type) Methods {
    57  	c.parseSource()
    58  
    59  	var methods Methods
    60  
    61  	excludes := make(map[string]bool)
    62  
    63  	if len(exclude) > 0 {
    64  		for _, m := range c.MethodsFromTypes(exclude, nil) {
    65  			excludes[m.Name] = true
    66  		}
    67  	}
    68  
    69  	// There may be overlapping interfaces in types. Do a simple check for now.
    70  	seen := make(map[string]bool)
    71  
    72  	nameAndPackage := func(t reflect.Type) (string, string) {
    73  		var name, pkg string
    74  
    75  		isPointer := t.Kind() == reflect.Ptr
    76  
    77  		if isPointer {
    78  			t = t.Elem()
    79  		}
    80  
    81  		pkgPrefix := ""
    82  		if pkgPath := t.PkgPath(); pkgPath != "" {
    83  			pkgPath = strings.TrimSuffix(pkgPath, "/")
    84  			_, shortPath := path.Split(pkgPath)
    85  			pkgPrefix = shortPath + "."
    86  			pkg = pkgPath
    87  		}
    88  
    89  		name = t.Name()
    90  		if name == "" {
    91  			// interface{}
    92  			name = t.String()
    93  		}
    94  
    95  		if isPointer {
    96  			pkgPrefix = "*" + pkgPrefix
    97  		}
    98  
    99  		name = pkgPrefix + name
   100  
   101  		return name, pkg
   102  	}
   103  
   104  	for _, t := range include {
   105  		for i := 0; i < t.NumMethod(); i++ {
   106  
   107  			m := t.Method(i)
   108  			if excludes[m.Name] || seen[m.Name] {
   109  				continue
   110  			}
   111  
   112  			seen[m.Name] = true
   113  
   114  			if m.PkgPath != "" {
   115  				// Not exported
   116  				continue
   117  			}
   118  
   119  			numIn := m.Type.NumIn()
   120  
   121  			ownerName, _ := nameAndPackage(t)
   122  
   123  			method := Method{Owner: t, OwnerName: ownerName, Name: m.Name}
   124  
   125  			for i := 0; i < numIn; i++ {
   126  				in := m.Type.In(i)
   127  
   128  				name, pkg := nameAndPackage(in)
   129  
   130  				if pkg != "" {
   131  					method.Imports = append(method.Imports, pkg)
   132  				}
   133  
   134  				method.In = append(method.In, name)
   135  			}
   136  
   137  			numOut := m.Type.NumOut()
   138  
   139  			if numOut > 0 {
   140  				for i := 0; i < numOut; i++ {
   141  					out := m.Type.Out(i)
   142  					name, pkg := nameAndPackage(out)
   143  
   144  					if pkg != "" {
   145  						method.Imports = append(method.Imports, pkg)
   146  					}
   147  
   148  					method.Out = append(method.Out, name)
   149  				}
   150  			}
   151  
   152  			methods = append(methods, method)
   153  		}
   154  	}
   155  
   156  	sort.SliceStable(methods, func(i, j int) bool {
   157  		mi, mj := methods[i], methods[j]
   158  
   159  		wi := c.methodWeight[mi.OwnerName][mi.Name]
   160  		wj := c.methodWeight[mj.OwnerName][mj.Name]
   161  
   162  		if wi == wj {
   163  			return mi.Name < mj.Name
   164  		}
   165  
   166  		return wi < wj
   167  	})
   168  
   169  	return methods
   170  }
   171  
   172  func (c *Inspector) parseSource() {
   173  	c.init.Do(func() {
   174  		if !strings.Contains(c.ProjectRootDir, "hugo") {
   175  			panic("dir must be set to the Hugo root")
   176  		}
   177  
   178  		c.methodWeight = make(map[string]map[string]int)
   179  		dirExcludes := regexp.MustCompile("docs|examples")
   180  		fileExcludes := regexp.MustCompile("autogen")
   181  		var filenames []string
   182  
   183  		filepath.Walk(c.ProjectRootDir, func(path string, info os.FileInfo, err error) error {
   184  			if info.IsDir() {
   185  				if dirExcludes.MatchString(info.Name()) {
   186  					return filepath.SkipDir
   187  				}
   188  			}
   189  
   190  			if !strings.HasSuffix(path, ".go") || fileExcludes.MatchString(path) {
   191  				return nil
   192  			}
   193  
   194  			filenames = append(filenames, path)
   195  
   196  			return nil
   197  		})
   198  
   199  		for _, filename := range filenames {
   200  
   201  			pkg := c.packageFromPath(filename)
   202  
   203  			fset := token.NewFileSet()
   204  			node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
   205  			if err != nil {
   206  				panic(err)
   207  			}
   208  
   209  			ast.Inspect(node, func(n ast.Node) bool {
   210  				switch t := n.(type) {
   211  				case *ast.TypeSpec:
   212  					if t.Name.IsExported() {
   213  						switch it := t.Type.(type) {
   214  						case *ast.InterfaceType:
   215  							iface := pkg + "." + t.Name.Name
   216  							methodNames := collectMethodsRecursive(pkg, it.Methods.List)
   217  							weights := make(map[string]int)
   218  							weight := weightWidth
   219  							for _, name := range methodNames {
   220  								weights[name] = weight
   221  								weight += weightWidth
   222  							}
   223  							c.methodWeight[iface] = weights
   224  						}
   225  					}
   226  				}
   227  				return true
   228  			})
   229  
   230  		}
   231  
   232  		// Complement
   233  		for _, v1 := range c.methodWeight {
   234  			for k2, w := range v1 {
   235  				if v, found := c.methodWeight[k2]; found {
   236  					for k3, v3 := range v {
   237  						v1[k3] = (v3 / weightWidth) + w
   238  					}
   239  				}
   240  			}
   241  		}
   242  	})
   243  }
   244  
   245  func (c *Inspector) packageFromPath(p string) string {
   246  	p = filepath.ToSlash(p)
   247  	base := path.Base(p)
   248  	if !strings.Contains(base, ".") {
   249  		return base
   250  	}
   251  	return path.Base(strings.TrimSuffix(p, base))
   252  }
   253  
   254  // Method holds enough information about it to recreate it.
   255  type Method struct {
   256  	// The interface we extracted this method from.
   257  	Owner reflect.Type
   258  
   259  	// String version of the above, on the form PACKAGE.NAME, e.g.
   260  	// page.Page
   261  	OwnerName string
   262  
   263  	// Method name.
   264  	Name string
   265  
   266  	// Imports needed to satisfy the method signature.
   267  	Imports []string
   268  
   269  	// Argument types, including any package prefix, e.g. string, int, interface{},
   270  	// net.Url
   271  	In []string
   272  
   273  	// Return types.
   274  	Out []string
   275  }
   276  
   277  // Declaration creates a method declaration (without any body) for the given receiver.
   278  func (m Method) Declaration(receiver string) string {
   279  	return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStr())
   280  }
   281  
   282  // DeclarationNamed creates a method declaration (without any body) for the given receiver
   283  // with named return values.
   284  func (m Method) DeclarationNamed(receiver string) string {
   285  	return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStrNamed())
   286  }
   287  
   288  // Delegate creates a delegate call string.
   289  func (m Method) Delegate(receiver, delegate string) string {
   290  	ret := ""
   291  	if len(m.Out) > 0 {
   292  		ret = "return "
   293  	}
   294  	return fmt.Sprintf("%s%s.%s.%s%s", ret, receiverShort(receiver), delegate, m.Name, m.inOutStr())
   295  }
   296  
   297  func (m Method) String() string {
   298  	return m.Name + m.inStr() + " " + m.outStr() + "\n"
   299  }
   300  
   301  func (m Method) inOutStr() string {
   302  	if len(m.In) == 0 {
   303  		return "()"
   304  	}
   305  
   306  	args := make([]string, len(m.In))
   307  	for i := 0; i < len(args); i++ {
   308  		args[i] = fmt.Sprintf("arg%d", i)
   309  	}
   310  	return "(" + strings.Join(args, ", ") + ")"
   311  }
   312  
   313  func (m Method) inStr() string {
   314  	if len(m.In) == 0 {
   315  		return "()"
   316  	}
   317  
   318  	args := make([]string, len(m.In))
   319  	for i := 0; i < len(args); i++ {
   320  		args[i] = fmt.Sprintf("arg%d %s", i, m.In[i])
   321  	}
   322  	return "(" + strings.Join(args, ", ") + ")"
   323  }
   324  
   325  func (m Method) outStr() string {
   326  	if len(m.Out) == 0 {
   327  		return ""
   328  	}
   329  	if len(m.Out) == 1 {
   330  		return m.Out[0]
   331  	}
   332  
   333  	return "(" + strings.Join(m.Out, ", ") + ")"
   334  }
   335  
   336  func (m Method) outStrNamed() string {
   337  	if len(m.Out) == 0 {
   338  		return ""
   339  	}
   340  
   341  	outs := make([]string, len(m.Out))
   342  	for i := 0; i < len(outs); i++ {
   343  		outs[i] = fmt.Sprintf("o%d %s", i, m.Out[i])
   344  	}
   345  
   346  	return "(" + strings.Join(outs, ", ") + ")"
   347  }
   348  
   349  // Methods represents a list of methods for one or more interfaces.
   350  // The order matches the defined order in their source file(s).
   351  type Methods []Method
   352  
   353  // Imports returns a sorted list of package imports needed to satisfy the
   354  // signatures of all methods.
   355  func (m Methods) Imports() []string {
   356  	var pkgImports []string
   357  	for _, method := range m {
   358  		pkgImports = append(pkgImports, method.Imports...)
   359  	}
   360  	if len(pkgImports) > 0 {
   361  		pkgImports = uniqueNonEmptyStrings(pkgImports)
   362  		sort.Strings(pkgImports)
   363  	}
   364  	return pkgImports
   365  }
   366  
   367  // ToMarshalJSON creates a MarshalJSON method for these methods. Any method name
   368  // matching any of the regexps in excludes will be ignored.
   369  func (m Methods) ToMarshalJSON(receiver, pkgPath string, excludes ...string) (string, []string) {
   370  	var sb strings.Builder
   371  
   372  	r := receiverShort(receiver)
   373  	what := firstToUpper(trimAsterisk(receiver))
   374  	pgkName := path.Base(pkgPath)
   375  
   376  	fmt.Fprintf(&sb, "func Marshal%sToJSON(%s %s) ([]byte, error) {\n", what, r, receiver)
   377  
   378  	var methods Methods
   379  	excludeRes := make([]*regexp.Regexp, len(excludes))
   380  
   381  	for i, exclude := range excludes {
   382  		excludeRes[i] = regexp.MustCompile(exclude)
   383  	}
   384  
   385  	for _, method := range m {
   386  		// Exclude methods with arguments and incompatible return values
   387  		if len(method.In) > 0 || len(method.Out) == 0 || len(method.Out) > 2 {
   388  			continue
   389  		}
   390  
   391  		if len(method.Out) == 2 {
   392  			if method.Out[1] != "error" {
   393  				continue
   394  			}
   395  		}
   396  
   397  		for _, re := range excludeRes {
   398  			if re.MatchString(method.Name) {
   399  				continue
   400  			}
   401  		}
   402  
   403  		methods = append(methods, method)
   404  	}
   405  
   406  	for _, method := range methods {
   407  		varn := varName(method.Name)
   408  		if len(method.Out) == 1 {
   409  			fmt.Fprintf(&sb, "\t%s := %s.%s()\n", varn, r, method.Name)
   410  		} else {
   411  			fmt.Fprintf(&sb, "\t%s, err := %s.%s()\n", varn, r, method.Name)
   412  			fmt.Fprint(&sb, "\tif err != nil {\n\t\treturn nil, err\n\t}\n")
   413  		}
   414  	}
   415  
   416  	fmt.Fprint(&sb, "\n\ts := struct {\n")
   417  
   418  	for _, method := range methods {
   419  		fmt.Fprintf(&sb, "\t\t%s %s\n", method.Name, typeName(method.Out[0], pgkName))
   420  	}
   421  
   422  	fmt.Fprint(&sb, "\n\t}{\n")
   423  
   424  	for _, method := range methods {
   425  		varn := varName(method.Name)
   426  		fmt.Fprintf(&sb, "\t\t%s: %s,\n", method.Name, varn)
   427  	}
   428  
   429  	fmt.Fprint(&sb, "\n\t}\n\n")
   430  	fmt.Fprint(&sb, "\treturn json.Marshal(&s)\n}")
   431  
   432  	pkgImports := append(methods.Imports(), "encoding/json")
   433  
   434  	if pkgPath != "" {
   435  		// Exclude self
   436  		for i, pkgImp := range pkgImports {
   437  			if pkgImp == pkgPath {
   438  				pkgImports = append(pkgImports[:i], pkgImports[i+1:]...)
   439  			}
   440  		}
   441  	}
   442  
   443  	return sb.String(), pkgImports
   444  }
   445  
   446  func collectMethodsRecursive(pkg string, f []*ast.Field) []string {
   447  	var methodNames []string
   448  	for _, m := range f {
   449  		if m.Names != nil {
   450  			methodNames = append(methodNames, m.Names[0].Name)
   451  			continue
   452  		}
   453  
   454  		if ident, ok := m.Type.(*ast.Ident); ok && ident.Obj != nil {
   455  			switch tt := ident.Obj.Decl.(*ast.TypeSpec).Type.(type) {
   456  			case *ast.InterfaceType:
   457  				// Embedded interface
   458  				methodNames = append(
   459  					methodNames,
   460  					collectMethodsRecursive(
   461  						pkg,
   462  						tt.Methods.List)...)
   463  			}
   464  
   465  		} else {
   466  			// Embedded, but in a different file/package. Return the
   467  			// package.Name and deal with that later.
   468  			name := packageName(m.Type)
   469  			if !strings.Contains(name, ".") {
   470  				// Assume current package
   471  				name = pkg + "." + name
   472  			}
   473  			methodNames = append(methodNames, name)
   474  		}
   475  	}
   476  
   477  	return methodNames
   478  }
   479  
   480  func firstToLower(name string) string {
   481  	return strings.ToLower(name[:1]) + name[1:]
   482  }
   483  
   484  func firstToUpper(name string) string {
   485  	return strings.ToUpper(name[:1]) + name[1:]
   486  }
   487  
   488  func packageName(e ast.Expr) string {
   489  	switch tp := e.(type) {
   490  	case *ast.Ident:
   491  		return tp.Name
   492  	case *ast.SelectorExpr:
   493  		return fmt.Sprintf("%s.%s", packageName(tp.X), packageName(tp.Sel))
   494  	}
   495  	return ""
   496  }
   497  
   498  func receiverShort(receiver string) string {
   499  	return strings.ToLower(trimAsterisk(receiver))[:1]
   500  }
   501  
   502  func trimAsterisk(name string) string {
   503  	return strings.TrimPrefix(name, "*")
   504  }
   505  
   506  func typeName(name, pkg string) string {
   507  	return strings.TrimPrefix(name, pkg+".")
   508  }
   509  
   510  func uniqueNonEmptyStrings(s []string) []string {
   511  	var unique []string
   512  	set := map[string]any{}
   513  	for _, val := range s {
   514  		if val == "" {
   515  			continue
   516  		}
   517  		if _, ok := set[val]; !ok {
   518  			unique = append(unique, val)
   519  			set[val] = val
   520  		}
   521  	}
   522  	return unique
   523  }
   524  
   525  func varName(name string) string {
   526  	name = firstToLower(name)
   527  
   528  	// Adjust some reserved keywords, see https://golang.org/ref/spec#Keywords
   529  	switch name {
   530  	case "type":
   531  		name = "typ"
   532  	case "package":
   533  		name = "pkg"
   534  		// Not reserved, but syntax highlighters has it as a keyword.
   535  	case "len":
   536  		name = "length"
   537  	}
   538  
   539  	return name
   540  }