github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/pretty_test.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree_test
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"flag"
    17  	"fmt"
    18  	"io/ioutil"
    19  	"path/filepath"
    20  	"runtime"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/sql/parser"
    25  	_ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    28  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    29  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    30  	"github.com/cockroachdb/cockroach/pkg/util/pretty"
    31  	"golang.org/x/sync/errgroup"
    32  )
    33  
    34  var (
    35  	flagWritePretty = flag.Bool("rewrite-pretty", false, "rewrite pretty test outputs")
    36  	testPrettyCfg   = func() tree.PrettyCfg {
    37  		cfg := tree.DefaultPrettyCfg()
    38  		cfg.JSONFmt = true
    39  		return cfg
    40  	}()
    41  )
    42  
    43  // TestPrettyData reads in a single SQL statement from a file, formats
    44  // it at 40 characters width, and compares that output to a known-good
    45  // output file. It is most useful when changing or implementing the
    46  // doc interface for a node, and should be used to compare and verify
    47  // the changed output.
    48  func TestPrettyDataShort(t *testing.T) {
    49  	defer leaktest.AfterTest(t)()
    50  	matches, err := filepath.Glob(filepath.Join("testdata", "pretty", "*.sql"))
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  	if *flagWritePretty {
    55  		t.Log("WARNING: do not forget to run TestPrettyData with build flag 'nightly' and the -rewrite-pretty flag too!")
    56  	}
    57  	cfg := testPrettyCfg
    58  	cfg.Align = tree.PrettyNoAlign
    59  	t.Run("ref", func(t *testing.T) {
    60  		runTestPrettyData(t, "ref", cfg, matches, true /*short*/)
    61  	})
    62  	cfg.Align = tree.PrettyAlignAndDeindent
    63  	t.Run("align-deindent", func(t *testing.T) {
    64  		runTestPrettyData(t, "align-deindent", cfg, matches, true /*short*/)
    65  	})
    66  	cfg.Align = tree.PrettyAlignOnly
    67  	t.Run("align-only", func(t *testing.T) {
    68  		runTestPrettyData(t, "align-only", cfg, matches, true /*short*/)
    69  	})
    70  }
    71  
    72  func runTestPrettyData(
    73  	t *testing.T, prefix string, cfg tree.PrettyCfg, matches []string, short bool,
    74  ) {
    75  	for _, m := range matches {
    76  		m := m
    77  		t.Run(filepath.Base(m), func(t *testing.T) {
    78  			sql, err := ioutil.ReadFile(m)
    79  			if err != nil {
    80  				t.Fatal(err)
    81  			}
    82  			stmt, err := parser.ParseOne(string(sql))
    83  			if err != nil {
    84  				t.Fatal(err)
    85  			}
    86  
    87  			// We have a statement, now we need to format it at all possible line
    88  			// lengths. We use the length of the string + 10 as the upper bound to try to
    89  			// find what happens at the longest line length. Preallocate a result slice and
    90  			// work chan, then fire off a bunch of workers to compute all of the variants.
    91  			var res []string
    92  			if short {
    93  				res = []string{""}
    94  			} else {
    95  				res = make([]string, len(sql)+10)
    96  			}
    97  			type param struct{ idx, numCols int }
    98  			work := make(chan param, len(res))
    99  			if short {
   100  				work <- param{0, 40}
   101  			} else {
   102  				for i := range res {
   103  					work <- param{i, i + 1}
   104  				}
   105  			}
   106  			close(work)
   107  			g, _ := errgroup.WithContext(context.Background())
   108  			worker := func() error {
   109  				for p := range work {
   110  					thisCfg := cfg
   111  					thisCfg.LineWidth = p.numCols
   112  					res[p.idx] = thisCfg.Pretty(stmt.AST)
   113  				}
   114  				return nil
   115  			}
   116  			for i := 0; i < runtime.NumCPU(); i++ {
   117  				g.Go(worker)
   118  			}
   119  			if err := g.Wait(); err != nil {
   120  				t.Fatal(err)
   121  			}
   122  			var sb strings.Builder
   123  			for i, s := range res {
   124  				// Only write each new result to the output, along with a small header
   125  				// indicating the line length.
   126  				if i == 0 || s != res[i-1] {
   127  					fmt.Fprintf(&sb, "%d:\n%s\n%s\n\n", i+1, strings.Repeat("-", i+1), s)
   128  				}
   129  			}
   130  			var gotB bytes.Buffer
   131  			gotB.WriteString("// Code generated by TestPretty. DO NOT EDIT.\n")
   132  			gotB.WriteString("// GENERATED FILE DO NOT EDIT\n")
   133  			gotB.WriteString(sb.String())
   134  			gotB.WriteByte('\n')
   135  			got := gotB.String()
   136  
   137  			ext := filepath.Ext(m)
   138  			outfile := m[:len(m)-len(ext)] + "." + prefix + ".golden"
   139  			if short {
   140  				outfile = outfile + ".short"
   141  			}
   142  
   143  			if *flagWritePretty {
   144  				if err := ioutil.WriteFile(outfile, []byte(got), 0666); err != nil {
   145  					t.Fatal(err)
   146  				}
   147  				return
   148  			}
   149  
   150  			expect, err := ioutil.ReadFile(outfile)
   151  			if err != nil {
   152  				t.Fatal(err)
   153  			}
   154  			if string(expect) != got {
   155  				t.Fatalf("expected:\n%s\ngot:\n%s", expect, got)
   156  			}
   157  
   158  			sqlutils.VerifyStatementPrettyRoundtrip(t, string(sql))
   159  		})
   160  	}
   161  }
   162  
   163  func TestPrettyVerify(t *testing.T) {
   164  	defer leaktest.AfterTest(t)()
   165  	tests := map[string]string{
   166  		// Verify that INTERVAL is maintained.
   167  		`SELECT interval '-2µs'`: `SELECT '-00:00:00.000002':::INTERVAL`,
   168  	}
   169  	for orig, pretty := range tests {
   170  		t.Run(orig, func(t *testing.T) {
   171  			sqlutils.VerifyStatementPrettyRoundtrip(t, orig)
   172  
   173  			stmt, err := parser.ParseOne(orig)
   174  			if err != nil {
   175  				t.Fatal(err)
   176  			}
   177  			got := tree.Pretty(stmt.AST)
   178  			if pretty != got {
   179  				t.Fatalf("got: %s\nexpected: %s", got, pretty)
   180  			}
   181  		})
   182  	}
   183  }
   184  
   185  func BenchmarkPrettyData(b *testing.B) {
   186  	matches, err := filepath.Glob(filepath.Join("testdata", "pretty", "*.sql"))
   187  	if err != nil {
   188  		b.Fatal(err)
   189  	}
   190  	var docs []pretty.Doc
   191  	cfg := tree.DefaultPrettyCfg()
   192  	for _, m := range matches {
   193  		sql, err := ioutil.ReadFile(m)
   194  		if err != nil {
   195  			b.Fatal(err)
   196  		}
   197  		stmt, err := parser.ParseOne(string(sql))
   198  		if err != nil {
   199  			b.Fatal(err)
   200  		}
   201  		docs = append(docs, cfg.Doc(stmt.AST))
   202  	}
   203  
   204  	b.ResetTimer()
   205  	for i := 0; i < b.N; i++ {
   206  		for _, doc := range docs {
   207  			for _, w := range []int{1, 30, 80} {
   208  				pretty.Pretty(doc, w, true /*useTabs*/, 4 /*tabWidth*/, nil /* keywordTransform */)
   209  			}
   210  		}
   211  	}
   212  }
   213  
   214  func TestPrettyExprs(t *testing.T) {
   215  	defer leaktest.AfterTest(t)()
   216  	tests := map[tree.Expr]string{
   217  		&tree.CastExpr{
   218  			Expr: tree.NewDString("foo"),
   219  			Type: types.MakeCollatedString(types.String, "en"),
   220  		}: `CAST('foo':::STRING AS STRING) COLLATE en`,
   221  	}
   222  
   223  	for expr, pretty := range tests {
   224  		got := tree.Pretty(expr)
   225  		if pretty != got {
   226  			t.Fatalf("got: %s\nexpected: %s", got, pretty)
   227  		}
   228  	}
   229  }