
     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package mysql
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    22  	""
    23  )
    25  type verifyFunc func(interface{}) error
    27  // allowed session variables
    28  const (
    29  	SQLModeStr     = "sql_mode"
    30  	SQLSafeUpdates = "sql_safe_updates"
    31  	TimeZone       = "time_zone"
    32  )
    34  // not allowed session variables
    35  const (
    36  	MaxAllowedPacket = "max_allowed_packet"
    37  )
    39  var variableVerifyFuncMap = map[string]verifyFunc{
    40  	SQLModeStr:     verifySQLMode,
    41  	SQLSafeUpdates: verifyOnOffInteger,
    42  	TimeZone:       verifyTimeZone,
    43  }
    45  // SessionVariables variables in session
    46  type SessionVariables struct {
    47  	variables map[string]*Variable
    48  	unused    map[string]*Variable
    49  }
    51  // NewSessionVariables constructor of SessionVariables
    52  func NewSessionVariables() *SessionVariables {
    53  	return &SessionVariables{
    54  		variables: make(map[string]*Variable),
    55  		unused:    make(map[string]*Variable),
    56  	}
    57  }
    59  // Equals check if equal of SessionVariables
    60  func (s *SessionVariables) Equals(dst *SessionVariables) bool {
    61  	if len(s.variables) != len(dst.variables) {
    62  		return false
    63  	}
    65  	for _, v := range s.variables {
    66  		if dstV, ok := dst.variables[v.Name()]; !ok {
    67  			return false
    68  		} else if dstV != v {
    69  			return false
    70  		}
    71  	}
    72  	return true
    73  }
    75  // SetEqualsWith set the SessionVariables equals with the dst, and variables not contained in dst are moved to unused.
    76  func (s *SessionVariables) SetEqualsWith(dst *SessionVariables) ( /*changed*/ bool, error) {
    77  	if len(s.variables) == 0 && len(dst.variables) != 0 {
    78  		for _, v := range dst.variables {
    79  			if err := s.Set(v.Name(), v.Get()); err != nil {
    80  				return false, err
    81  			}
    82  		}
    83  		return true, nil
    84  	}
    86  	if len(s.variables) != 0 && len(dst.variables) == 0 {
    87  		for _, v := range s.variables {
    88  			s.unused[v.Name()] = v
    89  			delete(s.variables, v.Name())
    90  		}
    91  		return true, nil
    92  	}
    94  	changed := false
    95  	for variableName := range variableVerifyFuncMap {
    96  		srcVar, srcOK := s.variables[variableName]
    97  		dstVar, dstOK := dst.variables[variableName]
    98  		if srcOK && dstOK {
    99  			if srcVar.Get() != dstVar.Get() {
   100  				changed = true
   101  				srcVar.Set(dstVar.Get())
   102  			}
   103  		} else if srcOK && !dstOK {
   104  			changed = true
   105  			s.unused[variableName] = srcVar
   106  			delete(s.variables, variableName)
   107  		} else if !srcOK && dstOK {
   108  			changed = true
   109  			s.Set(variableName, dstVar.Get())
   110  		}
   111  	}
   113  	return changed, nil
   114  }
   116  // Delete delete variables with specific key
   117  func (s *SessionVariables) Delete(key string) {
   118  	delete(s.variables, formatVariableName(key))
   119  }
   121  // Set store variable in session
   122  func (s *SessionVariables) Set(key string, value interface{}) error {
   123  	formatKey := formatVariableName(key)
   124  	verifyFunc, ok := variableVerifyFuncMap[formatKey]
   125  	if !ok {
   126  		return fmt.Errorf("variable not support")
   127  	}
   129  	if variable, ok := s.variables[formatKey]; ok {
   130  		return variable.Set(value)
   131  	}
   133  	variable, err := NewVariable(formatKey, value, verifyFunc)
   134  	if err != nil {
   135  		return err
   136  	}
   137  	s.variables[formatKey] = variable
   138  	return nil
   139  }
   141  // Get return variable with specific key
   142  func (s *SessionVariables) Get(key string) (interface{}, bool) {
   143  	v, ok := s.variables[key]
   144  	return v, ok
   145  }
   147  // GetAll return all variables in session
   148  func (s *SessionVariables) GetAll() map[string]*Variable {
   149  	return s.variables
   150  }
   152  // GetUnusedAndClear unused variables
   153  func (s *SessionVariables) GetUnusedAndClear() map[string]*Variable {
   154  	unused := s.unused
   155  	s.unused = make(map[string]*Variable)
   156  	return unused
   157  }
   159  func formatVariableName(name string) string {
   160  	name = strings.Trim(name, "'`\"")
   161  	name = strings.ToLower(name)
   162  	return name
   163  }
   165  // Variable variable definition in session
   166  type Variable struct {
   167  	name   string
   168  	value  interface{}
   169  	verify verifyFunc
   170  }
   172  // NewVariable constructor of Variable
   173  func NewVariable(name string, value interface{}, verify verifyFunc) (*Variable, error) {
   174  	v := &Variable{
   175  		name:   formatVariableName(name),
   176  		value:  value,
   177  		verify: verify,
   178  	}
   179  	if err := v.verify(value); err != nil {
   180  		return nil, err
   181  	}
   182  	return v, nil
   183  }
   185  // Set store data
   186  func (v *Variable) Set(value interface{}) error {
   187  	if err := v.verify(value); err != nil {
   188  		return err
   189  	}
   190  	v.value = value
   191  	return nil
   192  }
   194  // Name name of variable
   195  func (v *Variable) Name() string {
   196  	return
   197  }
   199  // Get return value in Variable
   200  func (v *Variable) Get() interface{} {
   201  	return v.value
   202  }
   204  func verifySQLMode(v interface{}) error {
   205  	value, ok := v.(string)
   206  	if !ok {
   207  		return fmt.Errorf("invalid type of sql mode")
   208  	}
   209  	if value == "" {
   210  		return nil
   211  	}
   213  	value = strings.Trim(value, "'`\"")
   214  	value = strings.ToUpper(value)
   215  	values := strings.Split(value, ",")
   216  	for _, sqlMode := range values {
   217  		if _, ok := SQLModeSet[sqlMode]; !ok {
   218  			return errors.ErrInvalidSQLMode
   219  		}
   220  	}
   221  	return nil
   222  }
   224  // SQLModeSet
   225  var SQLModeSet = map[string]bool{
   226  	// Full List of SQL Modes
   227  	"ALLOW_INVALID_DATES":        true,
   228  	"ANSI_QUOTES":                true,
   229  	"ERROR_FOR_DIVISION_BY_ZERO": true,
   230  	"HIGH_NOT_PRECEDENCE":        true,
   231  	"IGNORE_SPACE":               true,
   232  	"NO_AUTO_CREATE_USER":        true,
   233  	"NO_AUTO_VALUE_ON_ZERO":      true,
   234  	"NO_BACKSLASH_ESCAPES":       true,
   235  	"NO_DIR_IN_CREATE":           true,
   236  	"NO_ENGINE_SUBSTITUTION":     true,
   237  	"NO_FIELD_OPTIONS":           true,
   238  	"NO_KEY_OPTIONS":             true,
   239  	"NO_TABLE_OPTIONS":           true,
   240  	"NO_UNSIGNED_SUBTRACTION":    true,
   241  	"NO_ZERO_DATE":               true,
   242  	"NO_ZERO_IN_DATE":            true,
   243  	"ONLY_FULL_GROUP_BY":         true,
   244  	"PAD_CHAR_TO_FULL_LENGTH":    true,
   245  	"PIPES_AS_CONCAT":            true,
   246  	"REAL_AS_FLOAT":              true,
   247  	"STRICT_ALL_TABLES":          true,
   248  	"STRICT_TRANS_TABLES":        true,
   250  	// Combination SQL Modes
   251  	"ANSI":        true,
   252  	"DB2":         true,
   253  	"MAXDB":       true,
   254  	"MSSQL":       true,
   255  	"MYSQL323":    true,
   256  	"MYSQL40":     true,
   257  	"ORACLE":      true,
   258  	"POSTGRESQL":  true,
   259  	"TRADITIONAL": true,
   260  }
   262  func verifyOnOffInteger(v interface{}) error {
   263  	val, ok := v.(int64)
   264  	if !ok {
   265  		return fmt.Errorf("value is not int64")
   266  	}
   267  	if val != 0 && val != 1 {
   268  		return fmt.Errorf("value is not 0 or 1")
   269  	}
   270  	return nil
   271  }
   273  func verifyTimeZone(v interface{}) error {
   274  	value, ok := v.(string)
   275  	if !ok {
   276  		return fmt.Errorf("invalid type of time_zone")
   277  	}
   278  	values := strings.Split(value, ":")
   279  	if len(values) != 2 {
   280  		return fmt.Errorf("invalid format of time_zone")
   281  	}
   282  	if values[0][0] != '+' && values[0][0] != '-' {
   283  		return fmt.Errorf("invalid format of time_zone")
   284  	}
   285  	hour, err := strconv.Atoi(values[0])
   286  	if err != nil {
   287  		return fmt.Errorf("invalid hour of time_zone")
   288  	}
   289  	minute, err := strconv.Atoi(values[1])
   290  	if err != nil {
   291  		return fmt.Errorf("invalid minute of time_zone")
   292  	}
   293  	var directMinute int
   294  	if hour < 0 {
   295  		directMinute = hour*60 - minute
   296  	} else {
   297  		directMinute = hour*60 + minute
   298  	}
   299  	if directMinute < -779 || directMinute > 780 {
   300  		return fmt.Errorf("exceed limit of time_zone")
   301  	}
   303  	return nil
   304  }