github.com/rgonomic/rgo@v0.2.2-0.20220708095818-4747f0905d6e/internal/pkg/types.go (about)

     1  // Copyright ©2019 The rgonomic Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package pkg
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/types"
    12  	"log"
    13  	"regexp"
    14  	"sort"
    15  	"strings"
    16  	"unicode"
    17  
    18  	"golang.org/x/tools/go/packages"
    19  )
    20  
    21  // Info holds information about the functions in a package.
    22  type Info struct {
    23  	Funcs []FuncInfo
    24  
    25  	Unpackers unpackers
    26  	Packers   packers
    27  }
    28  
    29  func (p *Info) Pkg() *types.Package {
    30  	if len(p.Funcs) == 0 {
    31  		return nil
    32  	}
    33  	return p.Funcs[0].Pkg()
    34  }
    35  
    36  // FuncInfo holds type and syntax information about a function.
    37  type FuncInfo struct {
    38  	*types.Func
    39  	*ast.FuncDecl
    40  }
    41  
    42  func (f FuncInfo) Signature() *types.Signature {
    43  	return f.Func.Type().(*types.Signature)
    44  }
    45  
    46  func Analyse(path, allowed string, verbose bool) (*Info, error) {
    47  	if strings.HasSuffix(path, "...") {
    48  		return nil, errors.New("pkg: invalid use of ... suffix")
    49  	}
    50  
    51  	cfg := &packages.Config{
    52  		Mode: packages.NeedFiles |
    53  			packages.NeedSyntax |
    54  			packages.NeedTypes |
    55  			packages.NeedTypesInfo,
    56  	}
    57  
    58  	pkgs, err := packages.Load(cfg, "pattern="+path)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if packages.PrintErrors(pkgs) != 0 {
    63  		return nil, errors.New("package errors")
    64  	}
    65  	var pkg *packages.Package
    66  	switch len(pkgs) {
    67  	case 0:
    68  		return nil, errors.New("pkg: no package analysed")
    69  	case 1:
    70  		pkg = pkgs[0]
    71  	default:
    72  		return nil, errors.New("pkg: more than one package analysed")
    73  	}
    74  
    75  	allow, err := regexp.Compile(allowed)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	log.Printf("wrapping: %s", pkg.ID)
    81  	if verbose {
    82  		log.Println("files:", pkg.GoFiles)
    83  	}
    84  	var funcs []FuncInfo
    85  	needUnpack := make(unpackers)
    86  	needPack := make(packers)
    87  	for _, f := range pkg.Syntax {
    88  		for _, decl := range f.Decls {
    89  			fd, ok := decl.(*ast.FuncDecl)
    90  			if !ok {
    91  				continue
    92  			}
    93  
    94  			fn := pkg.TypesInfo.Defs[fd.Name].(*types.Func)
    95  			if !fn.Exported() {
    96  				if verbose {
    97  					log.Printf("skipping %s: unexported function", fn.Name())
    98  				}
    99  				continue
   100  			}
   101  			if !allow.MatchString(fn.Name()) {
   102  				if verbose {
   103  					log.Printf("skipping %s: not allowed name", fn.Name())
   104  				}
   105  				continue
   106  			}
   107  			sig := fn.Type().(*types.Signature)
   108  			if sig.Recv() != nil {
   109  				if verbose {
   110  					log.Printf("skipping %s.%s: method function", sig.Recv().Type(), fn.Name())
   111  				}
   112  				continue
   113  			}
   114  
   115  			par := sig.Params()
   116  			err := checkType(par, par, true)
   117  			if err != nil {
   118  				if verbose {
   119  					log.Printf("skipping %s: %v", fn.Name(), err)
   120  				}
   121  				continue
   122  			}
   123  			res := sig.Results()
   124  			err = checkType(res, res, false)
   125  			if err != nil {
   126  				if verbose {
   127  					log.Printf("skipping %s: %v", fn.Name(), err)
   128  				}
   129  				continue
   130  			}
   131  			funcs = append(funcs, FuncInfo{
   132  				Func:     fn,
   133  				FuncDecl: fd,
   134  			})
   135  
   136  			walk(needUnpack, par, par)
   137  			walk(needPack, res, res)
   138  		}
   139  
   140  	}
   141  
   142  	// Check for mangled name collisions.
   143  	seen := make(map[string]types.Type)
   144  	for _, typ := range needUnpack {
   145  		if typ2, ok := seen[Mangle(typ)]; ok {
   146  			panic(fmt.Sprintf("mangled name collision in SEXP unwrapper: %s hits %s", typ, typ2))
   147  		}
   148  	}
   149  	seen = make(map[string]types.Type)
   150  	for _, typ := range needPack {
   151  		if typ2, ok := seen[Mangle(typ)]; ok {
   152  			panic(fmt.Sprintf("mangled name collision in SEXP builder: %s hits %s", typ, typ2))
   153  		}
   154  	}
   155  
   156  	return &Info{Funcs: funcs, Unpackers: needUnpack, Packers: needPack}, nil
   157  }
   158  
   159  // TODO(kortschak): Handle recursive type definitions correctly.
   160  
   161  func checkType(typ, named types.Type, parameters bool) error {
   162  	switch typ := typ.(type) {
   163  	case *types.Named:
   164  		return checkType(typ.Underlying(), typ, parameters)
   165  
   166  	case *types.Array:
   167  		elem := typ.Elem()
   168  		return checkType(elem, elem, parameters)
   169  
   170  	case *types.Basic:
   171  		switch kind := typ.Kind(); kind {
   172  		case types.Uint8, types.Int32:
   173  			// Dealias rune and byte.
   174  			*typ = *types.Typ[kind]
   175  		case types.Int64, types.Uint64:
   176  			if typ == named {
   177  				return fmt.Errorf("unhandled integer type %s", typ)
   178  			}
   179  			return fmt.Errorf("unhandled integer type %s (%s)", named, typ)
   180  		}
   181  
   182  	case *types.Chan:
   183  		if typ == named {
   184  			return fmt.Errorf("unhandled chan type %s", typ)
   185  		}
   186  		return fmt.Errorf("unhandled chan type %s (%s)", named, typ)
   187  
   188  	case *types.Interface:
   189  		if !IsError(named) {
   190  			return fmt.Errorf("unhandled interface type %s", named)
   191  		}
   192  
   193  	case *types.Map:
   194  		if !types.Identical(typ.Key().Underlying(), types.Typ[types.String]) {
   195  			if typ == named {
   196  				return fmt.Errorf("unhandled non-string keyed map type %s", typ)
   197  			}
   198  			return fmt.Errorf("unhandled non-string keyed map type %s (%s)", named, typ)
   199  		}
   200  		elem := typ.Elem()
   201  		err := checkType(elem, elem, parameters)
   202  		if err != nil {
   203  			return err
   204  		}
   205  
   206  	case *types.Pointer:
   207  		elem := typ.Elem()
   208  		err := checkType(elem, elem, parameters)
   209  		if err != nil {
   210  			return err
   211  		}
   212  		if !parameters {
   213  			break
   214  		}
   215  		if typ == named {
   216  			log.Printf("warning: pointer type %s", typ)
   217  		} else {
   218  			log.Printf("warning: pointer type %s (%s)", named, typ)
   219  		}
   220  
   221  	case *types.Signature:
   222  		if typ == named {
   223  			return fmt.Errorf("unhandled function type with signature %s", typ)
   224  		}
   225  		return fmt.Errorf("unhandled function type with signature %s (%s)", named, typ)
   226  
   227  	case *types.Slice:
   228  		elem := typ.Elem()
   229  		err := checkType(elem, elem, parameters)
   230  		if err != nil {
   231  			return err
   232  		}
   233  		if !parameters {
   234  			break
   235  		}
   236  		if typ == named {
   237  			log.Printf("warning: slice type %s", typ)
   238  		} else {
   239  			log.Printf("warning: slice type %s (%s)", named, typ)
   240  		}
   241  
   242  	case *types.Struct:
   243  		for i := 0; i < typ.NumFields(); i++ {
   244  			f := typ.Field(i).Type()
   245  			err := checkType(f, f, parameters)
   246  			if err != nil {
   247  				return err
   248  			}
   249  		}
   250  
   251  	case *types.Tuple:
   252  		for i := 0; i < typ.Len(); i++ {
   253  			f := typ.At(i)
   254  			if parameters && (f.Name() == "" || f.Name() == "_") {
   255  				return fmt.Errorf("unhandled function with blank parameter name %q", f)
   256  			}
   257  			typ := f.Type()
   258  			err := checkType(typ, typ, parameters)
   259  			if err != nil {
   260  				return err
   261  			}
   262  		}
   263  	}
   264  
   265  	return nil
   266  }
   267  
   268  type unpackers map[string]types.Type
   269  
   270  func (v unpackers) visit(typ types.Type) {
   271  	if _, ok := typ.Underlying().(*types.Interface); ok {
   272  		panic(fmt.Sprintf("unhandled input parameter type: %q", typ))
   273  	}
   274  	s := typ.String()
   275  	if _, ok := v[s]; !ok {
   276  		v[s] = typ
   277  	}
   278  	if also := v.also(typ); also != types.Invalid {
   279  		v.visit(types.Typ[also])
   280  	}
   281  }
   282  
   283  func (v unpackers) also(typ types.Type) types.BasicKind {
   284  	if typ, ok := typ.(*types.Basic); ok {
   285  		// Make sure we have a complex128 value we can convert to complex64.
   286  		if typ.Kind() == types.Complex64 {
   287  			return types.Complex128
   288  		}
   289  	}
   290  	return types.Invalid
   291  }
   292  
   293  func (v unpackers) Types() []types.Type {
   294  	typs := make([]types.Type, 0, len(v))
   295  	for _, typ := range v {
   296  		typs = append(typs, typ)
   297  	}
   298  	sort.Sort(byName(typs))
   299  	return typs
   300  }
   301  
   302  type packers map[string]types.Type
   303  
   304  func (v packers) visit(typ types.Type) {
   305  	s := typ.String()
   306  	if _, ok := v[s]; !ok {
   307  		v[s] = typ
   308  	}
   309  }
   310  
   311  func (v packers) NeedList() bool {
   312  	for _, typ := range v {
   313  		switch typ.(type) {
   314  		case *types.Struct, *types.Map:
   315  			return true
   316  		}
   317  	}
   318  	return false
   319  }
   320  
   321  func (v packers) Types() []types.Type {
   322  	typs := make([]types.Type, 0, len(v))
   323  	for _, typ := range v {
   324  		typs = append(typs, typ)
   325  	}
   326  	sort.Sort(byName(typs))
   327  	return typs
   328  }
   329  
   330  type byName []types.Type
   331  
   332  func (t byName) Len() int           { return len(t) }
   333  func (t byName) Less(i, j int) bool { return Mangle(t[i]) < Mangle(t[j]) }
   334  func (t byName) Swap(i, j int)      { t[i], t[j] = t[j], t[i] }
   335  
   336  type visitor interface {
   337  	visit(typ types.Type)
   338  }
   339  
   340  func walk(v visitor, typ, named types.Type) {
   341  	switch typ := typ.(type) {
   342  	case *types.Named:
   343  		v.visit(typ)
   344  		walk(v, typ.Underlying(), typ)
   345  
   346  	case *types.Array:
   347  		elem := typ.Elem()
   348  		v.visit(typ)
   349  		v.visit(types.NewSlice(elem)) // This will visit the element in the slice walk.
   350  
   351  	case *types.Basic:
   352  		switch typ.Kind() {
   353  		case types.Int64, types.Uint64:
   354  			if typ == named {
   355  				panic(fmt.Sprintf("unhandled integer type %s", typ))
   356  			}
   357  			panic(fmt.Sprintf("unhandled integer type %s (%s)", named, typ))
   358  		}
   359  		v.visit(typ)
   360  
   361  	case *types.Chan:
   362  		if typ == named {
   363  			panic(fmt.Sprintf("unhandled chan type %s", typ))
   364  		}
   365  		panic(fmt.Sprintf("unhandled chan type %s (%s)", named, typ))
   366  
   367  	case *types.Interface:
   368  		if !IsError(named) {
   369  			panic(fmt.Sprintf("unhandled interface type %s", named))
   370  		}
   371  		v.visit(types.Typ[types.String])
   372  		v.visit(named)
   373  
   374  	case *types.Map:
   375  		if !types.Identical(typ.Key().Underlying(), types.Typ[types.String]) {
   376  			if typ == named {
   377  				panic(fmt.Sprintf("unhandled non-string keyed map type %s", typ))
   378  			}
   379  			panic(fmt.Sprintf("unhandled non-string keyed map type %s (%s)", named, typ))
   380  		}
   381  		key := typ.Key()
   382  		walk(v, key, key)
   383  		elem := typ.Elem()
   384  		v.visit(typ)
   385  		walk(v, elem, elem)
   386  
   387  	case *types.Pointer:
   388  		elem := typ.Elem()
   389  		v.visit(typ)
   390  		walk(v, elem, elem)
   391  
   392  	case *types.Signature:
   393  		if typ == named {
   394  			panic(fmt.Sprintf("unhandled function type %s", typ))
   395  		}
   396  		panic(fmt.Sprintf("unhandled function type %s (%s)", named, typ))
   397  
   398  	case *types.Slice:
   399  		elem := typ.Elem()
   400  		v.visit(typ)
   401  		if _, ok := elem.Underlying().(*types.Basic); !ok {
   402  			walk(v, elem, elem)
   403  		}
   404  
   405  	case *types.Struct:
   406  		for i := 0; i < typ.NumFields(); i++ {
   407  			f := typ.Field(i).Type()
   408  			walk(v, f, f)
   409  		}
   410  		v.visit(typ)
   411  
   412  	case *types.Tuple:
   413  		for i := 0; i < typ.Len(); i++ {
   414  			f := typ.At(i).Type()
   415  			walk(v, f, f)
   416  		}
   417  	}
   418  }
   419  
   420  func IsError(typ types.Type) bool {
   421  	return types.Identical(typ, types.Universe.Lookup("error").Type())
   422  }
   423  
   424  func Mangle(typ types.Type) string {
   425  	// FIXME(kortschak): This may lead to name collisions for complex unnamed types.
   426  	runes := []rune(fmt.Sprintf("%T_%[1]s", typ))
   427  	for i, r := range runes {
   428  		if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
   429  			runes[i] = '_'
   430  		}
   431  	}
   432  	return string(runes)
   433  }