github.com/willyham/dosa@v2.3.1-0.20171024181418-1e446d37ee71+incompatible/finder.go (about)

     1  // Copyright (c) 2017 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package dosa
    22  
    23  import (
    24  	"fmt"
    25  	"go/ast"
    26  	"go/parser"
    27  	"go/token"
    28  	"os"
    29  	"path/filepath"
    30  	"reflect"
    31  	"strings"
    32  	"unicode"
    33  	"unicode/utf8"
    34  
    35  	"github.com/pkg/errors"
    36  )
    37  
    38  // FindEntities finds all entities in the given file paths. An error is
    39  // returned if there are naming colisions, otherwise, return a slice of
    40  // warnings (or nil).
    41  func FindEntities(paths, excludes []string) ([]*Table, []error, error) {
    42  	var entities []*Table
    43  	var warnings []error
    44  	for _, path := range paths {
    45  		fileSet := token.NewFileSet()
    46  		packages, err := parser.ParseDir(fileSet, path, func(fileInfo os.FileInfo) bool {
    47  			if len(excludes) == 0 {
    48  				return true
    49  			}
    50  			for _, exclude := range excludes {
    51  				if matched, _ := filepath.Match(exclude, fileInfo.Name()); matched {
    52  					return false
    53  				}
    54  			}
    55  			return true
    56  		}, 0)
    57  		if err != nil {
    58  			return nil, nil, err
    59  		}
    60  		erv := new(EntityRecordingVisitor)
    61  		for _, pkg := range packages { // go through all the packages
    62  			for _, file := range pkg.Files { // go through all the files
    63  				packagePrefix, hasDosa := findDosaPackage(file)
    64  				//if erv.PackageName != "" { // skip packages that don't import 'dosa'
    65  				if hasDosa {
    66  					erv.PackagePrefix = packagePrefix
    67  					for _, decl := range file.Decls { // go through all the declarations
    68  						ast.Walk(erv, decl)
    69  					}
    70  				}
    71  			}
    72  		}
    73  		entities = append(entities, erv.Entities...)
    74  		warnings = append(warnings, erv.Warnings...)
    75  	}
    76  
    77  	return entities, warnings, nil
    78  }
    79  
    80  // DosaPackageName is the name of the dosa package, fully qualified and quoted
    81  const DosaPackageName = `"github.com/uber-go/dosa"`
    82  
    83  func findDosaPackage(file *ast.File) (string, bool) {
    84  	// look for the case where we import dosa
    85  	for _, impspec := range file.Imports {
    86  		if impspec.Path.Value == DosaPackageName {
    87  			// impspec.Name is nil when not renamed,
    88  			// so we use the default "dosa"
    89  			if impspec.Name == nil {
    90  				return "dosa", true
    91  			}
    92  			// renamed case
    93  			return impspec.Name.Name, true
    94  		}
    95  	}
    96  	if file.Name.Name == "dosa" {
    97  		// special case: our package is 'dosa' so no prefix is required
    98  		return "", true
    99  	}
   100  	// this file doesn't have any references to dosa, so skip it
   101  	return "", false
   102  }
   103  
   104  // EntityRecordingVisitor is a visitor that records entities it finds
   105  // It also keeps track of all failed entities that pass the basic "looks like a DOSA object" test
   106  // (see isDosaEntity to understand that test)
   107  type EntityRecordingVisitor struct {
   108  	Entities      []*Table
   109  	Warnings      []error
   110  	PackagePrefix string
   111  }
   112  
   113  // Visit records all the entities seen into the EntityRecordingVisitor structure
   114  func (f *EntityRecordingVisitor) Visit(n ast.Node) ast.Visitor {
   115  	switch n := n.(type) {
   116  	case *ast.File, *ast.Package, *ast.BlockStmt, *ast.DeclStmt, *ast.FuncDecl, *ast.GenDecl:
   117  		return f
   118  	case *ast.TypeSpec:
   119  		if structType, ok := n.Type.(*ast.StructType); ok {
   120  			// look for a Entity with a dosa annotation
   121  			if isDosaEntity(structType) {
   122  				table, err := tableFromStructType(n.Name.Name, structType, f.PackagePrefix)
   123  				if err == nil {
   124  					f.Entities = append(f.Entities, table)
   125  				} else {
   126  					f.Warnings = append(f.Warnings, err)
   127  				}
   128  			}
   129  		}
   130  	}
   131  	return nil
   132  }
   133  
   134  // isDosaEntity is a sanity check so that only objects that are probably supposed to be dosa
   135  // annotated objects will generate warnings. The rules for that are:
   136  //  - must have some fields
   137  //  - the first field should be of type Entity
   138  //    TODO: Really any field could be type Entity, but we currently do not have this case
   139  
   140  func isDosaEntity(structType *ast.StructType) bool {
   141  	// structures with no fields cannot be dosa entities
   142  	if len(structType.Fields.List) < 1 {
   143  		return false
   144  	}
   145  
   146  	// the first field should be a DOSA Entity type
   147  	candidateEntityField := structType.Fields.List[0]
   148  	if identifier, ok := candidateEntityField.Type.(*ast.Ident); ok {
   149  		if identifier.Name != entityName {
   150  			return false
   151  		}
   152  	}
   153  
   154  	// and should have a DOSA tag
   155  	if candidateEntityField.Tag == nil || candidateEntityField.Tag.Kind != token.STRING {
   156  		return false
   157  	}
   158  	entityTag := reflect.StructTag(strings.Trim(candidateEntityField.Tag.Value, "`"))
   159  	if entityTag.Get(dosaTagKey) == "" {
   160  		return false
   161  	}
   162  
   163  	return true
   164  }
   165  
   166  func parseASTType(expr ast.Expr) (string, error) {
   167  	var kind string
   168  	var err error
   169  	switch typeName := expr.(type) {
   170  	case *ast.Ident:
   171  		kind = typeName.Name
   172  		// not an Entity type, perhaps another primitive type
   173  	case *ast.ArrayType:
   174  		// only dosa allowed array type is []byte
   175  		if typeName, ok := typeName.Elt.(*ast.Ident); ok {
   176  			if typeName.Name == "byte" {
   177  				kind = "[]byte"
   178  			}
   179  		}
   180  	case *ast.SelectorExpr:
   181  		// only dosa allowed selector is time.Time
   182  		if innerName, ok := typeName.X.(*ast.Ident); ok {
   183  			kind = innerName.Name + "." + typeName.Sel.Name
   184  		}
   185  	case *ast.StarExpr:
   186  		// pointer types
   187  		// need to recursively parse the type
   188  		kind, err = parseASTType(typeName.X)
   189  		kind = "*" + kind
   190  	default:
   191  		err = fmt.Errorf("Unexpected field type: %v", typeName)
   192  	}
   193  
   194  	return kind, err
   195  }
   196  
   197  // tableFromStructType takes an ast StructType and converts it into a Table object
   198  func tableFromStructType(structName string, structType *ast.StructType, packagePrefix string) (*Table, error) {
   199  	normalizedName, err := NormalizeName(structName)
   200  	if err != nil {
   201  		// TODO: This isn't correct, someone could override the name later
   202  		return nil, errors.Wrapf(err, "struct name is invalid")
   203  	}
   204  
   205  	t := &Table{
   206  		StructName: structName,
   207  		EntityDefinition: EntityDefinition{
   208  			Name:    normalizedName,
   209  			Columns: []*ColumnDefinition{},
   210  			Indexes: map[string]*IndexDefinition{},
   211  		},
   212  		ColToField: map[string]string{},
   213  		FieldToCol: map[string]string{},
   214  	}
   215  	for _, field := range structType.Fields.List {
   216  		var dosaTag string
   217  		if field.Tag != nil {
   218  			entityTag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
   219  			dosaTag = strings.TrimSpace(entityTag.Get(dosaTagKey))
   220  		}
   221  		if dosaTag == "-" { // skip explicitly ignored fields
   222  			continue
   223  		}
   224  
   225  		kind, err := parseASTType(field.Type)
   226  		if err != nil {
   227  			return nil, err
   228  		}
   229  
   230  		if kind == packagePrefix+"."+entityName || (packagePrefix == "" && kind == entityName) {
   231  			var err error
   232  			if t.EntityDefinition.Name, t.Key, err = parseEntityTag(structName, dosaTag); err != nil {
   233  				return nil, err
   234  			}
   235  		} else {
   236  			for _, fieldName := range field.Names {
   237  				name := fieldName.Name
   238  				if kind == packagePrefix+"."+indexName || (packagePrefix == "" && kind == indexName) {
   239  					indexName, indexKey, err := parseIndexTag(name, dosaTag)
   240  					if err != nil {
   241  						return nil, err
   242  					}
   243  					if _, exist := t.Indexes[indexName]; exist {
   244  						return nil, errors.Errorf("index name is duplicated: %s", indexName)
   245  					}
   246  					t.Indexes[indexName] = &IndexDefinition{Key: indexKey}
   247  				} else {
   248  					firstRune, _ := utf8.DecodeRuneInString(name)
   249  					if unicode.IsLower(firstRune) {
   250  						// skip unexported fields
   251  						continue
   252  					}
   253  					typ, isPointer := stringToDosaType(kind, packagePrefix)
   254  					if typ == Invalid {
   255  						return nil, fmt.Errorf("Column %q has invalid type %q", name, kind)
   256  					}
   257  					cd, err := parseField(typ, isPointer, name, dosaTag)
   258  					if err != nil {
   259  						return nil, errors.Wrapf(err, "column %q", name)
   260  					}
   261  					t.Columns = append(t.Columns, cd)
   262  					t.ColToField[cd.Name] = name
   263  					t.FieldToCol[name] = cd.Name
   264  				}
   265  			}
   266  
   267  			if len(field.Names) == 0 {
   268  				if kind == packagePrefix+"."+indexName || (packagePrefix == "" && kind == indexName) {
   269  					indexName, indexKey, err := parseIndexTag("", dosaTag)
   270  					if err != nil {
   271  						return nil, err
   272  					}
   273  					if _, exist := t.Indexes[indexName]; exist {
   274  						return nil, errors.Errorf("index name is duplicated: %s", indexName)
   275  					}
   276  					t.Indexes[indexName] = &IndexDefinition{Key: indexKey}
   277  				}
   278  			}
   279  		}
   280  	}
   281  
   282  	if t.Key == nil {
   283  		return nil, errors.Errorf("cannot find dosa.Entity in object %s", t.StructName)
   284  	}
   285  
   286  	translateKeyName(t)
   287  	if err := t.EnsureValid(); err != nil {
   288  		return nil, errors.Wrap(err, "failed to parse dosa object")
   289  	}
   290  	return t, nil
   291  }
   292  
   293  func stringToDosaType(inType, pkg string) (Type, bool) {
   294  
   295  	// Append a dot if the package suffix doesn't already have one.
   296  	if pkg != "" && !strings.HasSuffix(pkg, ".") {
   297  		pkg += "."
   298  	}
   299  
   300  	switch inType {
   301  	case "string":
   302  		return String, false
   303  	case "[]byte":
   304  		return Blob, false
   305  	case "bool":
   306  		return Bool, false
   307  	case "int32":
   308  		return Int32, false
   309  	case "int64":
   310  		return Int64, false
   311  	case "float64":
   312  		return Double, false
   313  	case "time.Time":
   314  		return Timestamp, false
   315  	case "UUID", pkg + "UUID":
   316  		return TUUID, false
   317  	case "*string":
   318  		return String, true
   319  	case "*bool":
   320  		return Bool, true
   321  	case "*int32":
   322  		return Int32, true
   323  	case "*int64":
   324  		return Int64, true
   325  	case "*float64":
   326  		return Double, true
   327  	case "*time.Time":
   328  		return Timestamp, true
   329  	case "*UUID", "*" + pkg + "UUID":
   330  		return TUUID, true
   331  	default:
   332  		return Invalid, false
   333  	}
   334  }