github.com/onsi/ginkgo@v1.16.6-0.20211118180735-4e1925ba4c95/internal/node.go (about)

     1  package internal
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sort"
     7  
     8  	"sync"
     9  
    10  	"github.com/onsi/ginkgo/types"
    11  )
    12  
    13  var _global_node_id_counter = uint(0)
    14  var _global_id_mutex = &sync.Mutex{}
    15  
    16  func UniqueNodeID() uint {
    17  	//There's a reace in the internal integration tests if we don't make
    18  	//accessing _global_node_id_counter safe across goroutines.
    19  	_global_id_mutex.Lock()
    20  	defer _global_id_mutex.Unlock()
    21  	_global_node_id_counter += 1
    22  	return _global_node_id_counter
    23  }
    24  
    25  type Node struct {
    26  	ID       uint
    27  	NodeType types.NodeType
    28  
    29  	Text         string
    30  	Body         func()
    31  	CodeLocation types.CodeLocation
    32  	NestingLevel int
    33  
    34  	SynchronizedBeforeSuiteProc1Body    func() []byte
    35  	SynchronizedBeforeSuiteAllProcsBody func([]byte)
    36  
    37  	SynchronizedAfterSuiteAllProcsBody func()
    38  	SynchronizedAfterSuiteProc1Body    func()
    39  
    40  	ReportEachBody       func(types.SpecReport)
    41  	ReportAfterSuiteBody func(types.Report)
    42  
    43  	MarkedFocus   bool
    44  	MarkedPending bool
    45  	MarkedSerial  bool
    46  	MarkedOrdered bool
    47  	FlakeAttempts int
    48  	Labels        Labels
    49  
    50  	NodeIDWhereCleanupWasGenerated uint
    51  }
    52  
    53  // Decoration Types
    54  type focusType bool
    55  type pendingType bool
    56  type serialType bool
    57  type orderedType bool
    58  
    59  const Focus = focusType(true)
    60  const Pending = pendingType(true)
    61  const Serial = serialType(true)
    62  const Ordered = orderedType(true)
    63  
    64  type FlakeAttempts uint
    65  type Offset uint
    66  type Done chan<- interface{} // Deprecated Done Channel for asynchronous testing
    67  type Labels []string
    68  
    69  func UnionOfLabels(labels ...Labels) Labels {
    70  	out := Labels{}
    71  	seen := map[string]bool{}
    72  	for _, labelSet := range labels {
    73  		for _, label := range labelSet {
    74  			if !seen[label] {
    75  				seen[label] = true
    76  				out = append(out, label)
    77  			}
    78  		}
    79  	}
    80  	return out
    81  }
    82  
    83  func PartitionDecorations(args ...interface{}) ([]interface{}, []interface{}) {
    84  	decorations := []interface{}{}
    85  	remainingArgs := []interface{}{}
    86  	for _, arg := range args {
    87  		if isDecoration(arg) {
    88  			decorations = append(decorations, arg)
    89  		} else {
    90  			remainingArgs = append(remainingArgs, arg)
    91  		}
    92  	}
    93  	return decorations, remainingArgs
    94  }
    95  
    96  func isDecoration(arg interface{}) bool {
    97  	switch t := reflect.TypeOf(arg); {
    98  	case t == nil:
    99  		return false
   100  	case t == reflect.TypeOf(Offset(0)):
   101  		return true
   102  	case t == reflect.TypeOf(types.CodeLocation{}):
   103  		return true
   104  	case t == reflect.TypeOf(Focus):
   105  		return true
   106  	case t == reflect.TypeOf(Pending):
   107  		return true
   108  	case t == reflect.TypeOf(Serial):
   109  		return true
   110  	case t == reflect.TypeOf(Ordered):
   111  		return true
   112  	case t == reflect.TypeOf(FlakeAttempts(0)):
   113  		return true
   114  	case t == reflect.TypeOf(Labels{}):
   115  		return true
   116  	case t.Kind() == reflect.Slice && isSliceOfDecorations(arg):
   117  		return true
   118  	default:
   119  		return false
   120  	}
   121  }
   122  
   123  func isSliceOfDecorations(slice interface{}) bool {
   124  	vSlice := reflect.ValueOf(slice)
   125  	if vSlice.Len() == 0 {
   126  		return false
   127  	}
   128  	for i := 0; i < vSlice.Len(); i++ {
   129  		if !isDecoration(vSlice.Index(i).Interface()) {
   130  			return false
   131  		}
   132  	}
   133  	return true
   134  }
   135  
   136  func NewNode(deprecationTracker *types.DeprecationTracker, nodeType types.NodeType, text string, args ...interface{}) (Node, []error) {
   137  	baseOffset := 2
   138  	node := Node{
   139  		ID:           UniqueNodeID(),
   140  		NodeType:     nodeType,
   141  		Text:         text,
   142  		Labels:       Labels{},
   143  		CodeLocation: types.NewCodeLocation(baseOffset),
   144  		NestingLevel: -1,
   145  	}
   146  	errors := []error{}
   147  	appendError := func(err error) {
   148  		if err != nil {
   149  			errors = append(errors, err)
   150  		}
   151  	}
   152  
   153  	args = unrollInterfaceSlice(args)
   154  
   155  	remainingArgs := []interface{}{}
   156  	//First get the CodeLocation up-to-date
   157  	for _, arg := range args {
   158  		switch t := reflect.TypeOf(arg); {
   159  		case t == reflect.TypeOf(Offset(0)):
   160  			node.CodeLocation = types.NewCodeLocation(baseOffset + int(arg.(Offset)))
   161  		case t == reflect.TypeOf(types.CodeLocation{}):
   162  			node.CodeLocation = arg.(types.CodeLocation)
   163  		default:
   164  			remainingArgs = append(remainingArgs, arg)
   165  		}
   166  	}
   167  
   168  	labelsSeen := map[string]bool{}
   169  	trackedFunctionError := false
   170  	args = remainingArgs
   171  	remainingArgs = []interface{}{}
   172  	//now process the rest of the args
   173  	for _, arg := range args {
   174  		switch t := reflect.TypeOf(arg); {
   175  		case t == reflect.TypeOf(float64(0)):
   176  			break //ignore deprecated timeouts
   177  		case t == reflect.TypeOf(Focus):
   178  			node.MarkedFocus = bool(arg.(focusType))
   179  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   180  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Focus"))
   181  			}
   182  		case t == reflect.TypeOf(Pending):
   183  			node.MarkedPending = bool(arg.(pendingType))
   184  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   185  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Pending"))
   186  			}
   187  		case t == reflect.TypeOf(Serial):
   188  			node.MarkedSerial = bool(arg.(serialType))
   189  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   190  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Serial"))
   191  			}
   192  		case t == reflect.TypeOf(Ordered):
   193  			node.MarkedOrdered = bool(arg.(orderedType))
   194  			if !nodeType.Is(types.NodeTypeContainer) {
   195  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Ordered"))
   196  			}
   197  		case t == reflect.TypeOf(FlakeAttempts(0)):
   198  			node.FlakeAttempts = int(arg.(FlakeAttempts))
   199  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   200  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "FlakeAttempts"))
   201  			}
   202  		case t == reflect.TypeOf(Labels{}):
   203  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   204  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Label"))
   205  			}
   206  			for _, label := range arg.(Labels) {
   207  				if !labelsSeen[label] {
   208  					labelsSeen[label] = true
   209  					label, err := types.ValidateAndCleanupLabel(label, node.CodeLocation)
   210  					node.Labels = append(node.Labels, label)
   211  					appendError(err)
   212  				}
   213  			}
   214  		case t.Kind() == reflect.Func:
   215  			if node.Body != nil {
   216  				appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   217  				trackedFunctionError = true
   218  				break
   219  			}
   220  			isValid := (t.NumOut() == 0) && (t.NumIn() <= 1) && (t.NumIn() == 0 || t.In(0) == reflect.TypeOf(make(Done)))
   221  			if !isValid {
   222  				appendError(types.GinkgoErrors.InvalidBodyType(t, node.CodeLocation, nodeType))
   223  				trackedFunctionError = true
   224  				break
   225  			}
   226  			if t.NumIn() == 0 {
   227  				node.Body = arg.(func())
   228  			} else {
   229  				deprecationTracker.TrackDeprecation(types.Deprecations.Async(), node.CodeLocation)
   230  				deprecatedAsyncBody := arg.(func(Done))
   231  				node.Body = func() { deprecatedAsyncBody(make(Done)) }
   232  			}
   233  		default:
   234  			remainingArgs = append(remainingArgs, arg)
   235  		}
   236  	}
   237  
   238  	//validations
   239  	if node.MarkedPending && node.MarkedFocus {
   240  		appendError(types.GinkgoErrors.InvalidDeclarationOfFocusedAndPending(node.CodeLocation, nodeType))
   241  	}
   242  
   243  	if node.Body == nil && !node.MarkedPending && !trackedFunctionError {
   244  		appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
   245  	}
   246  	for _, arg := range remainingArgs {
   247  		appendError(types.GinkgoErrors.UnknownDecorator(node.CodeLocation, nodeType, arg))
   248  	}
   249  
   250  	if len(errors) > 0 {
   251  		return Node{}, errors
   252  	}
   253  
   254  	return node, errors
   255  }
   256  
   257  func NewSynchronizedBeforeSuiteNode(proc1Body func() []byte, allProcsBody func([]byte), codeLocation types.CodeLocation) (Node, []error) {
   258  	return Node{
   259  		ID:                                  UniqueNodeID(),
   260  		NodeType:                            types.NodeTypeSynchronizedBeforeSuite,
   261  		SynchronizedBeforeSuiteProc1Body:    proc1Body,
   262  		SynchronizedBeforeSuiteAllProcsBody: allProcsBody,
   263  		CodeLocation:                        codeLocation,
   264  	}, nil
   265  }
   266  
   267  func NewSynchronizedAfterSuiteNode(allProcsBody func(), proc1Body func(), codeLocation types.CodeLocation) (Node, []error) {
   268  	return Node{
   269  		ID:                                 UniqueNodeID(),
   270  		NodeType:                           types.NodeTypeSynchronizedAfterSuite,
   271  		SynchronizedAfterSuiteAllProcsBody: allProcsBody,
   272  		SynchronizedAfterSuiteProc1Body:    proc1Body,
   273  		CodeLocation:                       codeLocation,
   274  	}, nil
   275  }
   276  
   277  func NewReportBeforeEachNode(body func(types.SpecReport), codeLocation types.CodeLocation) (Node, []error) {
   278  	return Node{
   279  		ID:             UniqueNodeID(),
   280  		NodeType:       types.NodeTypeReportBeforeEach,
   281  		ReportEachBody: body,
   282  		CodeLocation:   codeLocation,
   283  		NestingLevel:   -1,
   284  	}, nil
   285  }
   286  
   287  func NewReportAfterEachNode(body func(types.SpecReport), codeLocation types.CodeLocation) (Node, []error) {
   288  	return Node{
   289  		ID:             UniqueNodeID(),
   290  		NodeType:       types.NodeTypeReportAfterEach,
   291  		ReportEachBody: body,
   292  		CodeLocation:   codeLocation,
   293  		NestingLevel:   -1,
   294  	}, nil
   295  }
   296  
   297  func NewReportAfterSuiteNode(text string, body func(types.Report), codeLocation types.CodeLocation) (Node, []error) {
   298  	return Node{
   299  		ID:                   UniqueNodeID(),
   300  		Text:                 text,
   301  		NodeType:             types.NodeTypeReportAfterSuite,
   302  		ReportAfterSuiteBody: body,
   303  		CodeLocation:         codeLocation,
   304  	}, nil
   305  }
   306  
   307  func NewCleanupNode(fail func(string, types.CodeLocation), args ...interface{}) (Node, []error) {
   308  	baseOffset := 2
   309  	node := Node{
   310  		ID:           UniqueNodeID(),
   311  		NodeType:     types.NodeTypeCleanupInvalid,
   312  		CodeLocation: types.NewCodeLocation(baseOffset),
   313  		NestingLevel: -1,
   314  	}
   315  	remainingArgs := []interface{}{}
   316  	for _, arg := range args {
   317  		switch t := reflect.TypeOf(arg); {
   318  		case t == reflect.TypeOf(Offset(0)):
   319  			node.CodeLocation = types.NewCodeLocation(baseOffset + int(arg.(Offset)))
   320  		case t == reflect.TypeOf(types.CodeLocation{}):
   321  			node.CodeLocation = arg.(types.CodeLocation)
   322  		default:
   323  			remainingArgs = append(remainingArgs, arg)
   324  		}
   325  	}
   326  
   327  	if len(remainingArgs) == 0 {
   328  		return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(node.CodeLocation)}
   329  	}
   330  	callback := reflect.ValueOf(remainingArgs[0])
   331  	if !(callback.Kind() == reflect.Func && callback.Type().NumOut() <= 1) {
   332  		return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(node.CodeLocation)}
   333  	}
   334  	callArgs := []reflect.Value{}
   335  	for _, arg := range remainingArgs[1:] {
   336  		callArgs = append(callArgs, reflect.ValueOf(arg))
   337  	}
   338  	cl := node.CodeLocation
   339  	node.Body = func() {
   340  		out := callback.Call(callArgs)
   341  		if len(out) == 1 && !out[0].IsNil() {
   342  			fail(fmt.Sprintf("DeferCleanup callback returned error: %v", out[0]), cl)
   343  		}
   344  	}
   345  
   346  	return node, nil
   347  }
   348  
   349  func (n Node) IsZero() bool {
   350  	return n.ID == 0
   351  }
   352  
   353  /* Nodes */
   354  type Nodes []Node
   355  
   356  func (n Nodes) CopyAppend(nodes ...Node) Nodes {
   357  	numN := len(n)
   358  	out := make(Nodes, numN+len(nodes))
   359  	for i, node := range n {
   360  		out[i] = node
   361  	}
   362  	for j, node := range nodes {
   363  		out[numN+j] = node
   364  	}
   365  	return out
   366  }
   367  
   368  func (n Nodes) SplitAround(pivot Node) (Nodes, Nodes) {
   369  	pivotIdx := len(n)
   370  	for i := range n {
   371  		if n[i].ID == pivot.ID {
   372  			pivotIdx = i
   373  			break
   374  		}
   375  	}
   376  	left := n[:pivotIdx]
   377  	right := Nodes{}
   378  	if pivotIdx+1 < len(n) {
   379  		right = n[pivotIdx+1:]
   380  	}
   381  
   382  	return left, right
   383  }
   384  
   385  func (n Nodes) FirstNodeWithType(nodeTypes types.NodeType) Node {
   386  	for i := range n {
   387  		if n[i].NodeType.Is(nodeTypes) {
   388  			return n[i]
   389  		}
   390  	}
   391  	return Node{}
   392  }
   393  
   394  func (n Nodes) WithType(nodeTypes types.NodeType) Nodes {
   395  	count := 0
   396  	for i := range n {
   397  		if n[i].NodeType.Is(nodeTypes) {
   398  			count++
   399  		}
   400  	}
   401  
   402  	out, j := make(Nodes, count), 0
   403  	for i := range n {
   404  		if n[i].NodeType.Is(nodeTypes) {
   405  			out[j] = n[i]
   406  			j++
   407  		}
   408  	}
   409  	return out
   410  }
   411  
   412  func (n Nodes) WithoutType(nodeTypes types.NodeType) Nodes {
   413  	count := 0
   414  	for i := range n {
   415  		if !n[i].NodeType.Is(nodeTypes) {
   416  			count++
   417  		}
   418  	}
   419  
   420  	out, j := make(Nodes, count), 0
   421  	for i := range n {
   422  		if !n[i].NodeType.Is(nodeTypes) {
   423  			out[j] = n[i]
   424  			j++
   425  		}
   426  	}
   427  	return out
   428  }
   429  
   430  func (n Nodes) WithoutNode(nodeToExclude Node) Nodes {
   431  	idxToExclude := len(n)
   432  	for i := range n {
   433  		if n[i].ID == nodeToExclude.ID {
   434  			idxToExclude = i
   435  			break
   436  		}
   437  	}
   438  	if idxToExclude == len(n) {
   439  		return n
   440  	}
   441  	out, j := make(Nodes, len(n)-1), 0
   442  	for i := range n {
   443  		if i == idxToExclude {
   444  			continue
   445  		}
   446  		out[j] = n[i]
   447  		j++
   448  	}
   449  	return out
   450  }
   451  
   452  func (n Nodes) Filter(filter func(Node) bool) Nodes {
   453  	trufa, count := make([]bool, len(n)), 0
   454  	for i := range n {
   455  		if filter(n[i]) {
   456  			trufa[i] = true
   457  			count += 1
   458  		}
   459  	}
   460  	out, j := make(Nodes, count), 0
   461  	for i := range n {
   462  		if trufa[i] {
   463  			out[j] = n[i]
   464  			j++
   465  		}
   466  	}
   467  	return out
   468  }
   469  
   470  func (n Nodes) WithinNestingLevel(deepestNestingLevel int) Nodes {
   471  	count := 0
   472  	for i := range n {
   473  		if n[i].NestingLevel <= deepestNestingLevel {
   474  			count++
   475  		}
   476  	}
   477  	out, j := make(Nodes, count), 0
   478  	for i := range n {
   479  		if n[i].NestingLevel <= deepestNestingLevel {
   480  			out[j] = n[i]
   481  			j++
   482  		}
   483  	}
   484  	return out
   485  }
   486  
   487  func (n Nodes) SortedByDescendingNestingLevel() Nodes {
   488  	out := make(Nodes, len(n))
   489  	copy(out, n)
   490  	sort.SliceStable(out, func(i int, j int) bool {
   491  		return out[i].NestingLevel > out[j].NestingLevel
   492  	})
   493  
   494  	return out
   495  }
   496  
   497  func (n Nodes) SortedByAscendingNestingLevel() Nodes {
   498  	out := make(Nodes, len(n))
   499  	copy(out, n)
   500  	sort.SliceStable(out, func(i int, j int) bool {
   501  		return out[i].NestingLevel < out[j].NestingLevel
   502  	})
   503  
   504  	return out
   505  }
   506  
   507  func (n Nodes) Reverse() Nodes {
   508  	out := make(Nodes, len(n))
   509  	for i := range n {
   510  		out[len(n)-1-i] = n[i]
   511  	}
   512  	return out
   513  }
   514  
   515  func (n Nodes) Texts() []string {
   516  	out := make([]string, len(n))
   517  	for i := range n {
   518  		out[i] = n[i].Text
   519  	}
   520  	return out
   521  }
   522  
   523  func (n Nodes) Labels() [][]string {
   524  	out := make([][]string, len(n))
   525  	for i := range n {
   526  		if n[i].Labels == nil {
   527  			out[i] = []string{}
   528  		} else {
   529  			out[i] = []string(n[i].Labels)
   530  		}
   531  	}
   532  	return out
   533  }
   534  
   535  func (n Nodes) UnionOfLabels() []string {
   536  	out := []string{}
   537  	seen := map[string]bool{}
   538  	for i := range n {
   539  		for _, label := range n[i].Labels {
   540  			if !seen[label] {
   541  				seen[label] = true
   542  				out = append(out, label)
   543  			}
   544  		}
   545  	}
   546  	return out
   547  }
   548  
   549  func (n Nodes) CodeLocations() []types.CodeLocation {
   550  	out := make([]types.CodeLocation, len(n))
   551  	for i := range n {
   552  		out[i] = n[i].CodeLocation
   553  	}
   554  	return out
   555  }
   556  
   557  func (n Nodes) BestTextFor(node Node) string {
   558  	if node.Text != "" {
   559  		return node.Text
   560  	}
   561  	parentNestingLevel := node.NestingLevel - 1
   562  	for i := range n {
   563  		if n[i].Text != "" && n[i].NestingLevel == parentNestingLevel {
   564  			return n[i].Text
   565  		}
   566  	}
   567  
   568  	return ""
   569  }
   570  
   571  func (n Nodes) ContainsNodeID(id uint) bool {
   572  	for i := range n {
   573  		if n[i].ID == id {
   574  			return true
   575  		}
   576  	}
   577  	return false
   578  }
   579  
   580  func (n Nodes) HasNodeMarkedPending() bool {
   581  	for i := range n {
   582  		if n[i].MarkedPending {
   583  			return true
   584  		}
   585  	}
   586  	return false
   587  }
   588  
   589  func (n Nodes) HasNodeMarkedFocus() bool {
   590  	for i := range n {
   591  		if n[i].MarkedFocus {
   592  			return true
   593  		}
   594  	}
   595  	return false
   596  }
   597  
   598  func (n Nodes) HasNodeMarkedSerial() bool {
   599  	for i := range n {
   600  		if n[i].MarkedSerial {
   601  			return true
   602  		}
   603  	}
   604  	return false
   605  }
   606  
   607  func (n Nodes) FirstNodeMarkedOrdered() Node {
   608  	for i := range n {
   609  		if n[i].MarkedOrdered {
   610  			return n[i]
   611  		}
   612  	}
   613  	return Node{}
   614  }
   615  
   616  func unrollInterfaceSlice(args interface{}) []interface{} {
   617  	v := reflect.ValueOf(args)
   618  	if v.Kind() != reflect.Slice {
   619  		return []interface{}{args}
   620  	}
   621  	out := []interface{}{}
   622  	for i := 0; i < v.Len(); i++ {
   623  		el := reflect.ValueOf(v.Index(i).Interface())
   624  		if el.Kind() == reflect.Slice && el.Type() != reflect.TypeOf(Labels{}) {
   625  			out = append(out, unrollInterfaceSlice(el.Interface())...)
   626  		} else {
   627  			out = append(out, v.Index(i).Interface())
   628  		}
   629  	}
   630  	return out
   631  }