github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/set_var.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"context"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/cockroachdb/apd"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    23  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    24  	"github.com/cockroachdb/errors"
    25  )
    26  
    27  // setVarNode represents a SET SESSION statement.
    28  type setVarNode struct {
    29  	name string
    30  	v    sessionVar
    31  	// typedValues == nil means RESET.
    32  	typedValues []tree.TypedExpr
    33  }
    34  
    35  // SetVar sets session variables.
    36  // Privileges: None.
    37  //   Notes: postgres/mysql do not require privileges for session variables (some exceptions).
    38  func (p *planner) SetVar(ctx context.Context, n *tree.SetVar) (planNode, error) {
    39  	if n.Name == "" {
    40  		// A client has sent the reserved internal syntax SET ROW ...,
    41  		// or the user entered `SET "" = foo`. Reject it.
    42  		return nil, pgerror.Newf(pgcode.Syntax,
    43  			"invalid variable name: %q", n.Name)
    44  	}
    45  
    46  	name := strings.ToLower(n.Name)
    47  	_, v, err := getSessionVar(name, false /* missingOk */)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	var typedValues []tree.TypedExpr
    53  	if len(n.Values) > 0 {
    54  		isReset := false
    55  		if len(n.Values) == 1 {
    56  			if _, ok := n.Values[0].(tree.DefaultVal); ok {
    57  				// "SET var = DEFAULT" means RESET.
    58  				// In that case, we want typedValues to remain nil, so that
    59  				// the Start() logic recognizes the RESET too.
    60  				isReset = true
    61  			}
    62  		}
    63  
    64  		if !isReset {
    65  			typedValues = make([]tree.TypedExpr, len(n.Values))
    66  			for i, expr := range n.Values {
    67  				expr = unresolvedNameToStrVal(expr)
    68  
    69  				var dummyHelper tree.IndexedVarHelper
    70  				typedValue, err := p.analyzeExpr(
    71  					ctx, expr, nil, dummyHelper, types.String, false, "SET SESSION "+name)
    72  				if err != nil {
    73  					return nil, wrapSetVarError(name, expr.String(), "%v", err)
    74  				}
    75  				typedValues[i] = typedValue
    76  			}
    77  		}
    78  	}
    79  
    80  	if v.Set == nil && v.RuntimeSet == nil {
    81  		return nil, newCannotChangeParameterError(name)
    82  	}
    83  
    84  	if typedValues == nil {
    85  		// Statement is RESET. Do we have a default available?
    86  		// We do not use getDefaultString here because we need to delay
    87  		// the computation of the default to the execute phase.
    88  		if _, ok := p.sessionDataMutator.defaults[name]; !ok && v.GlobalDefault == nil {
    89  			return nil, newCannotChangeParameterError(name)
    90  		}
    91  	}
    92  
    93  	return &setVarNode{name: name, v: v, typedValues: typedValues}, nil
    94  }
    95  
    96  // Special rule for SET: because SET doesn't apply in the context
    97  // of a table, SET ... = IDENT really means SET ... = 'IDENT'.
    98  func unresolvedNameToStrVal(expr tree.Expr) tree.Expr {
    99  	if s, ok := expr.(*tree.UnresolvedName); ok {
   100  		return tree.NewStrVal(tree.AsStringWithFlags(s, tree.FmtBareIdentifiers))
   101  	}
   102  	return expr
   103  }
   104  
   105  func (n *setVarNode) startExec(params runParams) error {
   106  	var strVal string
   107  	if n.typedValues != nil {
   108  		for i, v := range n.typedValues {
   109  			d, err := v.Eval(params.EvalContext())
   110  			if err != nil {
   111  				return err
   112  			}
   113  			n.typedValues[i] = d
   114  		}
   115  		var err error
   116  		if n.v.GetStringVal != nil {
   117  			strVal, err = n.v.GetStringVal(params.ctx, params.extendedEvalCtx, n.typedValues)
   118  		} else {
   119  			// No string converter defined, use the default one.
   120  			strVal, err = getStringVal(params.EvalContext(), n.name, n.typedValues)
   121  		}
   122  		if err != nil {
   123  			return err
   124  		}
   125  	} else {
   126  		// Statement is RESET and we already know we have a default. Find it.
   127  		_, strVal = getSessionVarDefaultString(n.name, n.v, params.p.sessionDataMutator)
   128  	}
   129  
   130  	if n.v.RuntimeSet != nil {
   131  		return n.v.RuntimeSet(params.ctx, params.extendedEvalCtx, strVal)
   132  	}
   133  	return n.v.Set(params.ctx, params.p.sessionDataMutator, strVal)
   134  }
   135  
   136  // getSessionVarDefaultString retrieves a string suitable to pass to a
   137  // session var's Set() method. First return value is false if there is
   138  // no default.
   139  func getSessionVarDefaultString(
   140  	varName string, v sessionVar, m *sessionDataMutator,
   141  ) (bool, string) {
   142  	if defVal, ok := m.defaults[varName]; ok {
   143  		return true, defVal
   144  	}
   145  	if v.GlobalDefault != nil {
   146  		return true, v.GlobalDefault(&m.settings.SV)
   147  	}
   148  	return false, ""
   149  }
   150  
   151  func (n *setVarNode) Next(_ runParams) (bool, error) { return false, nil }
   152  func (n *setVarNode) Values() tree.Datums            { return nil }
   153  func (n *setVarNode) Close(_ context.Context)        {}
   154  
   155  func datumAsString(evalCtx *tree.EvalContext, name string, value tree.TypedExpr) (string, error) {
   156  	val, err := value.Eval(evalCtx)
   157  	if err != nil {
   158  		return "", err
   159  	}
   160  	s, ok := tree.AsDString(val)
   161  	if !ok {
   162  		err = pgerror.Newf(pgcode.InvalidParameterValue,
   163  			"parameter %q requires a string value", name)
   164  		err = errors.WithDetailf(err,
   165  			"%s is a %s", value, errors.Safe(val.ResolvedType()))
   166  		return "", err
   167  	}
   168  	return string(s), nil
   169  }
   170  
   171  func getStringVal(evalCtx *tree.EvalContext, name string, values []tree.TypedExpr) (string, error) {
   172  	if len(values) != 1 {
   173  		return "", newSingleArgVarError(name)
   174  	}
   175  	return datumAsString(evalCtx, name, values[0])
   176  }
   177  
   178  func datumAsInt(evalCtx *tree.EvalContext, name string, value tree.TypedExpr) (int64, error) {
   179  	val, err := value.Eval(evalCtx)
   180  	if err != nil {
   181  		return 0, err
   182  	}
   183  	iv, ok := tree.AsDInt(val)
   184  	if !ok {
   185  		err = pgerror.Newf(pgcode.InvalidParameterValue,
   186  			"parameter %q requires an integer value", name)
   187  		err = errors.WithDetailf(err,
   188  			"%s is a %s", value, errors.Safe(val.ResolvedType()))
   189  		return 0, err
   190  	}
   191  	return int64(iv), nil
   192  }
   193  
   194  func getIntVal(evalCtx *tree.EvalContext, name string, values []tree.TypedExpr) (int64, error) {
   195  	if len(values) != 1 {
   196  		return 0, newSingleArgVarError(name)
   197  	}
   198  	return datumAsInt(evalCtx, name, values[0])
   199  }
   200  
   201  func timeZoneVarGetStringVal(
   202  	_ context.Context, evalCtx *extendedEvalContext, values []tree.TypedExpr,
   203  ) (string, error) {
   204  	if len(values) != 1 {
   205  		return "", newSingleArgVarError("timezone")
   206  	}
   207  	d, err := values[0].Eval(&evalCtx.EvalContext)
   208  	if err != nil {
   209  		return "", err
   210  	}
   211  
   212  	var loc *time.Location
   213  	var offset int64
   214  	switch v := tree.UnwrapDatum(&evalCtx.EvalContext, d).(type) {
   215  	case *tree.DString:
   216  		location := string(*v)
   217  		loc, err = timeutil.TimeZoneStringToLocation(
   218  			location,
   219  			timeutil.TimeZoneStringToLocationISO8601Standard,
   220  		)
   221  		if err != nil {
   222  			return "", wrapSetVarError("timezone", values[0].String(),
   223  				"cannot find time zone %q: %v", location, err)
   224  		}
   225  
   226  	case *tree.DInterval:
   227  		offset, _, _, err = v.Duration.Encode()
   228  		if err != nil {
   229  			return "", wrapSetVarError("timezone", values[0].String(), "%v", err)
   230  		}
   231  		offset /= int64(time.Second)
   232  
   233  	case *tree.DInt:
   234  		offset = int64(*v) * 60 * 60
   235  
   236  	case *tree.DFloat:
   237  		offset = int64(float64(*v) * 60.0 * 60.0)
   238  
   239  	case *tree.DDecimal:
   240  		sixty := apd.New(60, 0)
   241  		ed := apd.MakeErrDecimal(tree.ExactCtx)
   242  		ed.Mul(sixty, sixty, sixty)
   243  		ed.Mul(sixty, sixty, &v.Decimal)
   244  		offset = ed.Int64(sixty)
   245  		if ed.Err() != nil {
   246  			return "", wrapSetVarError("timezone", values[0].String(),
   247  				"time zone value %s would overflow an int64", sixty)
   248  		}
   249  
   250  	default:
   251  		return "", newVarValueError("timezone", values[0].String())
   252  	}
   253  	if loc == nil {
   254  		loc = timeutil.FixedOffsetTimeZoneToLocation(int(offset), d.String())
   255  	}
   256  
   257  	return loc.String(), nil
   258  }
   259  
   260  func timeZoneVarSet(_ context.Context, m *sessionDataMutator, s string) error {
   261  	loc, err := timeutil.TimeZoneStringToLocation(
   262  		s,
   263  		timeutil.TimeZoneStringToLocationISO8601Standard,
   264  	)
   265  	if err != nil {
   266  		return wrapSetVarError("TimeZone", s, "%v", err)
   267  	}
   268  
   269  	m.SetLocation(loc)
   270  	return nil
   271  }
   272  
   273  func stmtTimeoutVarGetStringVal(
   274  	ctx context.Context, evalCtx *extendedEvalContext, values []tree.TypedExpr,
   275  ) (string, error) {
   276  	if len(values) != 1 {
   277  		return "", newSingleArgVarError("statement_timeout")
   278  	}
   279  	d, err := values[0].Eval(&evalCtx.EvalContext)
   280  	if err != nil {
   281  		return "", err
   282  	}
   283  
   284  	var timeout time.Duration
   285  	switch v := tree.UnwrapDatum(&evalCtx.EvalContext, d).(type) {
   286  	case *tree.DString:
   287  		return string(*v), nil
   288  	case *tree.DInterval:
   289  		timeout, err = intervalToDuration(v)
   290  		if err != nil {
   291  			return "", wrapSetVarError("statement_timeout", values[0].String(), "%v", err)
   292  		}
   293  	case *tree.DInt:
   294  		timeout = time.Duration(*v) * time.Millisecond
   295  	}
   296  	return timeout.String(), nil
   297  }
   298  
   299  func stmtTimeoutVarSet(ctx context.Context, m *sessionDataMutator, s string) error {
   300  	interval, err := tree.ParseDIntervalWithTypeMetadata(s, types.IntervalTypeMetadata{
   301  		DurationField: types.IntervalDurationField{
   302  			DurationType: types.IntervalDurationType_MILLISECOND,
   303  		},
   304  	})
   305  	if err != nil {
   306  		return wrapSetVarError("statement_timeout", s, "%v", err)
   307  	}
   308  	timeout, err := intervalToDuration(interval)
   309  	if err != nil {
   310  		return wrapSetVarError("statement_timeout", s, "%v", err)
   311  	}
   312  
   313  	if timeout < 0 {
   314  		return wrapSetVarError("statement_timeout", s,
   315  			"statement_timeout cannot have a negative duration")
   316  	}
   317  	m.SetStmtTimeout(timeout)
   318  	return nil
   319  }
   320  
   321  func intervalToDuration(interval *tree.DInterval) (time.Duration, error) {
   322  	nanos, _, _, err := interval.Encode()
   323  	if err != nil {
   324  		return 0, err
   325  	}
   326  	return time.Duration(nanos), nil
   327  }
   328  
   329  func newSingleArgVarError(varName string) error {
   330  	return pgerror.Newf(pgcode.InvalidParameterValue,
   331  		"SET %s takes only one argument", varName)
   332  }
   333  
   334  func wrapSetVarError(varName, actualValue string, fmt string, args ...interface{}) error {
   335  	err := pgerror.Newf(pgcode.InvalidParameterValue,
   336  		"invalid value for parameter %q: %q", varName, actualValue)
   337  	return errors.WithDetailf(err, fmt, args...)
   338  }
   339  
   340  func newVarValueError(varName, actualVal string, allowedVals ...string) (err error) {
   341  	err = pgerror.Newf(pgcode.InvalidParameterValue,
   342  		"invalid value for parameter %q: %q", varName, actualVal)
   343  	if len(allowedVals) > 0 {
   344  		err = errors.WithHintf(err, "Available values: %s", strings.Join(allowedVals, ","))
   345  	}
   346  	return err
   347  }
   348  
   349  func newCannotChangeParameterError(varName string) error {
   350  	return pgerror.Newf(pgcode.CantChangeRuntimeParam,
   351  		"parameter %q cannot be changed", varName)
   352  }