github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/signal.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/mysql"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  )
    26  
    27  // SignalConditionItemName represents the item name for the set conditions of a SIGNAL statement.
    28  type SignalConditionItemName string
    29  
    30  const (
    31  	SignalConditionItemName_Unknown           SignalConditionItemName = ""
    32  	SignalConditionItemName_ClassOrigin       SignalConditionItemName = "class_origin"
    33  	SignalConditionItemName_SubclassOrigin    SignalConditionItemName = "subclass_origin"
    34  	SignalConditionItemName_MessageText       SignalConditionItemName = "message_text"
    35  	SignalConditionItemName_MysqlErrno        SignalConditionItemName = "mysql_errno"
    36  	SignalConditionItemName_ConstraintCatalog SignalConditionItemName = "constraint_catalog"
    37  	SignalConditionItemName_ConstraintSchema  SignalConditionItemName = "constraint_schema"
    38  	SignalConditionItemName_ConstraintName    SignalConditionItemName = "constraint_name"
    39  	SignalConditionItemName_CatalogName       SignalConditionItemName = "catalog_name"
    40  	SignalConditionItemName_SchemaName        SignalConditionItemName = "schema_name"
    41  	SignalConditionItemName_TableName         SignalConditionItemName = "table_name"
    42  	SignalConditionItemName_ColumnName        SignalConditionItemName = "column_name"
    43  	SignalConditionItemName_CursorName        SignalConditionItemName = "cursor_name"
    44  )
    45  
    46  var SignalItems = []SignalConditionItemName{
    47  	SignalConditionItemName_ClassOrigin,
    48  	SignalConditionItemName_SubclassOrigin,
    49  	SignalConditionItemName_MessageText,
    50  	SignalConditionItemName_MysqlErrno,
    51  	SignalConditionItemName_ConstraintCatalog,
    52  	SignalConditionItemName_ConstraintSchema,
    53  	SignalConditionItemName_ConstraintName,
    54  	SignalConditionItemName_CatalogName,
    55  	SignalConditionItemName_SchemaName,
    56  	SignalConditionItemName_TableName,
    57  	SignalConditionItemName_ColumnName,
    58  	SignalConditionItemName_CursorName,
    59  }
    60  
    61  // SignalInfo represents a piece of information for a SIGNAL statement.
    62  type SignalInfo struct {
    63  	ConditionItemName SignalConditionItemName
    64  	IntValue          int64
    65  	StrValue          string
    66  	ExprVal           sql.Expression
    67  }
    68  
    69  // Signal represents the SIGNAL statement with a set SQLSTATE.
    70  type Signal struct {
    71  	SqlStateValue string // Will always be a string with length 5
    72  	Info          map[SignalConditionItemName]SignalInfo
    73  }
    74  
    75  // SignalName represents the SIGNAL statement with a condition name.
    76  type SignalName struct {
    77  	Signal *Signal
    78  	Name   string
    79  }
    80  
    81  var _ sql.Node = (*Signal)(nil)
    82  var _ sql.Node = (*SignalName)(nil)
    83  var _ sql.Expressioner = (*Signal)(nil)
    84  var _ sql.CollationCoercible = (*Signal)(nil)
    85  var _ sql.CollationCoercible = (*SignalName)(nil)
    86  
    87  // NewSignal returns a *Signal node.
    88  func NewSignal(sqlstate string, info map[SignalConditionItemName]SignalInfo) *Signal {
    89  	// https://dev.mysql.com/doc/refman/8.0/en/signal.html#signal-condition-information-items
    90  	// https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
    91  	firstTwo := sqlstate[0:2]
    92  	if _, ok := info[SignalConditionItemName_MessageText]; !ok {
    93  		si := SignalInfo{
    94  			ConditionItemName: SignalConditionItemName_MessageText,
    95  		}
    96  		switch firstTwo {
    97  		case "01":
    98  			si.StrValue = "Unhandled user-defined warning condition"
    99  		case "02":
   100  			si.StrValue = "Unhandled user-defined not found condition"
   101  		default:
   102  			si.StrValue = "Unhandled user-defined exception condition"
   103  		}
   104  		info[SignalConditionItemName_MessageText] = si
   105  	}
   106  	if _, ok := info[SignalConditionItemName_MysqlErrno]; !ok {
   107  		si := SignalInfo{
   108  			ConditionItemName: SignalConditionItemName_MysqlErrno,
   109  		}
   110  		switch firstTwo {
   111  		case "01":
   112  			si.IntValue = 1642
   113  		case "02":
   114  			si.IntValue = 1643
   115  		default:
   116  			si.IntValue = 1644
   117  		}
   118  		info[SignalConditionItemName_MysqlErrno] = si
   119  	}
   120  	return &Signal{
   121  		SqlStateValue: sqlstate,
   122  		Info:          info,
   123  	}
   124  }
   125  
   126  // NewSignalName returns a *SignalName node.
   127  func NewSignalName(name string, info map[SignalConditionItemName]SignalInfo) *SignalName {
   128  	return &SignalName{
   129  		Signal: &Signal{
   130  			Info: info,
   131  		},
   132  		Name: name,
   133  	}
   134  }
   135  
   136  // Resolved implements the sql.Node interface.
   137  func (s *Signal) Resolved() bool {
   138  	for _, e := range s.Expressions() {
   139  		if !e.Resolved() {
   140  			return false
   141  		}
   142  	}
   143  	return true
   144  }
   145  
   146  // String implements the sql.Node interface.
   147  func (s *Signal) String() string {
   148  	infoStr := ""
   149  	if len(s.Info) > 0 {
   150  		infoStr = " SET"
   151  		i := 0
   152  		for _, k := range SignalItems {
   153  			// enforce deterministic ordering
   154  			if info, ok := s.Info[k]; ok {
   155  				if i > 0 {
   156  					infoStr += ","
   157  				}
   158  				infoStr += " " + info.String()
   159  				i++
   160  			}
   161  		}
   162  	}
   163  	return fmt.Sprintf("SIGNAL SQLSTATE '%s'%s", s.SqlStateValue, infoStr)
   164  }
   165  
   166  func (s *Signal) IsReadOnly() bool {
   167  	return true
   168  }
   169  
   170  // DebugString implements the sql.DebugStringer interface.
   171  func (s *Signal) DebugString() string {
   172  	infoStr := ""
   173  	if len(s.Info) > 0 {
   174  		infoStr = " SET"
   175  		i := 0
   176  		for _, k := range SignalItems {
   177  			// enforce deterministic ordering
   178  			if info, ok := s.Info[k]; ok {
   179  				if i > 0 {
   180  					infoStr += ","
   181  				}
   182  				infoStr += " " + info.DebugString()
   183  				i++
   184  			}
   185  		}
   186  	}
   187  	return fmt.Sprintf("SIGNAL SQLSTATE '%s'%s", s.SqlStateValue, infoStr)
   188  }
   189  
   190  // Schema implements the sql.Node interface.
   191  func (s *Signal) Schema() sql.Schema {
   192  	return nil
   193  }
   194  
   195  // Children implements the sql.Node interface.
   196  func (s *Signal) Children() []sql.Node {
   197  	return nil
   198  }
   199  
   200  // WithChildren implements the sql.Node interface.
   201  func (s *Signal) WithChildren(children ...sql.Node) (sql.Node, error) {
   202  	return NillaryWithChildren(s, children...)
   203  }
   204  
   205  func (s *Signal) Expressions() []sql.Expression {
   206  	items := s.signalItemsWithExpressions()
   207  
   208  	var exprs []sql.Expression
   209  	for _, itemInfo := range items {
   210  		exprs = append(exprs, itemInfo.ExprVal)
   211  	}
   212  
   213  	return exprs
   214  }
   215  
   216  // signalItemsWithExpressions returns the subset of the Info map entries that have an expression value, sorted by
   217  // item name
   218  func (s *Signal) signalItemsWithExpressions() []SignalInfo {
   219  	var items []SignalInfo
   220  
   221  	for _, itemInfo := range s.Info {
   222  		if itemInfo.ExprVal != nil {
   223  			items = append(items, itemInfo)
   224  		}
   225  	}
   226  
   227  	// Very important to have a consistent sort order between here and the WithExpressions call
   228  	sort.Slice(items, func(i, j int) bool {
   229  		return items[i].ConditionItemName < items[j].ConditionItemName
   230  	})
   231  
   232  	return items
   233  }
   234  
   235  func (s Signal) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   236  	itemsWithExprs := s.signalItemsWithExpressions()
   237  	if len(itemsWithExprs) != len(exprs) {
   238  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(itemsWithExprs))
   239  	}
   240  
   241  	mapCopy := make(map[SignalConditionItemName]SignalInfo)
   242  	for k, v := range s.Info {
   243  		mapCopy[k] = v
   244  	}
   245  
   246  	for i := range exprs {
   247  		// transfer the expression to the new info map
   248  		newInfo := itemsWithExprs[i]
   249  		newInfo.ExprVal = exprs[i]
   250  		mapCopy[itemsWithExprs[i].ConditionItemName] = newInfo
   251  	}
   252  
   253  	s.Info = mapCopy
   254  	return &s, nil
   255  }
   256  
   257  // CheckPrivileges implements the interface sql.Node.
   258  func (s *Signal) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   259  	return true
   260  }
   261  
   262  // CollationCoercibility implements the interface sql.CollationCoercible.
   263  func (*Signal) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   264  	return sql.Collation_binary, 7
   265  }
   266  
   267  // RowIter implements the sql.Node interface.
   268  func (s *Signal) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
   269  	//TODO: implement CLASS_ORIGIN
   270  	//TODO: implement SUBCLASS_ORIGIN
   271  	//TODO: implement CONSTRAINT_CATALOG
   272  	//TODO: implement CONSTRAINT_SCHEMA
   273  	//TODO: implement CONSTRAINT_NAME
   274  	//TODO: implement CATALOG_NAME
   275  	//TODO: implement SCHEMA_NAME
   276  	//TODO: implement TABLE_NAME
   277  	//TODO: implement COLUMN_NAME
   278  	//TODO: implement CURSOR_NAME
   279  	if s.SqlStateValue[0:2] == "01" {
   280  		//TODO: implement warnings
   281  		return nil, fmt.Errorf("warnings not yet implemented")
   282  	} else {
   283  
   284  		messageItem := s.Info[SignalConditionItemName_MessageText]
   285  		strValue := messageItem.StrValue
   286  		if messageItem.ExprVal != nil {
   287  			exprResult, err := messageItem.ExprVal.Eval(ctx, nil)
   288  			if err != nil {
   289  				return nil, err
   290  			}
   291  			s, ok := exprResult.(string)
   292  			if !ok {
   293  				return nil, fmt.Errorf("message text expression did not evaluate to a string")
   294  			}
   295  			strValue = s
   296  		}
   297  
   298  		return nil, mysql.NewSQLError(
   299  			int(s.Info[SignalConditionItemName_MysqlErrno].IntValue),
   300  			s.SqlStateValue,
   301  			strValue,
   302  		)
   303  	}
   304  }
   305  
   306  // Resolved implements the sql.Node interface.
   307  func (s *SignalName) Resolved() bool {
   308  	return true
   309  }
   310  
   311  // String implements the sql.Node interface.
   312  func (s *SignalName) String() string {
   313  	infoStr := ""
   314  	if len(s.Signal.Info) > 0 {
   315  		infoStr = " SET"
   316  		i := 0
   317  		for _, info := range s.Signal.Info {
   318  			if i > 0 {
   319  				infoStr += ","
   320  			}
   321  			infoStr += " " + info.String()
   322  			i++
   323  		}
   324  	}
   325  	return fmt.Sprintf("SIGNAL %s%s", s.Name, infoStr)
   326  }
   327  
   328  // Schema implements the sql.Node interface.
   329  func (s *SignalName) Schema() sql.Schema {
   330  	return nil
   331  }
   332  
   333  func (s *SignalName) IsReadOnly() bool {
   334  	return true
   335  }
   336  
   337  // Children implements the sql.Node interface.
   338  func (s *SignalName) Children() []sql.Node {
   339  	return nil // SignalName is an alternate form of Signal rather than an encapsulating node, thus no children
   340  }
   341  
   342  // WithChildren implements the sql.Node interface.
   343  func (s *SignalName) WithChildren(children ...sql.Node) (sql.Node, error) {
   344  	return NillaryWithChildren(s, children...)
   345  }
   346  
   347  // CheckPrivileges implements the interface sql.Node.
   348  func (s *SignalName) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   349  	return true
   350  }
   351  
   352  // CollationCoercibility implements the interface sql.CollationCoercible.
   353  func (*SignalName) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   354  	return sql.Collation_binary, 7
   355  }
   356  
   357  // RowIter implements the sql.Node interface.
   358  func (s *SignalName) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
   359  	return nil, fmt.Errorf("may not iterate over unresolved node *SignalName")
   360  }
   361  
   362  func (s SignalInfo) IsReadOnly() bool {
   363  	return true
   364  }
   365  
   366  func (s SignalInfo) String() string {
   367  	itemName := strings.ToUpper(string(s.ConditionItemName))
   368  	if s.ExprVal != nil {
   369  		return fmt.Sprintf("%s = %s", itemName, s.ExprVal.String())
   370  	} else if s.ConditionItemName == SignalConditionItemName_MysqlErrno {
   371  		return fmt.Sprintf("%s = %d", itemName, s.IntValue)
   372  	}
   373  	return fmt.Sprintf("%s = %s", itemName, s.StrValue)
   374  }
   375  
   376  func (s SignalInfo) DebugString() string {
   377  	itemName := strings.ToUpper(string(s.ConditionItemName))
   378  	if s.ExprVal != nil {
   379  		return fmt.Sprintf("%s = %s", itemName, sql.DebugString(s.ExprVal))
   380  	} else if s.ConditionItemName == SignalConditionItemName_MysqlErrno {
   381  		return fmt.Sprintf("%s = %d", itemName, s.IntValue)
   382  	}
   383  	return fmt.Sprintf("%s = %s", itemName, s.StrValue)
   384  }