github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/colexec/execgen/util.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package execgen
    12  
    13  import (
    14  	"fmt"
    15  	"strings"
    16  
    17  	"github.com/dave/dst"
    18  	"github.com/dave/dst/decorator"
    19  )
    20  
    21  func prettyPrintStmts(stmts ...dst.Stmt) string {
    22  	if len(stmts) == 0 {
    23  		return ""
    24  	}
    25  	f := &dst.File{
    26  		Name: dst.NewIdent("main"),
    27  		Decls: []dst.Decl{
    28  			&dst.FuncDecl{
    29  				Name: dst.NewIdent("test"),
    30  				Type: &dst.FuncType{},
    31  				Body: &dst.BlockStmt{
    32  					List: stmts,
    33  				},
    34  			},
    35  		},
    36  	}
    37  	var ret strings.Builder
    38  	_ = decorator.Fprint(&ret, f)
    39  	prelude := `package main
    40  
    41  func test() {
    42  `
    43  	postlude := `}
    44  `
    45  	s := ret.String()
    46  	return strings.TrimSpace(s[len(prelude) : len(s)-len(postlude)])
    47  }
    48  
    49  func prettyPrintExprs(exprs ...dst.Expr) string {
    50  	stmts := make([]dst.Stmt, len(exprs))
    51  	for i := range exprs {
    52  		stmts[i] = &dst.ExprStmt{X: exprs[i]}
    53  	}
    54  	return prettyPrintStmts(stmts...)
    55  }
    56  
    57  func parseStmt(stmt string) (dst.Stmt, error) {
    58  	f, err := decorator.Parse(fmt.Sprintf(
    59  		`package main
    60  func test() {
    61  	%s
    62  }`, stmt))
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	return f.Decls[0].(*dst.FuncDecl).Body.List[0], nil
    67  }
    68  
    69  func mustParseStmt(stmt string) dst.Stmt {
    70  	ret, err := parseStmt(stmt)
    71  	if err != nil {
    72  		panic(err)
    73  	}
    74  	return ret
    75  }