github.com/rgonomic/rgo@v0.2.2-0.20220708095818-4747f0905d6e/internal/pkg/types.go (about) 1 // Copyright ©2019 The rgonomic Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package pkg 6 7 import ( 8 "errors" 9 "fmt" 10 "go/ast" 11 "go/types" 12 "log" 13 "regexp" 14 "sort" 15 "strings" 16 "unicode" 17 18 "golang.org/x/tools/go/packages" 19 ) 20 21 // Info holds information about the functions in a package. 22 type Info struct { 23 Funcs []FuncInfo 24 25 Unpackers unpackers 26 Packers packers 27 } 28 29 func (p *Info) Pkg() *types.Package { 30 if len(p.Funcs) == 0 { 31 return nil 32 } 33 return p.Funcs[0].Pkg() 34 } 35 36 // FuncInfo holds type and syntax information about a function. 37 type FuncInfo struct { 38 *types.Func 39 *ast.FuncDecl 40 } 41 42 func (f FuncInfo) Signature() *types.Signature { 43 return f.Func.Type().(*types.Signature) 44 } 45 46 func Analyse(path, allowed string, verbose bool) (*Info, error) { 47 if strings.HasSuffix(path, "...") { 48 return nil, errors.New("pkg: invalid use of ... suffix") 49 } 50 51 cfg := &packages.Config{ 52 Mode: packages.NeedFiles | 53 packages.NeedSyntax | 54 packages.NeedTypes | 55 packages.NeedTypesInfo, 56 } 57 58 pkgs, err := packages.Load(cfg, "pattern="+path) 59 if err != nil { 60 return nil, err 61 } 62 if packages.PrintErrors(pkgs) != 0 { 63 return nil, errors.New("package errors") 64 } 65 var pkg *packages.Package 66 switch len(pkgs) { 67 case 0: 68 return nil, errors.New("pkg: no package analysed") 69 case 1: 70 pkg = pkgs[0] 71 default: 72 return nil, errors.New("pkg: more than one package analysed") 73 } 74 75 allow, err := regexp.Compile(allowed) 76 if err != nil { 77 return nil, err 78 } 79 80 log.Printf("wrapping: %s", pkg.ID) 81 if verbose { 82 log.Println("files:", pkg.GoFiles) 83 } 84 var funcs []FuncInfo 85 needUnpack := make(unpackers) 86 needPack := make(packers) 87 for _, f := range pkg.Syntax { 88 for _, decl := range f.Decls { 89 fd, ok := decl.(*ast.FuncDecl) 90 if !ok { 91 continue 92 } 93 94 fn := pkg.TypesInfo.Defs[fd.Name].(*types.Func) 95 if !fn.Exported() { 96 if verbose { 97 log.Printf("skipping %s: unexported function", fn.Name()) 98 } 99 continue 100 } 101 if !allow.MatchString(fn.Name()) { 102 if verbose { 103 log.Printf("skipping %s: not allowed name", fn.Name()) 104 } 105 continue 106 } 107 sig := fn.Type().(*types.Signature) 108 if sig.Recv() != nil { 109 if verbose { 110 log.Printf("skipping %s.%s: method function", sig.Recv().Type(), fn.Name()) 111 } 112 continue 113 } 114 115 par := sig.Params() 116 err := checkType(par, par, true) 117 if err != nil { 118 if verbose { 119 log.Printf("skipping %s: %v", fn.Name(), err) 120 } 121 continue 122 } 123 res := sig.Results() 124 err = checkType(res, res, false) 125 if err != nil { 126 if verbose { 127 log.Printf("skipping %s: %v", fn.Name(), err) 128 } 129 continue 130 } 131 funcs = append(funcs, FuncInfo{ 132 Func: fn, 133 FuncDecl: fd, 134 }) 135 136 walk(needUnpack, par, par) 137 walk(needPack, res, res) 138 } 139 140 } 141 142 // Check for mangled name collisions. 143 seen := make(map[string]types.Type) 144 for _, typ := range needUnpack { 145 if typ2, ok := seen[Mangle(typ)]; ok { 146 panic(fmt.Sprintf("mangled name collision in SEXP unwrapper: %s hits %s", typ, typ2)) 147 } 148 } 149 seen = make(map[string]types.Type) 150 for _, typ := range needPack { 151 if typ2, ok := seen[Mangle(typ)]; ok { 152 panic(fmt.Sprintf("mangled name collision in SEXP builder: %s hits %s", typ, typ2)) 153 } 154 } 155 156 return &Info{Funcs: funcs, Unpackers: needUnpack, Packers: needPack}, nil 157 } 158 159 // TODO(kortschak): Handle recursive type definitions correctly. 160 161 func checkType(typ, named types.Type, parameters bool) error { 162 switch typ := typ.(type) { 163 case *types.Named: 164 return checkType(typ.Underlying(), typ, parameters) 165 166 case *types.Array: 167 elem := typ.Elem() 168 return checkType(elem, elem, parameters) 169 170 case *types.Basic: 171 switch kind := typ.Kind(); kind { 172 case types.Uint8, types.Int32: 173 // Dealias rune and byte. 174 *typ = *types.Typ[kind] 175 case types.Int64, types.Uint64: 176 if typ == named { 177 return fmt.Errorf("unhandled integer type %s", typ) 178 } 179 return fmt.Errorf("unhandled integer type %s (%s)", named, typ) 180 } 181 182 case *types.Chan: 183 if typ == named { 184 return fmt.Errorf("unhandled chan type %s", typ) 185 } 186 return fmt.Errorf("unhandled chan type %s (%s)", named, typ) 187 188 case *types.Interface: 189 if !IsError(named) { 190 return fmt.Errorf("unhandled interface type %s", named) 191 } 192 193 case *types.Map: 194 if !types.Identical(typ.Key().Underlying(), types.Typ[types.String]) { 195 if typ == named { 196 return fmt.Errorf("unhandled non-string keyed map type %s", typ) 197 } 198 return fmt.Errorf("unhandled non-string keyed map type %s (%s)", named, typ) 199 } 200 elem := typ.Elem() 201 err := checkType(elem, elem, parameters) 202 if err != nil { 203 return err 204 } 205 206 case *types.Pointer: 207 elem := typ.Elem() 208 err := checkType(elem, elem, parameters) 209 if err != nil { 210 return err 211 } 212 if !parameters { 213 break 214 } 215 if typ == named { 216 log.Printf("warning: pointer type %s", typ) 217 } else { 218 log.Printf("warning: pointer type %s (%s)", named, typ) 219 } 220 221 case *types.Signature: 222 if typ == named { 223 return fmt.Errorf("unhandled function type with signature %s", typ) 224 } 225 return fmt.Errorf("unhandled function type with signature %s (%s)", named, typ) 226 227 case *types.Slice: 228 elem := typ.Elem() 229 err := checkType(elem, elem, parameters) 230 if err != nil { 231 return err 232 } 233 if !parameters { 234 break 235 } 236 if typ == named { 237 log.Printf("warning: slice type %s", typ) 238 } else { 239 log.Printf("warning: slice type %s (%s)", named, typ) 240 } 241 242 case *types.Struct: 243 for i := 0; i < typ.NumFields(); i++ { 244 f := typ.Field(i).Type() 245 err := checkType(f, f, parameters) 246 if err != nil { 247 return err 248 } 249 } 250 251 case *types.Tuple: 252 for i := 0; i < typ.Len(); i++ { 253 f := typ.At(i) 254 if parameters && (f.Name() == "" || f.Name() == "_") { 255 return fmt.Errorf("unhandled function with blank parameter name %q", f) 256 } 257 typ := f.Type() 258 err := checkType(typ, typ, parameters) 259 if err != nil { 260 return err 261 } 262 } 263 } 264 265 return nil 266 } 267 268 type unpackers map[string]types.Type 269 270 func (v unpackers) visit(typ types.Type) { 271 if _, ok := typ.Underlying().(*types.Interface); ok { 272 panic(fmt.Sprintf("unhandled input parameter type: %q", typ)) 273 } 274 s := typ.String() 275 if _, ok := v[s]; !ok { 276 v[s] = typ 277 } 278 if also := v.also(typ); also != types.Invalid { 279 v.visit(types.Typ[also]) 280 } 281 } 282 283 func (v unpackers) also(typ types.Type) types.BasicKind { 284 if typ, ok := typ.(*types.Basic); ok { 285 // Make sure we have a complex128 value we can convert to complex64. 286 if typ.Kind() == types.Complex64 { 287 return types.Complex128 288 } 289 } 290 return types.Invalid 291 } 292 293 func (v unpackers) Types() []types.Type { 294 typs := make([]types.Type, 0, len(v)) 295 for _, typ := range v { 296 typs = append(typs, typ) 297 } 298 sort.Sort(byName(typs)) 299 return typs 300 } 301 302 type packers map[string]types.Type 303 304 func (v packers) visit(typ types.Type) { 305 s := typ.String() 306 if _, ok := v[s]; !ok { 307 v[s] = typ 308 } 309 } 310 311 func (v packers) NeedList() bool { 312 for _, typ := range v { 313 switch typ.(type) { 314 case *types.Struct, *types.Map: 315 return true 316 } 317 } 318 return false 319 } 320 321 func (v packers) Types() []types.Type { 322 typs := make([]types.Type, 0, len(v)) 323 for _, typ := range v { 324 typs = append(typs, typ) 325 } 326 sort.Sort(byName(typs)) 327 return typs 328 } 329 330 type byName []types.Type 331 332 func (t byName) Len() int { return len(t) } 333 func (t byName) Less(i, j int) bool { return Mangle(t[i]) < Mangle(t[j]) } 334 func (t byName) Swap(i, j int) { t[i], t[j] = t[j], t[i] } 335 336 type visitor interface { 337 visit(typ types.Type) 338 } 339 340 func walk(v visitor, typ, named types.Type) { 341 switch typ := typ.(type) { 342 case *types.Named: 343 v.visit(typ) 344 walk(v, typ.Underlying(), typ) 345 346 case *types.Array: 347 elem := typ.Elem() 348 v.visit(typ) 349 v.visit(types.NewSlice(elem)) // This will visit the element in the slice walk. 350 351 case *types.Basic: 352 switch typ.Kind() { 353 case types.Int64, types.Uint64: 354 if typ == named { 355 panic(fmt.Sprintf("unhandled integer type %s", typ)) 356 } 357 panic(fmt.Sprintf("unhandled integer type %s (%s)", named, typ)) 358 } 359 v.visit(typ) 360 361 case *types.Chan: 362 if typ == named { 363 panic(fmt.Sprintf("unhandled chan type %s", typ)) 364 } 365 panic(fmt.Sprintf("unhandled chan type %s (%s)", named, typ)) 366 367 case *types.Interface: 368 if !IsError(named) { 369 panic(fmt.Sprintf("unhandled interface type %s", named)) 370 } 371 v.visit(types.Typ[types.String]) 372 v.visit(named) 373 374 case *types.Map: 375 if !types.Identical(typ.Key().Underlying(), types.Typ[types.String]) { 376 if typ == named { 377 panic(fmt.Sprintf("unhandled non-string keyed map type %s", typ)) 378 } 379 panic(fmt.Sprintf("unhandled non-string keyed map type %s (%s)", named, typ)) 380 } 381 key := typ.Key() 382 walk(v, key, key) 383 elem := typ.Elem() 384 v.visit(typ) 385 walk(v, elem, elem) 386 387 case *types.Pointer: 388 elem := typ.Elem() 389 v.visit(typ) 390 walk(v, elem, elem) 391 392 case *types.Signature: 393 if typ == named { 394 panic(fmt.Sprintf("unhandled function type %s", typ)) 395 } 396 panic(fmt.Sprintf("unhandled function type %s (%s)", named, typ)) 397 398 case *types.Slice: 399 elem := typ.Elem() 400 v.visit(typ) 401 if _, ok := elem.Underlying().(*types.Basic); !ok { 402 walk(v, elem, elem) 403 } 404 405 case *types.Struct: 406 for i := 0; i < typ.NumFields(); i++ { 407 f := typ.Field(i).Type() 408 walk(v, f, f) 409 } 410 v.visit(typ) 411 412 case *types.Tuple: 413 for i := 0; i < typ.Len(); i++ { 414 f := typ.At(i).Type() 415 walk(v, f, f) 416 } 417 } 418 } 419 420 func IsError(typ types.Type) bool { 421 return types.Identical(typ, types.Universe.Lookup("error").Type()) 422 } 423 424 func Mangle(typ types.Type) string { 425 // FIXME(kortschak): This may lead to name collisions for complex unnamed types. 426 runes := []rune(fmt.Sprintf("%T_%[1]s", typ)) 427 for i, r := range runes { 428 if !unicode.IsLetter(r) && !unicode.IsDigit(r) { 429 runes[i] = '_' 430 } 431 } 432 return string(runes) 433 }