github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/dslengine/runner.go (about)

     1  package dslengine
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path/filepath"
     7  	"reflect"
     8  	"runtime"
     9  	"strings"
    10  )
    11  
    12  var (
    13  	// Errors contains the DSL execution errors if any.
    14  	Errors MultiError
    15  
    16  	// Global DSL evaluation stack
    17  	ctxStack contextStack
    18  
    19  	// Registered DSL roots
    20  	roots []Root
    21  
    22  	// DSL package paths used to compute error locations (skip the frames in these packages)
    23  	dslPackages map[string]bool
    24  )
    25  
    26  type (
    27  	// Error represents an error that occurred while running the API DSL.
    28  	// It contains the name of the file and line number of where the error
    29  	// occurred as well as the original Go error.
    30  	Error struct {
    31  		GoError error
    32  		File    string
    33  		Line    int
    34  	}
    35  
    36  	// MultiError collects all DSL errors. It implements error.
    37  	MultiError []*Error
    38  
    39  	// DSL evaluation contexts stack
    40  	contextStack []Definition
    41  )
    42  
    43  func init() {
    44  	dslPackages = map[string]bool{
    45  		"github.com/goadesign/goa/":            true,
    46  		"github.com/goadesign/goa/middleware/": true,
    47  		"github.com/goadesign/goa/encoding/":   true,
    48  		"github.com/goadesign/goa/logging/":    true,
    49  	}
    50  }
    51  
    52  // Register adds a DSL Root to be executed by Run.
    53  func Register(r Root) {
    54  	for _, o := range roots {
    55  		if r.DSLName() == o.DSLName() {
    56  			fmt.Fprintf(os.Stderr, "goagen: duplicate DSL %s", r.DSLName())
    57  			os.Exit(1)
    58  		}
    59  	}
    60  	t := reflect.TypeOf(r)
    61  	if t.Kind() == reflect.Ptr {
    62  		t = t.Elem()
    63  	}
    64  	dslPackages[t.PkgPath()] = true
    65  	roots = append(roots, r)
    66  }
    67  
    68  // Reset uses the registered RootFuncs to re-initialize the DSL roots.
    69  // This is useful to tests.
    70  func Reset() {
    71  	for _, r := range roots {
    72  		r.Reset()
    73  	}
    74  	Errors = nil
    75  }
    76  
    77  // Run runs the given root definitions. It iterates over the definition sets
    78  // multiple times to first execute the DSL, the validate the resulting
    79  // definitions and finally finalize them. The executed DSL may register new
    80  // roots to have them be executed (last) in the same run.
    81  func Run() error {
    82  	if len(roots) == 0 {
    83  		return nil
    84  	}
    85  	roots, err := SortRoots()
    86  	if err != nil {
    87  		return err
    88  	}
    89  	Errors = nil
    90  	executed := 0
    91  	recursed := 0
    92  	for executed < len(roots) {
    93  		recursed++
    94  		start := executed
    95  		executed = len(roots)
    96  		for _, root := range roots[start:] {
    97  			root.IterateSets(runSet)
    98  		}
    99  		if recursed > 100 {
   100  			// Let's cross that bridge once we get there
   101  			return fmt.Errorf("too many generated roots, infinite loop?")
   102  		}
   103  	}
   104  	if Errors != nil {
   105  		return Errors
   106  	}
   107  	for _, root := range roots {
   108  		root.IterateSets(validateSet)
   109  	}
   110  	if Errors != nil {
   111  		return Errors
   112  	}
   113  	for _, root := range roots {
   114  		root.IterateSets(finalizeSet)
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  // Execute runs the given DSL to initialize the given definition. It returns true on success.
   121  // It returns false and appends to Errors on failure.
   122  // Note that `Run` takes care of calling `Execute` on all definitions that implement Source.
   123  // This function is intended for use by definitions that run the DSL at declaration time rather than
   124  // store the DSL for execution by the dsl engine (usually simple independent definitions).
   125  // The DSL should use ReportError to record DSL execution errors.
   126  func Execute(dsl func(), def Definition) bool {
   127  	if dsl == nil {
   128  		return true
   129  	}
   130  	initCount := len(Errors)
   131  	ctxStack = append(ctxStack, def)
   132  	dsl()
   133  	ctxStack = ctxStack[:len(ctxStack)-1]
   134  	return len(Errors) <= initCount
   135  }
   136  
   137  // CurrentDefinition returns the definition whose initialization DSL is currently being executed.
   138  func CurrentDefinition() Definition {
   139  	current := ctxStack.Current()
   140  	if current == nil {
   141  		return &TopLevelDefinition{}
   142  	}
   143  	return current
   144  }
   145  
   146  // IsTopLevelDefinition returns true if the currently evaluated DSL is a root
   147  // DSL (i.e. is not being run in the context of another definition).
   148  func IsTopLevelDefinition() bool {
   149  	_, ok := CurrentDefinition().(*TopLevelDefinition)
   150  	return ok
   151  }
   152  
   153  // TopLevelDefinition represents the top-level file definitions, done
   154  // with `var _ = `.  An instance of this object is returned by
   155  // `CurrentDefinition()` when at the top-level.
   156  type TopLevelDefinition struct{}
   157  
   158  // Context tells the DSL engine which context we're in when showing
   159  // errors.
   160  func (t *TopLevelDefinition) Context() string { return "top-level" }
   161  
   162  // ReportError records a DSL error for reporting post DSL execution.
   163  func ReportError(fm string, vals ...interface{}) {
   164  	var suffix string
   165  	if cur := ctxStack.Current(); cur != nil {
   166  		if ctx := cur.Context(); ctx != "" {
   167  			suffix = fmt.Sprintf(" in %s", ctx)
   168  		}
   169  	} else {
   170  		suffix = " (top level)"
   171  	}
   172  	err := fmt.Errorf(fm+suffix, vals...)
   173  	file, line := computeErrorLocation()
   174  	Errors = append(Errors, &Error{
   175  		GoError: err,
   176  		File:    file,
   177  		Line:    line,
   178  	})
   179  }
   180  
   181  // FailOnError will exit with code 1 if `err != nil`. This function
   182  // will handle properly the MultiError this dslengine provides.
   183  func FailOnError(err error) {
   184  	if merr, ok := err.(MultiError); ok {
   185  		if len(merr) == 0 {
   186  			return
   187  		}
   188  		fmt.Fprintf(os.Stderr, merr.Error())
   189  		os.Exit(1)
   190  	}
   191  	if err != nil {
   192  		fmt.Fprintf(os.Stderr, err.Error())
   193  		os.Exit(1)
   194  	}
   195  }
   196  
   197  // PrintFilesOrFail will print the file list. Use it with a
   198  // generator's `Generate()` function to output the generated list of
   199  // files or quit on error.
   200  func PrintFilesOrFail(files []string, err error) {
   201  	FailOnError(err)
   202  	fmt.Println(strings.Join(files, "\n"))
   203  }
   204  
   205  // IncompatibleDSL should be called by DSL functions when they are
   206  // invoked in an incorrect context (e.g. "Params" in "Resource").
   207  func IncompatibleDSL() {
   208  	elems := strings.Split(caller(), ".")
   209  	ReportError("invalid use of %s", elems[len(elems)-1])
   210  }
   211  
   212  // InvalidArgError records an invalid argument error.
   213  // It is used by DSL functions that take dynamic arguments.
   214  func InvalidArgError(expected string, actual interface{}) {
   215  	ReportError("cannot use %#v (type %s) as type %s",
   216  		actual, reflect.TypeOf(actual), expected)
   217  }
   218  
   219  // Error returns the error message.
   220  func (m MultiError) Error() string {
   221  	msgs := make([]string, len(m))
   222  	for i, de := range m {
   223  		msgs[i] = de.Error()
   224  	}
   225  	return strings.Join(msgs, "\n")
   226  }
   227  
   228  // Error returns the underlying error message.
   229  func (de *Error) Error() string {
   230  	if err := de.GoError; err != nil {
   231  		if de.File == "" {
   232  			return err.Error()
   233  		}
   234  		return fmt.Sprintf("[%s:%d] %s", de.File, de.Line, err.Error())
   235  	}
   236  	return ""
   237  }
   238  
   239  // Current evaluation context, i.e. object being currently built by DSL
   240  func (s contextStack) Current() Definition {
   241  	if len(s) == 0 {
   242  		return nil
   243  	}
   244  	return s[len(s)-1]
   245  }
   246  
   247  // computeErrorLocation implements a heuristic to find the location in the user
   248  // code where the error occurred. It walks back the callstack until the file
   249  // doesn't match "/goa/design/*.go" or one of the DSL package paths.
   250  // When successful it returns the file name and line number, empty string and
   251  // 0 otherwise.
   252  func computeErrorLocation() (file string, line int) {
   253  	skipFunc := func(file string) bool {
   254  		if strings.HasSuffix(file, "_test.go") { // Be nice with tests
   255  			return false
   256  		}
   257  		file = filepath.ToSlash(file)
   258  		for pkg := range dslPackages {
   259  			if strings.Contains(file, pkg) {
   260  				return true
   261  			}
   262  		}
   263  		return false
   264  	}
   265  	depth := 2
   266  	_, file, line, _ = runtime.Caller(depth)
   267  	for skipFunc(file) {
   268  		depth++
   269  		_, file, line, _ = runtime.Caller(depth)
   270  	}
   271  	wd, err := os.Getwd()
   272  	if err != nil {
   273  		return
   274  	}
   275  	wd, err = filepath.Abs(wd)
   276  	if err != nil {
   277  		return
   278  	}
   279  	f, err := filepath.Rel(wd, file)
   280  	if err != nil {
   281  		return
   282  	}
   283  	file = f
   284  	return
   285  }
   286  
   287  // runSet executes the DSL for all definitions in the given set. The definition DSLs may append to
   288  // the set as they execute.
   289  func runSet(set DefinitionSet) error {
   290  	executed := 0
   291  	recursed := 0
   292  	for executed < len(set) {
   293  		recursed++
   294  		for _, def := range set[executed:] {
   295  			executed++
   296  			if source, ok := def.(Source); ok {
   297  				if dsl := source.DSL(); dsl != nil {
   298  					Execute(dsl, source)
   299  				}
   300  			}
   301  		}
   302  		if recursed > 100 {
   303  			return fmt.Errorf("too many generated definitions, infinite loop?")
   304  		}
   305  	}
   306  	return nil
   307  }
   308  
   309  // validateSet runs the validation on all the set definitions that define one.
   310  func validateSet(set DefinitionSet) error {
   311  	errors := &ValidationErrors{}
   312  	for _, def := range set {
   313  		if validate, ok := def.(Validate); ok {
   314  			if err := validate.Validate(); err != nil {
   315  				errors.AddError(def, err)
   316  			}
   317  		}
   318  	}
   319  	err := errors.AsError()
   320  	if err != nil {
   321  		Errors = append(Errors, &Error{GoError: err})
   322  	}
   323  	return err
   324  }
   325  
   326  // finalizeSet runs the validation on all the set definitions that define one.
   327  func finalizeSet(set DefinitionSet) error {
   328  	for _, def := range set {
   329  		if finalize, ok := def.(Finalize); ok {
   330  			finalize.Finalize()
   331  		}
   332  	}
   333  	return nil
   334  }
   335  
   336  // SortRoots orders the DSL roots making sure dependencies are last. It returns an error if there
   337  // is a dependency cycle.
   338  func SortRoots() ([]Root, error) {
   339  	if len(roots) == 0 {
   340  		return nil, nil
   341  	}
   342  	// First flatten dependencies for each root
   343  	rootDeps := make(map[string][]Root, len(roots))
   344  	rootByName := make(map[string]Root, len(roots))
   345  	for _, r := range roots {
   346  		sorted := sortDependencies(r, func(r Root) []Root { return r.DependsOn() })
   347  		length := len(sorted)
   348  		for i := 0; i < length/2; i++ {
   349  			sorted[i], sorted[length-i-1] = sorted[length-i-1], sorted[i]
   350  		}
   351  		rootDeps[r.DSLName()] = sorted
   352  		rootByName[r.DSLName()] = r
   353  	}
   354  	// Now check for cycles
   355  	for name, deps := range rootDeps {
   356  		root := rootByName[name]
   357  		for otherName, otherdeps := range rootDeps {
   358  			other := rootByName[otherName]
   359  			if root.DSLName() == other.DSLName() {
   360  				continue
   361  			}
   362  			dependsOnOther := false
   363  			for _, dep := range deps {
   364  				if dep.DSLName() == other.DSLName() {
   365  					dependsOnOther = true
   366  					break
   367  				}
   368  			}
   369  			if dependsOnOther {
   370  				for _, dep := range otherdeps {
   371  					if dep.DSLName() == root.DSLName() {
   372  						return nil, fmt.Errorf("dependency cycle: %s and %s depend on each other (directly or not)",
   373  							root.DSLName(), other.DSLName())
   374  					}
   375  				}
   376  			}
   377  		}
   378  	}
   379  	// Now sort top level DSLs
   380  	var sorted []Root
   381  	for _, r := range roots {
   382  		s := sortDependencies(r, func(r Root) []Root { return rootDeps[r.DSLName()] })
   383  		for _, s := range s {
   384  			found := false
   385  			for _, r := range sorted {
   386  				if r.DSLName() == s.DSLName() {
   387  					found = true
   388  					break
   389  				}
   390  			}
   391  			if !found {
   392  				sorted = append(sorted, s)
   393  			}
   394  		}
   395  	}
   396  	return sorted, nil
   397  }
   398  
   399  // sortDependencies sorts the depencies of the given root in the given slice.
   400  func sortDependencies(root Root, depFunc func(Root) []Root) []Root {
   401  	seen := make(map[string]bool, len(roots))
   402  	var sorted []Root
   403  	sortDependenciesR(root, seen, &sorted, depFunc)
   404  	return sorted
   405  }
   406  
   407  // sortDependenciesR sorts the depencies of the given root in the given slice.
   408  func sortDependenciesR(root Root, seen map[string]bool, sorted *[]Root, depFunc func(Root) []Root) {
   409  	for _, dep := range depFunc(root) {
   410  		if !seen[dep.DSLName()] {
   411  			seen[root.DSLName()] = true
   412  			sortDependenciesR(dep, seen, sorted, depFunc)
   413  		}
   414  	}
   415  	*sorted = append(*sorted, root)
   416  }
   417  
   418  // caller returns the name of calling function.
   419  func caller() string {
   420  	pc, file, _, ok := runtime.Caller(2)
   421  	if ok && filepath.Base(file) == "current.go" {
   422  		pc, _, _, ok = runtime.Caller(3)
   423  	}
   424  	if !ok {
   425  		return "<unknown>"
   426  	}
   427  
   428  	return runtime.FuncForPC(pc).Name()
   429  }