github.com/XiaoMi/Gaea@v1.2.5/mysql/variables.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 22 "github.com/XiaoMi/Gaea/core/errors" 23 ) 24 25 type verifyFunc func(interface{}) error 26 27 // allowed session variables 28 const ( 29 SQLModeStr = "sql_mode" 30 SQLSafeUpdates = "sql_safe_updates" 31 TimeZone = "time_zone" 32 ) 33 34 // not allowed session variables 35 const ( 36 MaxAllowedPacket = "max_allowed_packet" 37 ) 38 39 var variableVerifyFuncMap = map[string]verifyFunc{ 40 SQLModeStr: verifySQLMode, 41 SQLSafeUpdates: verifyOnOffInteger, 42 TimeZone: verifyTimeZone, 43 } 44 45 // SessionVariables variables in session 46 type SessionVariables struct { 47 variables map[string]*Variable 48 unused map[string]*Variable 49 } 50 51 // NewSessionVariables constructor of SessionVariables 52 func NewSessionVariables() *SessionVariables { 53 return &SessionVariables{ 54 variables: make(map[string]*Variable), 55 unused: make(map[string]*Variable), 56 } 57 } 58 59 // Equals check if equal of SessionVariables 60 func (s *SessionVariables) Equals(dst *SessionVariables) bool { 61 if len(s.variables) != len(dst.variables) { 62 return false 63 } 64 65 for _, v := range s.variables { 66 if dstV, ok := dst.variables[v.Name()]; !ok { 67 return false 68 } else if dstV != v { 69 return false 70 } 71 } 72 return true 73 } 74 75 // SetEqualsWith set the SessionVariables equals with the dst, and variables not contained in dst are moved to unused. 76 func (s *SessionVariables) SetEqualsWith(dst *SessionVariables) ( /*changed*/ bool, error) { 77 if len(s.variables) == 0 && len(dst.variables) != 0 { 78 for _, v := range dst.variables { 79 if err := s.Set(v.Name(), v.Get()); err != nil { 80 return false, err 81 } 82 } 83 return true, nil 84 } 85 86 if len(s.variables) != 0 && len(dst.variables) == 0 { 87 for _, v := range s.variables { 88 s.unused[v.Name()] = v 89 delete(s.variables, v.Name()) 90 } 91 return true, nil 92 } 93 94 changed := false 95 for variableName := range variableVerifyFuncMap { 96 srcVar, srcOK := s.variables[variableName] 97 dstVar, dstOK := dst.variables[variableName] 98 if srcOK && dstOK { 99 if srcVar.Get() != dstVar.Get() { 100 changed = true 101 srcVar.Set(dstVar.Get()) 102 } 103 } else if srcOK && !dstOK { 104 changed = true 105 s.unused[variableName] = srcVar 106 delete(s.variables, variableName) 107 } else if !srcOK && dstOK { 108 changed = true 109 s.Set(variableName, dstVar.Get()) 110 } 111 } 112 113 return changed, nil 114 } 115 116 // Delete delete variables with specific key 117 func (s *SessionVariables) Delete(key string) { 118 delete(s.variables, formatVariableName(key)) 119 } 120 121 // Set store variable in session 122 func (s *SessionVariables) Set(key string, value interface{}) error { 123 formatKey := formatVariableName(key) 124 verifyFunc, ok := variableVerifyFuncMap[formatKey] 125 if !ok { 126 return fmt.Errorf("variable not support") 127 } 128 129 if variable, ok := s.variables[formatKey]; ok { 130 return variable.Set(value) 131 } 132 133 variable, err := NewVariable(formatKey, value, verifyFunc) 134 if err != nil { 135 return err 136 } 137 s.variables[formatKey] = variable 138 return nil 139 } 140 141 // Get return variable with specific key 142 func (s *SessionVariables) Get(key string) (interface{}, bool) { 143 v, ok := s.variables[key] 144 return v, ok 145 } 146 147 // GetAll return all variables in session 148 func (s *SessionVariables) GetAll() map[string]*Variable { 149 return s.variables 150 } 151 152 // GetUnusedAndClear unused variables 153 func (s *SessionVariables) GetUnusedAndClear() map[string]*Variable { 154 unused := s.unused 155 s.unused = make(map[string]*Variable) 156 return unused 157 } 158 159 func formatVariableName(name string) string { 160 name = strings.Trim(name, "'`\"") 161 name = strings.ToLower(name) 162 return name 163 } 164 165 // Variable variable definition in session 166 type Variable struct { 167 name string 168 value interface{} 169 verify verifyFunc 170 } 171 172 // NewVariable constructor of Variable 173 func NewVariable(name string, value interface{}, verify verifyFunc) (*Variable, error) { 174 v := &Variable{ 175 name: formatVariableName(name), 176 value: value, 177 verify: verify, 178 } 179 if err := v.verify(value); err != nil { 180 return nil, err 181 } 182 return v, nil 183 } 184 185 // Set store data 186 func (v *Variable) Set(value interface{}) error { 187 if err := v.verify(value); err != nil { 188 return err 189 } 190 v.value = value 191 return nil 192 } 193 194 // Name name of variable 195 func (v *Variable) Name() string { 196 return v.name 197 } 198 199 // Get return value in Variable 200 func (v *Variable) Get() interface{} { 201 return v.value 202 } 203 204 func verifySQLMode(v interface{}) error { 205 value, ok := v.(string) 206 if !ok { 207 return fmt.Errorf("invalid type of sql mode") 208 } 209 if value == "" { 210 return nil 211 } 212 213 value = strings.Trim(value, "'`\"") 214 value = strings.ToUpper(value) 215 values := strings.Split(value, ",") 216 for _, sqlMode := range values { 217 if _, ok := SQLModeSet[sqlMode]; !ok { 218 return errors.ErrInvalidSQLMode 219 } 220 } 221 return nil 222 } 223 224 // SQLModeSet https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html 225 var SQLModeSet = map[string]bool{ 226 // Full List of SQL Modes 227 "ALLOW_INVALID_DATES": true, 228 "ANSI_QUOTES": true, 229 "ERROR_FOR_DIVISION_BY_ZERO": true, 230 "HIGH_NOT_PRECEDENCE": true, 231 "IGNORE_SPACE": true, 232 "NO_AUTO_CREATE_USER": true, 233 "NO_AUTO_VALUE_ON_ZERO": true, 234 "NO_BACKSLASH_ESCAPES": true, 235 "NO_DIR_IN_CREATE": true, 236 "NO_ENGINE_SUBSTITUTION": true, 237 "NO_FIELD_OPTIONS": true, 238 "NO_KEY_OPTIONS": true, 239 "NO_TABLE_OPTIONS": true, 240 "NO_UNSIGNED_SUBTRACTION": true, 241 "NO_ZERO_DATE": true, 242 "NO_ZERO_IN_DATE": true, 243 "ONLY_FULL_GROUP_BY": true, 244 "PAD_CHAR_TO_FULL_LENGTH": true, 245 "PIPES_AS_CONCAT": true, 246 "REAL_AS_FLOAT": true, 247 "STRICT_ALL_TABLES": true, 248 "STRICT_TRANS_TABLES": true, 249 250 // Combination SQL Modes 251 "ANSI": true, 252 "DB2": true, 253 "MAXDB": true, 254 "MSSQL": true, 255 "MYSQL323": true, 256 "MYSQL40": true, 257 "ORACLE": true, 258 "POSTGRESQL": true, 259 "TRADITIONAL": true, 260 } 261 262 func verifyOnOffInteger(v interface{}) error { 263 val, ok := v.(int64) 264 if !ok { 265 return fmt.Errorf("value is not int64") 266 } 267 if val != 0 && val != 1 { 268 return fmt.Errorf("value is not 0 or 1") 269 } 270 return nil 271 } 272 273 func verifyTimeZone(v interface{}) error { 274 value, ok := v.(string) 275 if !ok { 276 return fmt.Errorf("invalid type of time_zone") 277 } 278 values := strings.Split(value, ":") 279 if len(values) != 2 { 280 return fmt.Errorf("invalid format of time_zone") 281 } 282 if values[0][0] != '+' && values[0][0] != '-' { 283 return fmt.Errorf("invalid format of time_zone") 284 } 285 hour, err := strconv.Atoi(values[0]) 286 if err != nil { 287 return fmt.Errorf("invalid hour of time_zone") 288 } 289 minute, err := strconv.Atoi(values[1]) 290 if err != nil { 291 return fmt.Errorf("invalid minute of time_zone") 292 } 293 var directMinute int 294 if hour < 0 { 295 directMinute = hour*60 - minute 296 } else { 297 directMinute = hour*60 + minute 298 } 299 if directMinute < -779 || directMinute > 780 { 300 return fmt.Errorf("exceed limit of time_zone") 301 } 302 303 return nil 304 }