src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/parse/check_ast_test.go (about) 1 package parse 2 3 import ( 4 "fmt" 5 "reflect" 6 "strings" 7 "unicode" 8 "unicode/utf8" 9 ) 10 11 // AST checking utilities. Used in test cases. 12 13 // ast is an AST specification. The name part identifies the type of the Node; 14 // for instance, "Chunk" specifies a Chunk. The fields part is specifies children 15 // to check; see document of fs. 16 // 17 // When a Node contains exactly one child, It can be coalesced with its child 18 // by adding "/ChildName" in the name part. For instance, "Chunk/Pipeline" 19 // specifies a Chunk that contains exactly one Pipeline. In this case, the 20 // fields part specified the children of the Pipeline instead of the Chunk 21 // (which has no additional interesting fields anyway). Multi-level coalescence 22 // like "Chunk/Pipeline/Form" is also allowed. 23 // 24 // The dynamic type of the Node being checked is assumed to be a pointer to a 25 // struct that embeds the "node" struct. 26 type ast struct { 27 name string 28 fields fs 29 } 30 31 // fs specifies fields of a Node to check. For the value of field $f in the 32 // Node ("found value"), fs[$f] ("wanted value") is used to check against it. 33 // 34 // If the key is "text", the SourceText of the Node is checked. It doesn't 35 // involve a found value. 36 // 37 // If the wanted value is nil, the found value is checked against nil. 38 // 39 // If the found value implements Node, then the wanted value must be either an 40 // ast, where the checking algorithm of ast applies, or a string, where the 41 // source text of the found value is checked. 42 // 43 // If the found value is a slice whose elements implement Node, then the wanted 44 // value must be a slice where checking is then done recursively. 45 // 46 // If the found value satisfied none of the above conditions, it is checked 47 // against the wanted value using reflect.DeepEqual. 48 type fs map[string]any 49 50 // checkAST checks an AST against a specification. 51 func checkAST(n Node, want ast) error { 52 wantnames := strings.Split(want.name, "/") 53 // Check coalesced levels 54 for i, wantname := range wantnames { 55 name := reflect.TypeOf(n).Elem().Name() 56 if wantname != name { 57 return fmt.Errorf("want %s, got %s (%s)", wantname, name, summary(n)) 58 } 59 if i == len(wantnames)-1 { 60 break 61 } 62 fields := Children(n) 63 if len(fields) != 1 { 64 return fmt.Errorf("want exactly 1 child, got %d (%s)", len(fields), summary(n)) 65 } 66 n = fields[0] 67 } 68 69 ntype := reflect.TypeOf(n).Elem() 70 nvalue := reflect.ValueOf(n).Elem() 71 72 for i := 0; i < ntype.NumField(); i++ { 73 fieldname := ntype.Field(i).Name 74 if !exported(fieldname) { 75 // Unexported field 76 continue 77 } 78 got := nvalue.Field(i).Interface() 79 want, ok := want.fields[fieldname] 80 if ok { 81 err := checkField(got, want, "field "+fieldname+" of: "+summary(n)) 82 if err != nil { 83 return err 84 } 85 } else { 86 // Not specified. Check if got is a zero value of its type. 87 zero := reflect.Zero(reflect.TypeOf(got)).Interface() 88 if !reflect.DeepEqual(got, zero) { 89 return fmt.Errorf("want %v, got %v (field %s of: %s)", zero, got, fieldname, summary(n)) 90 } 91 } 92 } 93 94 return nil 95 } 96 97 // checkField checks a field against a field specification. 98 func checkField(got any, want any, ctx string) error { 99 // Want nil. 100 if want == nil { 101 if !reflect.ValueOf(got).IsNil() { 102 return fmt.Errorf("want nil, got %v (%s)", got, ctx) 103 } 104 return nil 105 } 106 107 if got, ok := got.(Node); ok { 108 // Got a Node. 109 return checkNodeInField(got, want) 110 } 111 tgot := reflect.TypeOf(got) 112 if tgot.Kind() == reflect.Slice && tgot.Elem().Implements(nodeType) { 113 // Got a slice of Nodes. 114 vgot := reflect.ValueOf(got) 115 vwant := reflect.ValueOf(want) 116 if vgot.Len() != vwant.Len() { 117 return fmt.Errorf("want %d, got %d (%s)", vwant.Len(), vgot.Len(), ctx) 118 } 119 for i := 0; i < vgot.Len(); i++ { 120 err := checkNodeInField(vgot.Index(i).Interface().(Node), 121 vwant.Index(i).Interface()) 122 if err != nil { 123 return err 124 } 125 } 126 return nil 127 } 128 129 if !reflect.DeepEqual(want, got) { 130 return fmt.Errorf("want %v, got %v (%s)", want, got, ctx) 131 } 132 return nil 133 } 134 135 func checkNodeInField(got Node, want any) error { 136 switch want := want.(type) { 137 case string: 138 text := SourceText(got) 139 if want != text { 140 return fmt.Errorf("want %q, got %q (%s)", want, text, summary(got)) 141 } 142 return nil 143 case ast: 144 return checkAST(got, want) 145 default: 146 panic(fmt.Sprintf("bad want type %T (%s)", want, summary(got))) 147 } 148 } 149 150 func exported(name string) bool { 151 r, _ := utf8.DecodeRuneInString(name) 152 return unicode.IsUpper(r) 153 }