github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/discard.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"context"
    15  
    16  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    19  	"github.com/cockroachdb/errors"
    20  )
    21  
    22  // Discard implements the DISCARD statement.
    23  // See https://www.postgresql.org/docs/9.6/static/sql-discard.html for details.
    24  func (p *planner) Discard(ctx context.Context, s *tree.Discard) (planNode, error) {
    25  	switch s.Mode {
    26  	case tree.DiscardModeAll:
    27  		if !p.autoCommit {
    28  			return nil, pgerror.New(pgcode.ActiveSQLTransaction,
    29  				"DISCARD ALL cannot run inside a transaction block")
    30  		}
    31  
    32  		// RESET ALL
    33  		if err := resetSessionVars(ctx, p.sessionDataMutator); err != nil {
    34  			return nil, err
    35  		}
    36  
    37  		// DEALLOCATE ALL
    38  		p.preparedStatements.DeleteAll(ctx)
    39  	default:
    40  		return nil, errors.AssertionFailedf("unknown mode for DISCARD: %d", s.Mode)
    41  	}
    42  	return newZeroNode(nil /* columns */), nil
    43  }
    44  
    45  func resetSessionVars(ctx context.Context, m *sessionDataMutator) error {
    46  	for _, varName := range varNames {
    47  		v := varGen[varName]
    48  		if v.Set != nil {
    49  			hasDefault, defVal := getSessionVarDefaultString(varName, v, m)
    50  			if hasDefault {
    51  				if err := v.Set(ctx, m, defVal); err != nil {
    52  					return err
    53  				}
    54  			}
    55  		}
    56  	}
    57  	return nil
    58  }