github.com/hugelgupf/u-root@v0.0.0-20191023214958-4807c632154c/cmds/core/elvish/parse/check_ast_test.go (about)

     1  package parse
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  	"unicode"
     8  	"unicode/utf8"
     9  )
    10  
    11  // AST checking utilities. Used in test cases.
    12  
    13  // ast is an AST specification. The name part identifies the type of the Node;
    14  // for instance, "Chunk" specifies a Chunk. The fields part is specifies children
    15  // to check; see document of fs.
    16  //
    17  // When a Node contains exactly one child, It can be coalesced with its child
    18  // by adding "/ChildName" in the name part. For instance, "Chunk/Pipeline"
    19  // specifies a Chunk that contains exactly one Pipeline. In this case, the
    20  // fields part specified the children of the Pipeline instead of the Chunk
    21  // (which has no additional interesting fields anyway). Multi-level coalescence
    22  // like "Chunk/Pipeline/Form" is also allowed.
    23  //
    24  // The dynamic type of the Node being checked is assumed to be a pointer to a
    25  // struct that embeds the "node" struct.
    26  type ast struct {
    27  	name   string
    28  	fields fs
    29  }
    30  
    31  // fs specifies fields of a Node to check. For the value of field $f in the
    32  // Node ("found value"), fs[$f] ("wanted value") is used to check against it.
    33  //
    34  // If the key is "text", the SourceText of the Node is checked. It doesn't
    35  // involve a found value.
    36  //
    37  // If the wanted value is nil, the found value is checked against nil.
    38  //
    39  // If the found value implements Node, then the wanted value must be either an
    40  // ast, where the checking algorithm of ast applies, or a string, where the
    41  // source text of the found value is checked.
    42  //
    43  // If the found value is a slice whose elements implement Node, then the wanted
    44  // value must be a slice where checking is then done recursively.
    45  //
    46  // If the found value satisfied none of the above conditions, it is checked
    47  // against the wanted value using reflect.DeepEqual.
    48  type fs map[string]interface{}
    49  
    50  // checkAST checks an AST against a specification.
    51  func checkAST(n Node, want ast) error {
    52  	wantnames := strings.Split(want.name, "/")
    53  	// Check coalesced levels
    54  	for i, wantname := range wantnames {
    55  		name := reflect.TypeOf(n).Elem().Name()
    56  		if wantname != name {
    57  			return fmt.Errorf("want %s, got %s (%s)", wantname, name, summary(n))
    58  		}
    59  		if i == len(wantnames)-1 {
    60  			break
    61  		}
    62  		fields := n.Children()
    63  		if len(fields) != 1 {
    64  			return fmt.Errorf("want exactly 1 child, got %d (%s)", len(fields), summary(n))
    65  		}
    66  		n = fields[0]
    67  	}
    68  
    69  	ntype := reflect.TypeOf(n).Elem()
    70  	nvalue := reflect.ValueOf(n).Elem()
    71  
    72  	for i := 0; i < ntype.NumField(); i++ {
    73  		fieldname := ntype.Field(i).Name
    74  		if !exported(fieldname) {
    75  			// Unexported field
    76  			continue
    77  		}
    78  		got := nvalue.Field(i).Interface()
    79  		want, ok := want.fields[fieldname]
    80  		if ok {
    81  			err := checkField(got, want, "field "+fieldname+" of: "+summary(n))
    82  			if err != nil {
    83  				return err
    84  			}
    85  		} else {
    86  			// Not specified. Check if got is a zero value of its type.
    87  			if !reflect.DeepEqual(got, reflect.Zero(reflect.TypeOf(got)).Interface()) {
    88  				return fmt.Errorf("want zero, got %v (field %s of: %s)", got, fieldname, summary(n))
    89  			}
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  var nodeType = reflect.TypeOf((*Node)(nil)).Elem()
    97  
    98  // checkField checks a field against a field specification.
    99  func checkField(got interface{}, want interface{}, ctx string) error {
   100  	// Want nil.
   101  	if want == nil {
   102  		if !reflect.ValueOf(got).IsNil() {
   103  			return fmt.Errorf("want nil, got %v (%s)", got, ctx)
   104  		}
   105  		return nil
   106  	}
   107  
   108  	if got, ok := got.(Node); ok {
   109  		// Got a Node.
   110  		return checkNodeInField(got.(Node), want)
   111  	}
   112  	tgot := reflect.TypeOf(got)
   113  	if tgot.Kind() == reflect.Slice && tgot.Elem().Implements(nodeType) {
   114  		// Got a slice of Nodes.
   115  		vgot := reflect.ValueOf(got)
   116  		vwant := reflect.ValueOf(want)
   117  		if vgot.Len() != vwant.Len() {
   118  			return fmt.Errorf("want %d, got %d (%s)", vwant.Len(), vgot.Len(), ctx)
   119  		}
   120  		for i := 0; i < vgot.Len(); i++ {
   121  			err := checkNodeInField(vgot.Index(i).Interface().(Node),
   122  				vwant.Index(i).Interface())
   123  			if err != nil {
   124  				return err
   125  			}
   126  		}
   127  		return nil
   128  	}
   129  
   130  	if !reflect.DeepEqual(want, got) {
   131  		return fmt.Errorf("want %v, got %v (%s)", want, got, ctx)
   132  	}
   133  	return nil
   134  }
   135  
   136  func checkNodeInField(got Node, want interface{}) error {
   137  	switch want := want.(type) {
   138  	case string:
   139  		text := got.SourceText()
   140  		if want != text {
   141  			return fmt.Errorf("want %q, got %q (%s)", want, text, summary(got))
   142  		}
   143  		return nil
   144  	case ast:
   145  		return checkAST(got, want)
   146  	default:
   147  		panic(fmt.Sprintf("bad want type %T (%s)", want, summary(got)))
   148  	}
   149  }
   150  
   151  func exported(name string) bool {
   152  	r, _ := utf8.DecodeRuneInString(name)
   153  	return unicode.IsUpper(r)
   154  }