github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/ast/ast_roundtrip_test.go (about) 1 package ast_test 2 3 import ( 4 "bytes" 5 "io" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 13 "github.com/jhump/protocompile/ast" 14 "github.com/jhump/protocompile/parser" 15 "github.com/jhump/protocompile/reporter" 16 ) 17 18 func TestASTRoundTrips(t *testing.T) { 19 err := filepath.Walk("../internal/testprotos", func(path string, info os.FileInfo, err error) error { 20 if err != nil { 21 return err 22 } 23 if filepath.Ext(path) == ".proto" { 24 t.Run(path, func(t *testing.T) { 25 data, err := ioutil.ReadFile(path) 26 if !assert.Nil(t, err, "%v", err) { 27 return 28 } 29 filename := filepath.Base(path) 30 root, err := parser.Parse(filename, bytes.NewReader(data), reporter.NewHandler(nil)) 31 if !assert.Nil(t, err) { 32 return 33 } 34 var buf bytes.Buffer 35 err = printAST(&buf, root) 36 if assert.Nil(t, err, "%v", err) { 37 // see if file survived round trip! 38 assert.Equal(t, string(data), buf.String()) 39 } 40 }) 41 } 42 return nil 43 }) 44 assert.Nil(t, err, "%v", err) 45 } 46 47 // printAST prints the given AST node to the given output. This operation 48 // basically walks the AST and, for each TerminalNode, prints the node's 49 // leading comments, leading whitespace, the node's raw text, and then 50 // any trailing comments. If the given node is a *FileNode, it will then 51 // also print the file's FinalComments and FinalWhitespace. 52 func printAST(w io.Writer, file *ast.FileNode) error { 53 sw, ok := w.(stringWriter) 54 if !ok { 55 sw = &strWriter{w} 56 } 57 err := ast.Walk(file, &ast.SimpleVisitor{ 58 DoVisitTerminalNode: func(token ast.TerminalNode) error { 59 info := file.NodeInfo(token) 60 if err := printComments(sw, info.LeadingComments()); err != nil { 61 return err 62 } 63 64 if _, err := sw.WriteString(info.LeadingWhitespace()); err != nil { 65 return err 66 } 67 68 if _, err := sw.WriteString(info.RawText()); err != nil { 69 return err 70 } 71 72 return printComments(sw, info.TrailingComments()) 73 }, 74 }) 75 if err != nil { 76 return err 77 } 78 79 //err = printComments(sw, file.FinalComments) 80 //if err != nil { 81 // return err 82 //} 83 //_, err = sw.WriteString(file.FinalWhitespace) 84 //return err 85 86 return nil 87 } 88 89 func printComments(sw stringWriter, comments ast.Comments) error { 90 for i := 0; i < comments.Len(); i++ { 91 comment := comments.Index(i) 92 if _, err := sw.WriteString(comment.LeadingWhitespace()); err != nil { 93 return err 94 } 95 if _, err := sw.WriteString(comment.RawText()); err != nil { 96 return err 97 } 98 } 99 return nil 100 } 101 102 // many io.Writer impls also provide a string-based method 103 type stringWriter interface { 104 WriteString(s string) (n int, err error) 105 } 106 107 // adapter, in case the given writer does NOT provide a string-based method 108 type strWriter struct { 109 io.Writer 110 } 111 112 func (s *strWriter) WriteString(str string) (int, error) { 113 if str == "" { 114 return 0, nil 115 } 116 return s.Write([]byte(str)) 117 }