
     1  /*
     2  Copyright 2017-2018 Mirantis
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    17  // Some changes are merged back from
    18  //
    20  package ginkgoext
    22  import (
    23  	"bytes"
    24  	"flag"
    25  	"os"
    26  	"reflect"
    27  	"regexp"
    28  	"strings"
    29  	"sync/atomic"
    31  	""
    32  	""
    33  )
    35  type scope struct {
    36  	parent        *scope
    37  	children      []*scope
    38  	counter       int32
    39  	before        []func()
    40  	after         []func()
    41  	afterEach     []func()
    42  	justAfterEach []func()
    43  	afterFail     []func()
    44  	started       int32
    45  	failed        bool
    46  	normalTests   int
    47  	focusedTests  int
    48  	focused       bool
    49  }
    51  var (
    52  	currentScope        = &scope{}
    53  	rootScope           = currentScope
    54  	countersInitialized bool
    56  	Context                               = wrapContextFunc(ginkgo.Context, false)
    57  	FContext                              = wrapContextFunc(ginkgo.FContext, true)
    58  	PContext                              = wrapNilContextFunc(ginkgo.PContext)
    59  	XContext                              = wrapNilContextFunc(ginkgo.XContext)
    60  	Describe                              = wrapContextFunc(ginkgo.Describe, false)
    61  	FDescribe                             = wrapContextFunc(ginkgo.FDescribe, true)
    62  	PDescribe                             = wrapNilContextFunc(ginkgo.PDescribe)
    63  	XDescribe                             = wrapNilContextFunc(ginkgo.XDescribe)
    64  	It                                    = wrapItFunc(ginkgo.It, false)
    65  	FIt                                   = wrapItFunc(ginkgo.FIt, true)
    66  	PIt                                   = ginkgo.PIt
    67  	XIt                                   = ginkgo.XIt
    68  	By                                    = ginkgo.By
    69  	JustBeforeEach                        = ginkgo.JustBeforeEach
    70  	BeforeSuite                           = ginkgo.BeforeSuite
    71  	AfterSuite                            = ginkgo.AfterSuite
    72  	Skip                                  = ginkgo.Skip
    73  	Fail                                  = ginkgo.Fail
    74  	CurrentGinkgoTestDescription          = ginkgo.CurrentGinkgoTestDescription
    75  	GinkgoRecover                         = ginkgo.GinkgoRecover
    76  	GinkgoT                               = ginkgo.GinkgoT
    77  	RunSpecs                              = ginkgo.RunSpecs
    78  	RunSpecsWithCustomReporters           = ginkgo.RunSpecsWithCustomReporters
    79  	RunSpecsWithDefaultAndCustomReporters = ginkgo.RunSpecsWithDefaultAndCustomReporters
    80  )
    82  type Done ginkgo.Done
    84  func init() {
    85  	// Only use the Ginkgo options and discard all other options
    86  	args := []string{}
    87  	for _, arg := range os.Args[1:] {
    88  		if strings.Contains(arg, "-ginkgo") {
    89  			args = append(args, arg)
    90  		}
    91  	}
    93  	//Get GinkgoConfig flags
    94  	commandFlags := flag.NewFlagSet("ginkgo", flag.ContinueOnError)
    95  	commandFlags.SetOutput(new(bytes.Buffer))
    97  	config.Flags(commandFlags, "ginkgo", true)
    98  	commandFlags.Parse(args)
    99  }
   101  // BeforeAll runs the function once before any test in context
   102  func BeforeAll(body func()) bool {
   103  	if currentScope != nil {
   104  		if body == nil {
   105  			currentScope.before = nil
   106  			return true
   107  		}
   108  		currentScope.before = append(currentScope.before, body)
   109  		return BeforeEach(func() {})
   110  	}
   111  	return true
   112  }
   114  // AfterAll runs the function once after any test in context
   115  func AfterAll(body func()) bool {
   116  	if currentScope != nil {
   117  		if body == nil {
   118  			currentScope.before = nil
   119  			return true
   120  		}
   121  		currentScope.after = append(currentScope.after, body)
   122  		return AfterEach(func() {})
   123  	}
   124  	return true
   125  }
   127  // JustAfterEach runs the function just after each test, before all AfterEach,
   128  // AfterFailed and AfterAll
   129  func JustAfterEach(body func()) bool {
   130  	if currentScope != nil {
   131  		if body == nil {
   132  			currentScope.before = nil
   133  			return true
   134  		}
   135  		currentScope.justAfterEach = append(currentScope.justAfterEach, body)
   136  		return AfterEach(func() {})
   137  	}
   138  	return true
   139  }
   141  // JustAfterFailed runs the function after test and JustAfterEach if the test
   142  // has failed and before all AfterEach
   143  func AfterFailed(body func()) bool {
   144  	if currentScope != nil {
   145  		if body == nil {
   146  			currentScope.before = nil
   147  			return true
   148  		}
   149  		currentScope.afterFail = append(currentScope.afterFail, body)
   150  		return AfterEach(func() {})
   151  	}
   152  	return true
   153  }
   155  // justAfterEachStatus map to store what `justAfterEach` functions have been
   156  // already executed for the given test
   157  var justAfterEachStatus map[string]bool = map[string]bool{}
   159  // runAllJustAfterEach runs all the `scope.justAfterEach` functions for the
   160  // given scope and parent scopes. This function make sure that all the
   161  // `JustAfterEach` functions are called before AfterEach functions.
   162  func runAllJustAfterEach(cs *scope, testName string) {
   163  	if _, ok := justAfterEachStatus[testName]; ok {
   164  		// JustAfterEach calls are already executed in the children
   165  		return
   166  	}
   168  	for _, body := range cs.justAfterEach {
   169  		body()
   170  	}
   172  	if cs.parent != nil {
   173  		runAllJustAfterEach(cs.parent, testName)
   174  	}
   175  }
   177  // afterFailedStatus map to store what `AfterFail` functions have been
   178  // already executed for the given test.
   179  var afterFailedStatus map[string]bool = map[string]bool{}
   181  // runAllAfterFail runs all the afterFail functions for the given
   182  // scope and parent scopes. This function make sure that all the `AfterFail`
   183  // functions are called before AfterEach.
   184  func runAllAfterFail(cs *scope, testName string) {
   185  	if _, ok := afterFailedStatus[testName]; ok {
   186  		// AfterFailcalls are already executed in the children
   187  		return
   188  	}
   190  	for _, body := range cs.afterFail {
   191  		if ginkgo.CurrentGinkgoTestDescription().Failed {
   192  			body()
   193  		}
   194  	}
   196  	if cs.parent != nil {
   197  		runAllAfterFail(cs.parent, testName)
   198  	}
   199  }
   201  // RunAfterEach is a wrapper that executes all AfterEach functions that are
   202  // stored in cs.afterEach array.
   203  func RunAfterEach(cs *scope) {
   204  	if cs == nil {
   205  		return
   206  	}
   207  	testName := ginkgo.CurrentGinkgoTestDescription().FullTestText
   208  	runAllJustAfterEach(cs, testName)
   209  	justAfterEachStatus[testName] = true
   211  	runAllAfterFail(cs, testName)
   212  	afterFailedStatus[testName] = true
   214  	for _, body := range cs.afterEach {
   215  		body()
   216  	}
   218  	// Only run afterAll when all the counters are 0 and all afterEach are executed
   219  	after := func() {
   220  		if cs.counter == 0 && cs.after != nil {
   221  			for _, after := range cs.after {
   222  				after()
   223  			}
   224  		}
   225  	}
   226  	after()
   227  }
   229  // AfterEach runs the function after each test in context
   230  func AfterEach(body func(), timeout ...float64) bool {
   231  	if currentScope == nil {
   232  		return ginkgo.AfterEach(body, timeout...)
   233  	}
   234  	cs := currentScope
   235  	result := true
   236  	if cs.afterEach == nil {
   237  		// If no scope, register only one AfterEach in the scope, after that
   238  		// RunAfterEeach will run all afterEach functions registered in the
   239  		// scope.
   240  		fn := func() {
   241  			RunAfterEach(cs)
   242  		}
   243  		result = ginkgo.AfterEach(fn, timeout...)
   244  	}
   245  	cs.afterEach = append(cs.afterEach, body)
   246  	return result
   247  }
   249  // BeforeEach runs the function before each test in context
   250  func BeforeEach(body interface{}, timeout ...float64) bool {
   251  	if currentScope == nil {
   252  		return ginkgo.BeforeEach(body, timeout...)
   253  	}
   254  	cs := currentScope
   255  	before := func() {
   256  		if atomic.CompareAndSwapInt32(&cs.started, 0, 1) && cs.before != nil {
   257  			defer func() {
   258  				if r := recover(); r != nil {
   259  					cs.failed = true
   260  					panic(r)
   261  				}
   262  			}()
   263  			for _, before := range cs.before {
   264  				before()
   265  			}
   266  		} else if cs.failed {
   267  			Fail("failed due to BeforeAll failure")
   268  		}
   269  	}
   270  	return ginkgo.BeforeEach(applyAdvice(body, before, nil), timeout...)
   271  }
   273  func wrapContextFunc(fn func(string, func()) bool, focused bool) func(string, func()) bool {
   274  	return func(text string, body func()) bool {
   275  		if currentScope == nil {
   276  			return fn(text, body)
   277  		}
   278  		newScope := &scope{parent: currentScope, focused: focused}
   279  		currentScope.children = append(currentScope.children, newScope)
   280  		currentScope = newScope
   281  		res := fn(text, body)
   282  		currentScope = currentScope.parent
   283  		return res
   284  	}
   285  }
   287  func wrapNilContextFunc(fn func(string, func()) bool) func(string, func()) bool {
   288  	return func(text string, body func()) bool {
   289  		oldScope := currentScope
   290  		currentScope = nil
   291  		res := fn(text, body)
   292  		currentScope = oldScope
   293  		return res
   294  	}
   295  }
   297  func wrapItFunc(fn func(string, interface{}, ...float64) bool, focused bool) func(string, interface{}, ...float64) bool {
   298  	if !countersInitialized {
   299  		countersInitialized = true
   300  		BeforeSuite(func() {
   301  			calculateCounters(rootScope, false)
   302  		})
   303  	}
   304  	return func(text string, body interface{}, timeout ...float64) bool {
   305  		if currentScope == nil {
   306  			return fn(text, body, timeout...)
   307  		}
   308  		if focused || isTestFocused(text) {
   309  			currentScope.focusedTests++
   310  		} else {
   311  			currentScope.normalTests++
   312  		}
   313  		return fn(text, wrapTest(body), timeout...)
   314  	}
   315  }
   317  // isTestFocused checks the value of FocusString and return true if the given
   318  // text name is focussed, returns false if the test is not focussed.
   319  func isTestFocused(text string) bool {
   320  	if config.GinkgoConfig.FocusString == "" && config.GinkgoConfig.SkipString == "" {
   321  		return false
   322  	}
   324  	skipFilter := regexp.MustCompile(config.GinkgoConfig.SkipString)
   325  	if skipFilter.Match([]byte(text)) {
   326  		return false
   327  	}
   329  	focusFilter := regexp.MustCompile(config.GinkgoConfig.FocusString)
   330  	return focusFilter.Match([]byte(text))
   331  }
   333  func applyAdvice(f interface{}, before, after func()) interface{} {
   334  	fn := reflect.ValueOf(f)
   335  	template := func(in []reflect.Value) []reflect.Value {
   336  		if before != nil {
   337  			before()
   338  		}
   339  		if after != nil {
   340  			defer after()
   341  		}
   342  		return fn.Call(in)
   343  	}
   344  	v := reflect.MakeFunc(fn.Type(), template)
   345  	return v.Interface()
   346  }
   348  func wrapTest(f interface{}) interface{} {
   349  	cs := currentScope
   350  	after := func() {
   351  		for cs != nil {
   352  			atomic.AddInt32(&cs.counter, -1)
   353  			cs = cs.parent
   354  		}
   355  	}
   356  	return applyAdvice(f, nil, after)
   357  }
   359  func calculateCounters(s *scope, focusedOnly bool) (int, bool) {
   360  	count := s.focusedTests
   361  	haveFocused := s.focusedTests > 0
   362  	var focusedChildren int
   363  	for _, child := range s.children {
   364  		if child.focused {
   365  			c, _ := calculateCounters(child, false)
   366  			focusedChildren += c
   367  		}
   368  	}
   369  	if focusedChildren > 0 {
   370  		haveFocused = true
   371  		count += focusedChildren
   372  	}
   373  	var normalChildren int
   374  	for _, child := range s.children {
   375  		if !child.focused {
   376  			c, f := calculateCounters(child, focusedOnly || haveFocused)
   377  			if f {
   378  				haveFocused = true
   379  				count += c
   380  			} else {
   381  				normalChildren += c
   382  			}
   383  		}
   384  	}
   385  	if !focusedOnly && !haveFocused {
   386  		count += s.normalTests + normalChildren
   387  	}
   388  	s.counter = int32(count)
   389  	return count, haveFocused
   390  }