github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/ast/ast_roundtrip_test.go (about)

     1  package ast_test
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"github.com/jhump/protocompile/ast"
    14  	"github.com/jhump/protocompile/parser"
    15  	"github.com/jhump/protocompile/reporter"
    16  )
    17  
    18  func TestASTRoundTrips(t *testing.T) {
    19  	err := filepath.Walk("../internal/testprotos", func(path string, info os.FileInfo, err error) error {
    20  		if err != nil {
    21  			return err
    22  		}
    23  		if filepath.Ext(path) == ".proto" {
    24  			t.Run(path, func(t *testing.T) {
    25  				data, err := ioutil.ReadFile(path)
    26  				if !assert.Nil(t, err, "%v", err) {
    27  					return
    28  				}
    29  				filename := filepath.Base(path)
    30  				root, err := parser.Parse(filename, bytes.NewReader(data), reporter.NewHandler(nil))
    31  				if !assert.Nil(t, err) {
    32  					return
    33  				}
    34  				var buf bytes.Buffer
    35  				err = printAST(&buf, root)
    36  				if assert.Nil(t, err, "%v", err) {
    37  					// see if file survived round trip!
    38  					assert.Equal(t, string(data), buf.String())
    39  				}
    40  			})
    41  		}
    42  		return nil
    43  	})
    44  	assert.Nil(t, err, "%v", err)
    45  }
    46  
    47  // printAST prints the given AST node to the given output. This operation
    48  // basically walks the AST and, for each TerminalNode, prints the node's
    49  // leading comments, leading whitespace, the node's raw text, and then
    50  // any trailing comments. If the given node is a *FileNode, it will then
    51  // also print the file's FinalComments and FinalWhitespace.
    52  func printAST(w io.Writer, file *ast.FileNode) error {
    53  	sw, ok := w.(stringWriter)
    54  	if !ok {
    55  		sw = &strWriter{w}
    56  	}
    57  	err := ast.Walk(file, &ast.SimpleVisitor{
    58  		DoVisitTerminalNode: func(token ast.TerminalNode) error {
    59  			info := file.NodeInfo(token)
    60  			if err := printComments(sw, info.LeadingComments()); err != nil {
    61  				return err
    62  			}
    63  
    64  			if _, err := sw.WriteString(info.LeadingWhitespace()); err != nil {
    65  				return err
    66  			}
    67  
    68  			if _, err := sw.WriteString(info.RawText()); err != nil {
    69  				return err
    70  			}
    71  
    72  			return printComments(sw, info.TrailingComments())
    73  		},
    74  	})
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	//err = printComments(sw, file.FinalComments)
    80  	//if err != nil {
    81  	//	return err
    82  	//}
    83  	//_, err = sw.WriteString(file.FinalWhitespace)
    84  	//return err
    85  
    86  	return nil
    87  }
    88  
    89  func printComments(sw stringWriter, comments ast.Comments) error {
    90  	for i := 0; i < comments.Len(); i++ {
    91  		comment := comments.Index(i)
    92  		if _, err := sw.WriteString(comment.LeadingWhitespace()); err != nil {
    93  			return err
    94  		}
    95  		if _, err := sw.WriteString(comment.RawText()); err != nil {
    96  			return err
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  // many io.Writer impls also provide a string-based method
   103  type stringWriter interface {
   104  	WriteString(s string) (n int, err error)
   105  }
   106  
   107  // adapter, in case the given writer does NOT provide a string-based method
   108  type strWriter struct {
   109  	io.Writer
   110  }
   111  
   112  func (s *strWriter) WriteString(str string) (int, error) {
   113  	if str == "" {
   114  		return 0, nil
   115  	}
   116  	return s.Write([]byte(str))
   117  }