github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/sem/plpgsqltree/statements.go (about)

     1  // Copyright 2023 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package plpgsqltree
    12  
    13  import (
    14  	"fmt"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
    19  	"github.com/cockroachdb/cockroachdb-parser/pkg/util/errorutil/unimplemented"
    20  )
    21  
    22  type Expr = tree.Expr
    23  
    24  type Statement interface {
    25  	tree.NodeFormatter
    26  	GetLineNo() int
    27  	GetStmtID() uint
    28  	plpgsqlStmt()
    29  	WalkStmt(StatementVisitor) (newStmt Statement, changed bool)
    30  }
    31  
    32  type TaggedStatement interface {
    33  	PlpgSQLStatementTag() string
    34  }
    35  
    36  type StatementImpl struct {
    37  	// TODO(Chengxiong): figure out how to get line number from scanner.
    38  	LineNo int
    39  	/*
    40  	 * Unique statement ID in this function (starting at 1; 0 is invalid/not
    41  	 * set).  This can be used by a profiler as the index for an array of
    42  	 * per-statement metrics.
    43  	 */
    44  	// TODO(Chengxiong): figure out how to get statement id from parser.
    45  	StmtID uint
    46  }
    47  
    48  func (s *StatementImpl) GetLineNo() int {
    49  	return s.LineNo
    50  }
    51  
    52  func (s *StatementImpl) GetStmtID() uint {
    53  	return s.StmtID
    54  }
    55  
    56  func (s *StatementImpl) plpgsqlStmt() {}
    57  
    58  // pl_block
    59  type Block struct {
    60  	StatementImpl
    61  	Label      string
    62  	Decls      []Statement
    63  	Body       []Statement
    64  	Exceptions []Exception
    65  }
    66  
    67  func (s *Block) CopyNode() *Block {
    68  	copyNode := *s
    69  	copyNode.Decls = append([]Statement(nil), copyNode.Decls...)
    70  	copyNode.Body = append([]Statement(nil), copyNode.Body...)
    71  	copyNode.Exceptions = append([]Exception(nil), copyNode.Exceptions...)
    72  	return &copyNode
    73  }
    74  
    75  // TODO(drewk): format Label and Exceptions fields.
    76  func (s *Block) Format(ctx *tree.FmtCtx) {
    77  	if s.Decls != nil {
    78  		ctx.WriteString("DECLARE\n")
    79  		for _, dec := range s.Decls {
    80  			ctx.FormatNode(dec)
    81  		}
    82  	}
    83  	// TODO(drewk): Make sure the child statement is pretty printed correctly
    84  	// with indents.
    85  	ctx.WriteString("BEGIN\n")
    86  	for _, childStmt := range s.Body {
    87  		ctx.FormatNode(childStmt)
    88  	}
    89  	if s.Exceptions != nil {
    90  		ctx.WriteString("EXCEPTION\n")
    91  		for _, e := range s.Exceptions {
    92  			ctx.FormatNode(&e)
    93  		}
    94  	}
    95  	ctx.WriteString("END\n")
    96  }
    97  
    98  func (s *Block) PlpgSQLStatementTag() string {
    99  	return "stmt_block"
   100  }
   101  
   102  func (s *Block) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   103  	newStmt, changed = visitor.Visit(s)
   104  	for i, stmt := range s.Decls {
   105  		ns, ch := stmt.WalkStmt(visitor)
   106  		if ch {
   107  			changed = true
   108  			if newStmt == s {
   109  				newStmt = s.CopyNode()
   110  			}
   111  			newStmt.(*Block).Decls[i] = ns
   112  		}
   113  	}
   114  	for i, stmt := range s.Body {
   115  		ns, ch := stmt.WalkStmt(visitor)
   116  		if ch {
   117  			changed = true
   118  			if newStmt == s {
   119  				newStmt = s.CopyNode()
   120  			}
   121  			newStmt.(*Block).Body[i] = ns
   122  		}
   123  	}
   124  	for i, stmt := range s.Exceptions {
   125  		ns, ch := stmt.WalkStmt(visitor)
   126  		if ch {
   127  			changed = true
   128  			if newStmt == s {
   129  				newStmt = s.CopyNode()
   130  			}
   131  			newStmt.(*Block).Exceptions[i] = *(ns.(*Exception))
   132  		}
   133  	}
   134  	return newStmt, changed
   135  }
   136  
   137  // decl_stmt
   138  type Declaration struct {
   139  	StatementImpl
   140  	Var      Variable
   141  	Constant bool
   142  	Typ      tree.ResolvableTypeReference
   143  	Collate  string
   144  	NotNull  bool
   145  	Expr     Expr
   146  }
   147  
   148  func (s *Declaration) CopyNode() *Declaration {
   149  	copyNode := *s
   150  	return &copyNode
   151  }
   152  
   153  func (s *Declaration) Format(ctx *tree.FmtCtx) {
   154  	ctx.WriteString(string(s.Var))
   155  	if s.Constant {
   156  		ctx.WriteString(" CONSTANT")
   157  	}
   158  	ctx.WriteString(" ")
   159  	ctx.FormatTypeReference(s.Typ)
   160  	if s.Collate != "" {
   161  		ctx.WriteString(" COLLATE ")
   162  		ctx.FormatNameP(&s.Collate)
   163  	}
   164  	if s.NotNull {
   165  		ctx.WriteString(" NOT NULL")
   166  	}
   167  	if s.Expr != nil {
   168  		ctx.WriteString(" := ")
   169  		ctx.FormatNode(s.Expr)
   170  	}
   171  	ctx.WriteString(";\n")
   172  }
   173  
   174  func (s *Declaration) PlpgSQLStatementTag() string {
   175  	return "decl_stmt"
   176  }
   177  
   178  func (s *Declaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   179  	newStmt, changed = visitor.Visit(s)
   180  	return newStmt, changed
   181  }
   182  
   183  type CursorDeclaration struct {
   184  	StatementImpl
   185  	Name   Variable
   186  	Scroll tree.CursorScrollOption
   187  	Query  tree.Statement
   188  }
   189  
   190  func (s *CursorDeclaration) CopyNode() *CursorDeclaration {
   191  	copyNode := *s
   192  	return &copyNode
   193  }
   194  
   195  func (s *CursorDeclaration) Format(ctx *tree.FmtCtx) {
   196  	ctx.WriteString(string(s.Name))
   197  	switch s.Scroll {
   198  	case tree.Scroll:
   199  		ctx.WriteString(" SCROLL")
   200  	case tree.NoScroll:
   201  		ctx.WriteString(" NO SCROLL")
   202  	}
   203  	ctx.WriteString(" CURSOR FOR ")
   204  	ctx.FormatNode(s.Query)
   205  	ctx.WriteString(";\n")
   206  }
   207  
   208  func (s *CursorDeclaration) PlpgSQLStatementTag() string {
   209  	return "decl_cursor_stmt"
   210  }
   211  
   212  func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   213  	newStmt, changed = visitor.Visit(s)
   214  	return newStmt, changed
   215  }
   216  
   217  // stmt_assign
   218  type Assignment struct {
   219  	Statement
   220  	Var   Variable
   221  	Value Expr
   222  }
   223  
   224  func (s *Assignment) CopyNode() *Assignment {
   225  	copyNode := *s
   226  	return &copyNode
   227  }
   228  
   229  func (s *Assignment) PlpgSQLStatementTag() string {
   230  	return "stmt_assign"
   231  }
   232  
   233  func (s *Assignment) Format(ctx *tree.FmtCtx) {
   234  	ctx.FormatNode(&s.Var)
   235  	ctx.WriteString(" := ")
   236  	ctx.FormatNode(s.Value)
   237  	ctx.WriteString(";\n")
   238  }
   239  
   240  func (s *Assignment) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   241  	newStmt, changed = visitor.Visit(s)
   242  	return newStmt, changed
   243  }
   244  
   245  // stmt_if
   246  type If struct {
   247  	StatementImpl
   248  	Condition  Expr
   249  	ThenBody   []Statement
   250  	ElseIfList []ElseIf
   251  	ElseBody   []Statement
   252  }
   253  
   254  func (s *If) CopyNode() *If {
   255  	copyNode := *s
   256  	copyNode.ThenBody = append([]Statement(nil), copyNode.ThenBody...)
   257  	copyNode.ElseBody = append([]Statement(nil), copyNode.ElseBody...)
   258  	copyNode.ElseIfList = make([]ElseIf, len(s.ElseIfList))
   259  	for i, ei := range s.ElseIfList {
   260  		copyNode.ElseIfList[i] = ei
   261  		copyNode.ElseIfList[i].Stmts = append([]Statement(nil), copyNode.ElseIfList[i].Stmts...)
   262  	}
   263  	return &copyNode
   264  }
   265  
   266  func (s *If) Format(ctx *tree.FmtCtx) {
   267  	ctx.WriteString("IF ")
   268  	ctx.FormatNode(s.Condition)
   269  	ctx.WriteString(" THEN\n")
   270  	for _, stmt := range s.ThenBody {
   271  		// TODO(drewk): Pretty Print with spaces, not tabs.
   272  		ctx.WriteString("\t")
   273  		ctx.FormatNode(stmt)
   274  	}
   275  	for _, elsifStmt := range s.ElseIfList {
   276  		ctx.FormatNode(&elsifStmt)
   277  	}
   278  	for i, elseStmt := range s.ElseBody {
   279  		if i == 0 {
   280  			ctx.WriteString("ELSE\n")
   281  		}
   282  		ctx.WriteString("\t")
   283  		ctx.FormatNode(elseStmt)
   284  	}
   285  	ctx.WriteString("END IF;\n")
   286  }
   287  
   288  func (s *If) PlpgSQLStatementTag() string {
   289  	return "stmt_if"
   290  }
   291  
   292  func (s *If) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   293  	newStmt, changed = visitor.Visit(s)
   294  
   295  	for i, thenStmt := range s.ThenBody {
   296  		ns, ch := thenStmt.WalkStmt(visitor)
   297  		if ch {
   298  			changed = true
   299  			if newStmt == s {
   300  				newStmt = s.CopyNode()
   301  			}
   302  			newStmt.(*If).ThenBody[i] = ns
   303  		}
   304  	}
   305  
   306  	for i, elseIf := range s.ElseIfList {
   307  		ns, ch := elseIf.WalkStmt(visitor)
   308  		if ch {
   309  			changed = true
   310  			if newStmt == s {
   311  				newStmt = s.CopyNode()
   312  			}
   313  			newStmt.(*If).ElseIfList[i] = *ns.(*ElseIf)
   314  		}
   315  	}
   316  
   317  	for i, elseStmt := range s.ElseBody {
   318  		ns, ch := elseStmt.WalkStmt(visitor)
   319  		if ch {
   320  			changed = true
   321  			if newStmt == s {
   322  				newStmt = s.CopyNode()
   323  			}
   324  			newStmt.(*If).ElseBody[i] = ns
   325  		}
   326  	}
   327  
   328  	return newStmt, changed
   329  }
   330  
   331  type ElseIf struct {
   332  	StatementImpl
   333  	Condition Expr
   334  	Stmts     []Statement
   335  }
   336  
   337  func (s *ElseIf) CopyNode() *ElseIf {
   338  	copyNode := *s
   339  	copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...)
   340  	return &copyNode
   341  }
   342  
   343  func (s *ElseIf) Format(ctx *tree.FmtCtx) {
   344  	ctx.WriteString("ELSIF ")
   345  	ctx.FormatNode(s.Condition)
   346  	ctx.WriteString(" THEN\n")
   347  	for _, stmt := range s.Stmts {
   348  		ctx.WriteString("\t")
   349  		ctx.FormatNode(stmt)
   350  	}
   351  }
   352  
   353  func (s *ElseIf) PlpgSQLStatementTag() string {
   354  	return "stmt_if_else_if"
   355  }
   356  
   357  func (s *ElseIf) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   358  	newStmt, changed = visitor.Visit(s)
   359  
   360  	for i, stmt := range s.Stmts {
   361  		ns, ch := stmt.WalkStmt(visitor)
   362  		if ch {
   363  			changed = true
   364  			if newStmt == s {
   365  				newStmt = s.CopyNode()
   366  			}
   367  			newStmt.(*ElseIf).Stmts[i] = ns
   368  		}
   369  	}
   370  	return newStmt, changed
   371  }
   372  
   373  // stmt_case
   374  type Case struct {
   375  	StatementImpl
   376  	// TODO(drewk): Change to Expr
   377  	TestExpr     string
   378  	Var          Variable
   379  	CaseWhenList []*CaseWhen
   380  	HaveElse     bool
   381  	ElseStmts    []Statement
   382  }
   383  
   384  func (s *Case) CopyNode() *Case {
   385  	copyNode := *s
   386  	copyNode.ElseStmts = append([]Statement(nil), copyNode.ElseStmts...)
   387  	copyNode.CaseWhenList = make([]*CaseWhen, len(s.CaseWhenList))
   388  	caseWhens := make([]CaseWhen, len(s.CaseWhenList))
   389  	for i, cw := range s.CaseWhenList {
   390  		caseWhens[i] = *cw
   391  		copyNode.CaseWhenList[i] = &caseWhens[i]
   392  	}
   393  	return &copyNode
   394  }
   395  
   396  // TODO(drewk): fix the whitespace/newline formatting for CASE (see the
   397  // stmt_case test file).
   398  func (s *Case) Format(ctx *tree.FmtCtx) {
   399  	ctx.WriteString("CASE")
   400  	if len(s.TestExpr) > 0 {
   401  		ctx.WriteString(fmt.Sprintf(" %s", s.TestExpr))
   402  	}
   403  	ctx.WriteString("\n")
   404  	for _, when := range s.CaseWhenList {
   405  		ctx.FormatNode(when)
   406  	}
   407  	if s.HaveElse {
   408  		ctx.WriteString("ELSE\n")
   409  		for _, stmt := range s.ElseStmts {
   410  			ctx.WriteString("  ")
   411  			ctx.FormatNode(stmt)
   412  		}
   413  	}
   414  	ctx.WriteString("END CASE\n")
   415  }
   416  
   417  func (s *Case) PlpgSQLStatementTag() string {
   418  	return "stmt_case"
   419  }
   420  
   421  func (s *Case) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   422  	newStmt, changed = visitor.Visit(s)
   423  
   424  	for i, when := range s.CaseWhenList {
   425  		ns, ch := when.WalkStmt(visitor)
   426  		if ch {
   427  			changed = true
   428  			if newStmt == s {
   429  				newStmt = s.CopyNode()
   430  			}
   431  			newStmt.(*Case).CaseWhenList[i] = ns.(*CaseWhen)
   432  		}
   433  	}
   434  
   435  	if s.HaveElse {
   436  		for i, stmt := range s.ElseStmts {
   437  			ns, ch := stmt.WalkStmt(visitor)
   438  			if ch {
   439  				changed = true
   440  				if newStmt == s {
   441  					newStmt = s.CopyNode()
   442  				}
   443  				newStmt.(*Case).ElseStmts[i] = ns
   444  			}
   445  		}
   446  	}
   447  	return newStmt, changed
   448  }
   449  
   450  type CaseWhen struct {
   451  	StatementImpl
   452  	// TODO(drewk): Change to Expr
   453  	Expr  string
   454  	Stmts []Statement
   455  }
   456  
   457  func (s *CaseWhen) CopyNode() *CaseWhen {
   458  	copyNode := *s
   459  	copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...)
   460  	return &copyNode
   461  }
   462  
   463  func (s *CaseWhen) Format(ctx *tree.FmtCtx) {
   464  	ctx.WriteString(fmt.Sprintf("WHEN %s THEN\n", s.Expr))
   465  	for i, stmt := range s.Stmts {
   466  		ctx.WriteString("  ")
   467  		ctx.FormatNode(stmt)
   468  		if i != len(s.Stmts)-1 {
   469  			ctx.WriteString("\n")
   470  		}
   471  	}
   472  }
   473  
   474  func (s *CaseWhen) PlpgSQLStatementTag() string {
   475  	return "stmt_when"
   476  }
   477  
   478  func (s *CaseWhen) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   479  	newStmt, changed = visitor.Visit(s)
   480  
   481  	for i, stmt := range s.Stmts {
   482  		ns, ch := stmt.WalkStmt(visitor)
   483  		if ch {
   484  			changed = true
   485  			if newStmt == s {
   486  				newStmt = s.CopyNode()
   487  			}
   488  			newStmt.(*CaseWhen).Stmts[i] = ns
   489  		}
   490  	}
   491  	return newStmt, changed
   492  }
   493  
   494  // stmt_loop
   495  type Loop struct {
   496  	StatementImpl
   497  	Label string
   498  	Body  []Statement
   499  }
   500  
   501  func (s *Loop) CopyNode() *Loop {
   502  	copyNode := *s
   503  	copyNode.Body = append([]Statement(nil), copyNode.Body...)
   504  	return &copyNode
   505  }
   506  
   507  func (s *Loop) PlpgSQLStatementTag() string {
   508  	return "stmt_simple_loop"
   509  }
   510  
   511  func (s *Loop) Format(ctx *tree.FmtCtx) {
   512  	if s.Label != "" {
   513  		ctx.WriteString("<<")
   514  		ctx.FormatNameP(&s.Label)
   515  		ctx.WriteString(">>\n")
   516  	}
   517  	ctx.WriteString("LOOP\n")
   518  	for _, stmt := range s.Body {
   519  		ctx.FormatNode(stmt)
   520  	}
   521  	ctx.WriteString("END LOOP")
   522  	if s.Label != "" {
   523  		ctx.WriteString(" ")
   524  		ctx.FormatNameP(&s.Label)
   525  	}
   526  	ctx.WriteString(";\n")
   527  }
   528  
   529  func (s *Loop) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   530  	newStmt, changed = visitor.Visit(s)
   531  	for i, stmt := range s.Body {
   532  		ns, ch := stmt.WalkStmt(visitor)
   533  		if ch {
   534  			changed = true
   535  			if newStmt == s {
   536  				newStmt = s.CopyNode()
   537  			}
   538  			newStmt.(*Loop).Body[i] = ns
   539  		}
   540  	}
   541  	return newStmt, changed
   542  }
   543  
   544  // stmt_while
   545  type While struct {
   546  	StatementImpl
   547  	Label     string
   548  	Condition Expr
   549  	Body      []Statement
   550  }
   551  
   552  func (s *While) CopyNode() *While {
   553  	copyNode := *s
   554  	copyNode.Body = append([]Statement(nil), copyNode.Body...)
   555  	return &copyNode
   556  }
   557  
   558  func (s *While) Format(ctx *tree.FmtCtx) {
   559  	if s.Label != "" {
   560  		ctx.WriteString("<<")
   561  		ctx.FormatNameP(&s.Label)
   562  		ctx.WriteString(">>\n")
   563  	}
   564  	ctx.WriteString("WHILE ")
   565  	ctx.FormatNode(s.Condition)
   566  	ctx.WriteString(" LOOP\n")
   567  	for _, stmt := range s.Body {
   568  		ctx.FormatNode(stmt)
   569  	}
   570  	ctx.WriteString("END LOOP")
   571  	if s.Label != "" {
   572  		ctx.WriteString(" ")
   573  		ctx.FormatNameP(&s.Label)
   574  	}
   575  	ctx.WriteString(";\n")
   576  }
   577  
   578  func (s *While) PlpgSQLStatementTag() string {
   579  	return "stmt_while"
   580  }
   581  
   582  func (s *While) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   583  	newStmt, changed = visitor.Visit(s)
   584  	for i, stmt := range s.Body {
   585  		ns, ch := stmt.WalkStmt(visitor)
   586  		if ch {
   587  			changed = true
   588  			if newStmt == s {
   589  				newStmt = s.CopyNode()
   590  			}
   591  			newStmt.(*While).Body[i] = ns
   592  		}
   593  	}
   594  	return newStmt, changed
   595  }
   596  
   597  // stmt_for
   598  type ForInt struct {
   599  	StatementImpl
   600  	Label   string
   601  	Var     Variable
   602  	Lower   Expr
   603  	Upper   Expr
   604  	Step    Expr
   605  	Reverse int
   606  	Body    []Statement
   607  }
   608  
   609  func (s *ForInt) Format(ctx *tree.FmtCtx) {
   610  }
   611  
   612  func (s *ForInt) PlpgSQLStatementTag() string {
   613  	return "stmt_for_int_loop"
   614  }
   615  
   616  func (s *ForInt) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   617  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   618  }
   619  
   620  type ForQuery struct {
   621  	StatementImpl
   622  	Label string
   623  	Var   Variable
   624  	Body  []Statement
   625  }
   626  
   627  func (s *ForQuery) Format(ctx *tree.FmtCtx) {
   628  }
   629  
   630  func (s *ForQuery) PlpgSQLStatementTag() string {
   631  	return "stmt_for_query_loop"
   632  }
   633  
   634  func (s *ForQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   635  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   636  }
   637  
   638  type ForSelect struct {
   639  	ForQuery
   640  	Query Expr
   641  }
   642  
   643  func (s *ForSelect) Format(ctx *tree.FmtCtx) {
   644  }
   645  
   646  func (s *ForSelect) PlpgSQLStatementTag() string {
   647  	return "stmt_query_select_loop"
   648  }
   649  
   650  func (s *ForSelect) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   651  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   652  }
   653  
   654  type ForCursor struct {
   655  	ForQuery
   656  	CurVar   int // TODO(drewk): is this CursorVariable?
   657  	ArgQuery Expr
   658  }
   659  
   660  func (s *ForCursor) Format(ctx *tree.FmtCtx) {
   661  }
   662  
   663  func (s *ForCursor) PlpgSQLStatementTag() string {
   664  	return "stmt_for_query_cursor_loop"
   665  }
   666  
   667  func (s *ForCursor) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   668  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   669  }
   670  
   671  type ForDynamic struct {
   672  	ForQuery
   673  	Query  Expr
   674  	Params []Expr
   675  }
   676  
   677  func (s *ForDynamic) Format(ctx *tree.FmtCtx) {
   678  }
   679  
   680  func (s *ForDynamic) PlpgSQLStatementTag() string {
   681  	return "stmt_for_dyn_loop"
   682  }
   683  
   684  func (s *ForDynamic) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   685  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   686  }
   687  
   688  // stmt_foreach_a
   689  type ForEachArray struct {
   690  	StatementImpl
   691  	Label string
   692  	Var   *Variable
   693  	Slice int // TODO(drewk): not sure what this is
   694  	Expr  Expr
   695  	Body  []Statement
   696  }
   697  
   698  func (s *ForEachArray) Format(ctx *tree.FmtCtx) {
   699  }
   700  
   701  func (s *ForEachArray) PlpgSQLStatementTag() string {
   702  	return "stmt_for_each_a"
   703  }
   704  
   705  func (s *ForEachArray) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   706  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   707  }
   708  
   709  // stmt_exit
   710  type Exit struct {
   711  	StatementImpl
   712  	Label     string
   713  	Condition Expr
   714  }
   715  
   716  func (s *Exit) CopyNode() *Exit {
   717  	copyNode := *s
   718  	return &copyNode
   719  }
   720  
   721  func (s *Exit) Format(ctx *tree.FmtCtx) {
   722  	ctx.WriteString("EXIT")
   723  	if s.Label != "" {
   724  		ctx.WriteString(" ")
   725  		ctx.FormatNameP(&s.Label)
   726  	}
   727  	if s.Condition != nil {
   728  		ctx.WriteString(" WHEN ")
   729  		ctx.FormatNode(s.Condition)
   730  	}
   731  	ctx.WriteString(";\n")
   732  
   733  }
   734  
   735  func (s *Exit) PlpgSQLStatementTag() string {
   736  	return "stmt_exit"
   737  }
   738  
   739  func (s *Exit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   740  	newStmt, changed = visitor.Visit(s)
   741  	return newStmt, changed
   742  }
   743  
   744  // stmt_continue
   745  type Continue struct {
   746  	StatementImpl
   747  	Label     string
   748  	Condition Expr
   749  }
   750  
   751  func (s *Continue) CopyNode() *Continue {
   752  	copyNode := *s
   753  	return &copyNode
   754  }
   755  
   756  func (s *Continue) Format(ctx *tree.FmtCtx) {
   757  	ctx.WriteString("CONTINUE")
   758  	if s.Label != "" {
   759  		ctx.WriteString(" ")
   760  		ctx.FormatNameP(&s.Label)
   761  	}
   762  	if s.Condition != nil {
   763  		ctx.WriteString(" WHEN ")
   764  		ctx.FormatNode(s.Condition)
   765  	}
   766  	ctx.WriteString(";\n")
   767  }
   768  
   769  func (s *Continue) PlpgSQLStatementTag() string {
   770  	return "stmt_continue"
   771  }
   772  
   773  func (s *Continue) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   774  	newStmt, changed = visitor.Visit(s)
   775  	return newStmt, changed
   776  }
   777  
   778  // stmt_return
   779  type Return struct {
   780  	StatementImpl
   781  	Expr   Expr
   782  	RetVar Variable
   783  }
   784  
   785  func (s *Return) CopyNode() *Return {
   786  	copyNode := *s
   787  	return &copyNode
   788  }
   789  
   790  func (s *Return) Format(ctx *tree.FmtCtx) {
   791  	ctx.WriteString("RETURN ")
   792  	if s.Expr == nil {
   793  		ctx.FormatNode(&s.RetVar)
   794  	} else {
   795  		ctx.FormatNode(s.Expr)
   796  	}
   797  	ctx.WriteString(";\n")
   798  }
   799  
   800  func (s *Return) PlpgSQLStatementTag() string {
   801  	return "stmt_return"
   802  }
   803  
   804  func (s *Return) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   805  	newStmt, changed = visitor.Visit(s)
   806  	return newStmt, changed
   807  }
   808  
   809  type ReturnNext struct {
   810  	StatementImpl
   811  	Expr   Expr
   812  	RetVar Variable
   813  }
   814  
   815  func (s *ReturnNext) Format(ctx *tree.FmtCtx) {
   816  }
   817  
   818  func (s *ReturnNext) PlpgSQLStatementTag() string {
   819  	return "stmt_return_next"
   820  }
   821  
   822  func (s *ReturnNext) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   823  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   824  }
   825  
   826  type ReturnQuery struct {
   827  	StatementImpl
   828  	Query        Expr
   829  	DynamicQuery Expr
   830  	Params       []Expr
   831  }
   832  
   833  func (s *ReturnQuery) Format(ctx *tree.FmtCtx) {
   834  }
   835  
   836  func (s *ReturnQuery) PlpgSQLStatementTag() string {
   837  	return "stmt_return_query"
   838  }
   839  
   840  func (s *ReturnQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   841  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
   842  }
   843  
   844  // stmt_raise
   845  type Raise struct {
   846  	StatementImpl
   847  	LogLevel string
   848  	Code     string
   849  	CodeName string
   850  	Message  string
   851  	Params   []Expr
   852  	Options  []RaiseOption
   853  }
   854  
   855  func (s *Raise) CopyNode() *Raise {
   856  	copyNode := *s
   857  	copyNode.Params = append([]Expr(nil), s.Params...)
   858  	copyNode.Options = append([]RaiseOption(nil), s.Options...)
   859  	return &copyNode
   860  }
   861  
   862  func (s *Raise) Format(ctx *tree.FmtCtx) {
   863  	ctx.WriteString("RAISE")
   864  	if s.LogLevel != "" {
   865  		ctx.WriteString(" ")
   866  		ctx.WriteString(s.LogLevel)
   867  	}
   868  	if s.Code != "" {
   869  		ctx.WriteString(" SQLSTATE ")
   870  		formatStringQuotes(ctx, s.Code)
   871  	}
   872  	if s.CodeName != "" {
   873  		ctx.WriteString(" ")
   874  		formatString(ctx, s.CodeName)
   875  	}
   876  	if s.Message != "" {
   877  		ctx.WriteString(" ")
   878  		formatStringQuotes(ctx, s.Message)
   879  		for i := range s.Params {
   880  			ctx.WriteString(", ")
   881  			ctx.FormatNode(s.Params[i])
   882  		}
   883  	}
   884  	for i := range s.Options {
   885  		if i == 0 {
   886  			ctx.WriteString("\nUSING ")
   887  		} else {
   888  			ctx.WriteString(",\n")
   889  		}
   890  		ctx.FormatNode(&s.Options[i])
   891  	}
   892  	ctx.WriteString(";\n")
   893  }
   894  
   895  type RaiseOption struct {
   896  	OptType string
   897  	Expr    Expr
   898  }
   899  
   900  func (s *RaiseOption) Format(ctx *tree.FmtCtx) {
   901  	ctx.WriteString(fmt.Sprintf("%s = ", strings.ToUpper(s.OptType)))
   902  	ctx.FormatNode(s.Expr)
   903  }
   904  
   905  func (s *Raise) PlpgSQLStatementTag() string {
   906  	return "stmt_raise"
   907  }
   908  
   909  func (s *Raise) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   910  	newStmt, changed = visitor.Visit(s)
   911  	return newStmt, changed
   912  }
   913  
   914  // stmt_assert
   915  type Assert struct {
   916  	StatementImpl
   917  	Condition Expr
   918  	Message   Expr
   919  }
   920  
   921  func (s *Assert) CopyNode() *Assert {
   922  	copyNode := *s
   923  	return &copyNode
   924  }
   925  
   926  func (s *Assert) Format(ctx *tree.FmtCtx) {
   927  	// TODO(drewk): Pretty print the assert condition and message
   928  	ctx.WriteString("ASSERT\n")
   929  }
   930  
   931  func (s *Assert) PlpgSQLStatementTag() string {
   932  	return "stmt_assert"
   933  }
   934  
   935  func (s *Assert) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   936  	newStmt, changed = visitor.Visit(s)
   937  	return newStmt, changed
   938  }
   939  
   940  // stmt_execsql
   941  type Execute struct {
   942  	StatementImpl
   943  	SqlStmt tree.Statement
   944  	Strict  bool // INTO STRICT flag
   945  	Target  []Variable
   946  }
   947  
   948  func (s *Execute) CopyNode() *Execute {
   949  	copyNode := *s
   950  	copyNode.Target = append([]Variable(nil), copyNode.Target...)
   951  	return &copyNode
   952  }
   953  
   954  func (s *Execute) Format(ctx *tree.FmtCtx) {
   955  	ctx.FormatNode(s.SqlStmt)
   956  	if s.Target != nil {
   957  		ctx.WriteString(" INTO ")
   958  		if s.Strict {
   959  			ctx.WriteString("STRICT ")
   960  		}
   961  		for i := range s.Target {
   962  			if i > 0 {
   963  				ctx.WriteString(", ")
   964  			}
   965  			ctx.FormatNode(&s.Target[i])
   966  		}
   967  	}
   968  	ctx.WriteString(";\n")
   969  }
   970  
   971  func (s *Execute) PlpgSQLStatementTag() string {
   972  	return "stmt_exec_sql"
   973  }
   974  
   975  func (s *Execute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
   976  	newStmt, changed = visitor.Visit(s)
   977  	return newStmt, changed
   978  }
   979  
   980  // stmt_dynexecute
   981  // TODO(chengxiong): query should be a better expression type.
   982  type DynamicExecute struct {
   983  	StatementImpl
   984  	Query  string
   985  	Into   bool
   986  	Strict bool
   987  	Target Variable
   988  	Params []Expr
   989  }
   990  
   991  func (s *DynamicExecute) CopyNode() *DynamicExecute {
   992  	copyNode := *s
   993  	copyNode.Params = append([]Expr(nil), s.Params...)
   994  	return &copyNode
   995  }
   996  
   997  func (s *DynamicExecute) Format(ctx *tree.FmtCtx) {
   998  	// TODO(drewk): Pretty print the original command
   999  	ctx.WriteString("EXECUTE a dynamic command")
  1000  	if s.Into {
  1001  		ctx.WriteString(" WITH INTO")
  1002  		if s.Strict {
  1003  			ctx.WriteString(" STRICT")
  1004  		}
  1005  	}
  1006  	if s.Params != nil {
  1007  		ctx.WriteString(" WITH USING")
  1008  	}
  1009  	ctx.WriteString("\n")
  1010  }
  1011  
  1012  func (s *DynamicExecute) PlpgSQLStatementTag() string {
  1013  	return "stmt_dyn_exec"
  1014  }
  1015  
  1016  func (s *DynamicExecute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1017  	newStmt, changed = visitor.Visit(s)
  1018  	return newStmt, changed
  1019  }
  1020  
  1021  // stmt_perform
  1022  type Perform struct {
  1023  	StatementImpl
  1024  	Expr Expr
  1025  }
  1026  
  1027  func (s *Perform) Format(ctx *tree.FmtCtx) {
  1028  }
  1029  
  1030  func (s *Perform) PlpgSQLStatementTag() string {
  1031  	return "stmt_perform"
  1032  }
  1033  
  1034  func (s *Perform) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1035  	panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern"))
  1036  }
  1037  
  1038  // stmt_call
  1039  type Call struct {
  1040  	StatementImpl
  1041  	Expr   Expr
  1042  	IsCall bool
  1043  	Target Variable
  1044  }
  1045  
  1046  func (s *Call) CopyNode() *Call {
  1047  	copyNode := *s
  1048  	return &copyNode
  1049  }
  1050  
  1051  func (s *Call) Format(ctx *tree.FmtCtx) {
  1052  	// TODO(drewk): Correct the Call field and print the Expr and Target.
  1053  	if s.IsCall {
  1054  		ctx.WriteString("CALL a function/procedure\n")
  1055  	} else {
  1056  		ctx.WriteString("DO a code block\n")
  1057  	}
  1058  }
  1059  
  1060  func (s *Call) PlpgSQLStatementTag() string {
  1061  	return "stmt_call"
  1062  }
  1063  
  1064  func (s *Call) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1065  	newStmt, changed = visitor.Visit(s)
  1066  	return newStmt, changed
  1067  }
  1068  
  1069  // stmt_getdiag
  1070  type GetDiagnostics struct {
  1071  	StatementImpl
  1072  	IsStacked bool
  1073  	DiagItems GetDiagnosticsItemList // TODO(drewk): what is this?
  1074  }
  1075  
  1076  func (s *GetDiagnostics) Format(ctx *tree.FmtCtx) {
  1077  	if s.IsStacked {
  1078  		ctx.WriteString("GET STACKED DIAGNOSTICS ")
  1079  	} else {
  1080  		ctx.WriteString("GET DIAGNOSTICS ")
  1081  	}
  1082  	for idx, i := range s.DiagItems {
  1083  		ctx.FormatNode(i)
  1084  		if idx != len(s.DiagItems)-1 {
  1085  			ctx.WriteString(" ")
  1086  		}
  1087  	}
  1088  	ctx.WriteString("\n")
  1089  }
  1090  
  1091  type GetDiagnosticsItem struct {
  1092  	Kind GetDiagnosticsKind
  1093  	// TODO(jane): TargetName is temporary -- should be removed and use Target.
  1094  	TargetName string
  1095  	Target     int // where to assign it?
  1096  }
  1097  
  1098  func (s *GetDiagnosticsItem) Format(ctx *tree.FmtCtx) {
  1099  	ctx.WriteString(fmt.Sprintf("%s := %s", s.TargetName, s.Kind.String()))
  1100  }
  1101  
  1102  type GetDiagnosticsItemList []*GetDiagnosticsItem
  1103  
  1104  func (s *GetDiagnostics) PlpgSQLStatementTag() string {
  1105  	return "stmt_get_diag"
  1106  }
  1107  
  1108  func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1109  	newStmt, changed = visitor.Visit(s)
  1110  	return newStmt, changed
  1111  }
  1112  
  1113  // stmt_open
  1114  type Open struct {
  1115  	StatementImpl
  1116  	CurVar Variable
  1117  	Scroll tree.CursorScrollOption
  1118  	Query  tree.Statement
  1119  }
  1120  
  1121  func (s *Open) CopyNode() *Open {
  1122  	copyNode := *s
  1123  	return &copyNode
  1124  }
  1125  
  1126  func (s *Open) Format(ctx *tree.FmtCtx) {
  1127  	ctx.WriteString("OPEN ")
  1128  	ctx.FormatNode(&s.CurVar)
  1129  	switch s.Scroll {
  1130  	case tree.Scroll:
  1131  		ctx.WriteString(" SCROLL")
  1132  	case tree.NoScroll:
  1133  		ctx.WriteString(" NO SCROLL")
  1134  	}
  1135  	if s.Query != nil {
  1136  		ctx.WriteString(" FOR ")
  1137  		ctx.FormatNode(s.Query)
  1138  	}
  1139  	ctx.WriteString(";\n")
  1140  }
  1141  
  1142  func (s *Open) PlpgSQLStatementTag() string {
  1143  	return "stmt_open"
  1144  }
  1145  
  1146  func (s *Open) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1147  	newStmt, changed = visitor.Visit(s)
  1148  	return newStmt, changed
  1149  }
  1150  
  1151  // stmt_fetch
  1152  // stmt_move (where IsMove = true)
  1153  type Fetch struct {
  1154  	StatementImpl
  1155  	Cursor tree.CursorStmt
  1156  	Target []Variable
  1157  	IsMove bool
  1158  }
  1159  
  1160  func (s *Fetch) Format(ctx *tree.FmtCtx) {
  1161  	if s.IsMove {
  1162  		ctx.WriteString("MOVE ")
  1163  	} else {
  1164  		ctx.WriteString("FETCH ")
  1165  	}
  1166  	if dir := s.Cursor.FetchType.String(); dir != "" {
  1167  		ctx.WriteString(dir)
  1168  		ctx.WriteString(" ")
  1169  	}
  1170  	if s.Cursor.FetchType.HasCount() {
  1171  		ctx.WriteString(strconv.Itoa(int(s.Cursor.Count)))
  1172  		ctx.WriteString(" ")
  1173  	}
  1174  	ctx.WriteString("FROM ")
  1175  	ctx.FormatName(string(s.Cursor.Name))
  1176  	if s.Target != nil {
  1177  		ctx.WriteString(" INTO ")
  1178  		for i := range s.Target {
  1179  			if i > 0 {
  1180  				ctx.WriteString(", ")
  1181  			}
  1182  			ctx.FormatNode(&s.Target[i])
  1183  		}
  1184  	}
  1185  	ctx.WriteString(";\n")
  1186  }
  1187  
  1188  func (s *Fetch) PlpgSQLStatementTag() string {
  1189  	if s.IsMove {
  1190  		return "stmt_move"
  1191  	}
  1192  	return "stmt_fetch"
  1193  }
  1194  
  1195  func (s *Fetch) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1196  	newStmt, changed = visitor.Visit(s)
  1197  	return newStmt, changed
  1198  }
  1199  
  1200  // stmt_close
  1201  type Close struct {
  1202  	StatementImpl
  1203  	CurVar Variable
  1204  }
  1205  
  1206  func (s *Close) Format(ctx *tree.FmtCtx) {
  1207  	ctx.WriteString("CLOSE ")
  1208  	ctx.FormatNode(&s.CurVar)
  1209  	ctx.WriteString(";\n")
  1210  }
  1211  
  1212  func (s *Close) PlpgSQLStatementTag() string {
  1213  	return "stmt_close"
  1214  }
  1215  
  1216  func (s *Close) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1217  	newStmt, changed = visitor.Visit(s)
  1218  	return newStmt, changed
  1219  }
  1220  
  1221  // stmt_commit
  1222  type Commit struct {
  1223  	StatementImpl
  1224  	Chain bool
  1225  }
  1226  
  1227  func (s *Commit) Format(ctx *tree.FmtCtx) {
  1228  }
  1229  
  1230  func (s *Commit) PlpgSQLStatementTag() string {
  1231  	return "stmt_commit"
  1232  }
  1233  
  1234  func (s *Commit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1235  	newStmt, changed = visitor.Visit(s)
  1236  	return newStmt, changed
  1237  }
  1238  
  1239  // stmt_rollback
  1240  type Rollback struct {
  1241  	StatementImpl
  1242  	Chain bool
  1243  }
  1244  
  1245  func (s *Rollback) Format(ctx *tree.FmtCtx) {
  1246  }
  1247  
  1248  func (s *Rollback) PlpgSQLStatementTag() string {
  1249  	return "stmt_rollback"
  1250  }
  1251  
  1252  func (s *Rollback) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1253  	newStmt, changed = visitor.Visit(s)
  1254  	return newStmt, changed
  1255  }
  1256  
  1257  // stmt_null
  1258  type Null struct {
  1259  	StatementImpl
  1260  }
  1261  
  1262  func (s *Null) Format(ctx *tree.FmtCtx) {
  1263  	ctx.WriteString("NULL;\n")
  1264  }
  1265  
  1266  func (s *Null) PlpgSQLStatementTag() string {
  1267  	return "stmt_null"
  1268  }
  1269  
  1270  func (s *Null) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) {
  1271  	newStmt, changed = visitor.Visit(s)
  1272  	return newStmt, changed
  1273  }
  1274  
  1275  // formatString is a helper function that prints "_" if FmtHideConstants is set,
  1276  // and otherwise prints the given string.
  1277  func formatString(ctx *tree.FmtCtx, str string) {
  1278  	if ctx.HasFlags(tree.FmtHideConstants) {
  1279  		ctx.WriteString("_")
  1280  	} else {
  1281  		ctx.WriteString(str)
  1282  	}
  1283  }
  1284  
  1285  // formatStringQuotes is similar to formatString, but surrounds the output with
  1286  // single quotes.
  1287  func formatStringQuotes(ctx *tree.FmtCtx, str string) {
  1288  	ctx.WriteString("'")
  1289  	formatString(ctx, str)
  1290  	ctx.WriteString("'")
  1291  }