github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/expect/extract.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package expect 6 7 import ( 8 "fmt" 9 "go/ast" 10 "go/parser" 11 "go/token" 12 "path/filepath" 13 "regexp" 14 "strconv" 15 "strings" 16 "text/scanner" 17 18 "golang.org/x/mod/modfile" 19 ) 20 21 const commentStart = "@" 22 const commentStartLen = len(commentStart) 23 24 // Identifier is the type for an identifier in an Note argument list. 25 type Identifier string 26 27 // Parse collects all the notes present in a file. 28 // If content is nil, the filename specified is read and parsed, otherwise the 29 // content is used and the filename is used for positions and error messages. 30 // Each comment whose text starts with @ is parsed as a comma-separated 31 // sequence of notes. 32 // See the package documentation for details about the syntax of those 33 // notes. 34 func Parse(fset *token.FileSet, filename string, content []byte) ([]*Note, error) { 35 var src interface{} 36 if content != nil { 37 src = content 38 } 39 switch filepath.Ext(filename) { 40 case ".go": 41 // TODO: We should write this in terms of the scanner. 42 // there are ways you can break the parser such that it will not add all the 43 // comments to the ast, which may result in files where the tests are silently 44 // not run. 45 file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors) 46 if file == nil { 47 return nil, err 48 } 49 return ExtractGo(fset, file) 50 case ".mod": 51 file, err := modfile.Parse(filename, content, nil) 52 if err != nil { 53 return nil, err 54 } 55 f := fset.AddFile(filename, -1, len(content)) 56 f.SetLinesForContent(content) 57 notes, err := extractMod(fset, file) 58 if err != nil { 59 return nil, err 60 } 61 // Since modfile.Parse does not return an *ast, we need to add the offset 62 // within the file's contents to the file's base relative to the fileset. 63 for _, note := range notes { 64 note.Pos += token.Pos(f.Base()) 65 } 66 return notes, nil 67 } 68 return nil, nil 69 } 70 71 // extractMod collects all the notes present in a go.mod file. 72 // Each comment whose text starts with @ is parsed as a comma-separated 73 // sequence of notes. 74 // See the package documentation for details about the syntax of those 75 // notes. 76 // Only allow notes to appear with the following format: "//@mark()" or // @mark() 77 func extractMod(fset *token.FileSet, file *modfile.File) ([]*Note, error) { 78 var notes []*Note 79 for _, stmt := range file.Syntax.Stmt { 80 comment := stmt.Comment() 81 if comment == nil { 82 continue 83 } 84 // Handle the case for markers of `// indirect` to be on the line before 85 // the require statement. 86 // TODO(golang/go#36894): have a more intuitive approach for // indirect 87 for _, cmt := range comment.Before { 88 text, adjust := getAdjustedNote(cmt.Token) 89 if text == "" { 90 continue 91 } 92 parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text) 93 if err != nil { 94 return nil, err 95 } 96 notes = append(notes, parsed...) 97 } 98 // Handle the normal case for markers on the same line. 99 for _, cmt := range comment.Suffix { 100 text, adjust := getAdjustedNote(cmt.Token) 101 if text == "" { 102 continue 103 } 104 parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text) 105 if err != nil { 106 return nil, err 107 } 108 notes = append(notes, parsed...) 109 } 110 } 111 return notes, nil 112 } 113 114 // ExtractGo collects all the notes present in an AST. 115 // Each comment whose text starts with @ is parsed as a comma-separated 116 // sequence of notes. 117 // See the package documentation for details about the syntax of those 118 // notes. 119 func ExtractGo(fset *token.FileSet, file *ast.File) ([]*Note, error) { 120 var notes []*Note 121 for _, g := range file.Comments { 122 for _, c := range g.List { 123 text, adjust := getAdjustedNote(c.Text) 124 if text == "" { 125 continue 126 } 127 parsed, err := parse(fset, token.Pos(int(c.Pos())+adjust), text) 128 if err != nil { 129 return nil, err 130 } 131 notes = append(notes, parsed...) 132 } 133 } 134 return notes, nil 135 } 136 137 func getAdjustedNote(text string) (string, int) { 138 if strings.HasPrefix(text, "/*") { 139 text = strings.TrimSuffix(text, "*/") 140 } 141 text = text[2:] // remove "//" or "/*" prefix 142 143 // Allow notes to appear within comments. 144 // For example: 145 // "// //@mark()" is valid. 146 // "// @mark()" is not valid. 147 // "// /*@mark()*/" is not valid. 148 var adjust int 149 if i := strings.Index(text, commentStart); i > 2 { 150 // Get the text before the commentStart. 151 pre := text[i-2 : i] 152 if pre != "//" { 153 return "", 0 154 } 155 text = text[i:] 156 adjust = i 157 } 158 if !strings.HasPrefix(text, commentStart) { 159 return "", 0 160 } 161 text = text[commentStartLen:] 162 return text, commentStartLen + adjust + 1 163 } 164 165 const invalidToken rune = 0 166 167 type tokens struct { 168 scanner scanner.Scanner 169 current rune 170 err error 171 base token.Pos 172 } 173 174 func (t *tokens) Init(base token.Pos, text string) *tokens { 175 t.base = base 176 t.scanner.Init(strings.NewReader(text)) 177 t.scanner.Mode = scanner.GoTokens 178 t.scanner.Whitespace ^= 1 << '\n' // don't skip new lines 179 t.scanner.Error = func(s *scanner.Scanner, msg string) { 180 t.Errorf("%v", msg) 181 } 182 return t 183 } 184 185 func (t *tokens) Consume() string { 186 t.current = invalidToken 187 return t.scanner.TokenText() 188 } 189 190 func (t *tokens) Token() rune { 191 if t.err != nil { 192 return scanner.EOF 193 } 194 if t.current == invalidToken { 195 t.current = t.scanner.Scan() 196 } 197 return t.current 198 } 199 200 func (t *tokens) Skip(r rune) int { 201 i := 0 202 for t.Token() == '\n' { 203 t.Consume() 204 i++ 205 } 206 return i 207 } 208 209 func (t *tokens) TokenString() string { 210 return scanner.TokenString(t.Token()) 211 } 212 213 func (t *tokens) Pos() token.Pos { 214 return t.base + token.Pos(t.scanner.Position.Offset) 215 } 216 217 func (t *tokens) Errorf(msg string, args ...interface{}) { 218 if t.err != nil { 219 return 220 } 221 t.err = fmt.Errorf(msg, args...) 222 } 223 224 func parse(fset *token.FileSet, base token.Pos, text string) ([]*Note, error) { 225 t := new(tokens).Init(base, text) 226 notes := parseComment(t) 227 if t.err != nil { 228 return nil, fmt.Errorf("%v:%s", fset.Position(t.Pos()), t.err) 229 } 230 return notes, nil 231 } 232 233 func parseComment(t *tokens) []*Note { 234 var notes []*Note 235 for { 236 t.Skip('\n') 237 switch t.Token() { 238 case scanner.EOF: 239 return notes 240 case scanner.Ident: 241 notes = append(notes, parseNote(t)) 242 default: 243 t.Errorf("unexpected %s parsing comment, expect identifier", t.TokenString()) 244 return nil 245 } 246 switch t.Token() { 247 case scanner.EOF: 248 return notes 249 case ',', '\n': 250 t.Consume() 251 default: 252 t.Errorf("unexpected %s parsing comment, expect separator", t.TokenString()) 253 return nil 254 } 255 } 256 } 257 258 func parseNote(t *tokens) *Note { 259 n := &Note{ 260 Pos: t.Pos(), 261 Name: t.Consume(), 262 } 263 264 switch t.Token() { 265 case ',', '\n', scanner.EOF: 266 // no argument list present 267 return n 268 case '(': 269 n.Args = parseArgumentList(t) 270 return n 271 default: 272 t.Errorf("unexpected %s parsing note", t.TokenString()) 273 return nil 274 } 275 } 276 277 func parseArgumentList(t *tokens) []interface{} { 278 args := []interface{}{} // @name() is represented by a non-nil empty slice. 279 t.Consume() // '(' 280 t.Skip('\n') 281 for t.Token() != ')' { 282 args = append(args, parseArgument(t)) 283 if t.Token() != ',' { 284 break 285 } 286 t.Consume() 287 t.Skip('\n') 288 } 289 if t.Token() != ')' { 290 t.Errorf("unexpected %s parsing argument list", t.TokenString()) 291 return nil 292 } 293 t.Consume() // ')' 294 return args 295 } 296 297 func parseArgument(t *tokens) interface{} { 298 switch t.Token() { 299 case scanner.Ident: 300 v := t.Consume() 301 switch v { 302 case "true": 303 return true 304 case "false": 305 return false 306 case "nil": 307 return nil 308 case "re": 309 if t.Token() != scanner.String && t.Token() != scanner.RawString { 310 t.Errorf("re must be followed by string, got %s", t.TokenString()) 311 return nil 312 } 313 pattern, _ := strconv.Unquote(t.Consume()) // can't fail 314 re, err := regexp.Compile(pattern) 315 if err != nil { 316 t.Errorf("invalid regular expression %s: %v", pattern, err) 317 return nil 318 } 319 return re 320 default: 321 return Identifier(v) 322 } 323 324 case scanner.String, scanner.RawString: 325 v, _ := strconv.Unquote(t.Consume()) // can't fail 326 return v 327 328 case scanner.Int: 329 s := t.Consume() 330 v, err := strconv.ParseInt(s, 0, 0) 331 if err != nil { 332 t.Errorf("cannot convert %v to int: %v", s, err) 333 } 334 return v 335 336 case scanner.Float: 337 s := t.Consume() 338 v, err := strconv.ParseFloat(s, 64) 339 if err != nil { 340 t.Errorf("cannot convert %v to float: %v", s, err) 341 } 342 return v 343 344 case scanner.Char: 345 t.Errorf("unexpected char literal %s", t.Consume()) 346 return nil 347 348 default: 349 t.Errorf("unexpected %s parsing argument", t.TokenString()) 350 return nil 351 } 352 }