github.com/abolfazlbeh/zhycan@v0.0.0-20230819144214-24cf38237387/internal/db/sql_wrapper.go (about) 1 package db 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "github.com/abolfazlbeh/zhycan/internal/config" 7 "gorm.io/driver/mysql" 8 "gorm.io/driver/postgres" 9 "gorm.io/driver/sqlite" 10 "gorm.io/gorm" 11 "reflect" 12 "strings" 13 ) 14 15 // Mark: Definitions 16 17 // SqlWrapper struct 18 type SqlWrapper[T SqlConfigurable] struct { 19 name string 20 config T 21 databaseInstance *gorm.DB 22 } 23 24 // init - SqlWrapper Constructor - It initializes the wrapper 25 func (s *SqlWrapper[T]) init(name string) error { 26 s.name = name 27 28 // reading config 29 nameParts := strings.Split(s.name, "/") 30 31 if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Sqlite{}) { 32 filenameKey := fmt.Sprintf("%s.%s", nameParts[1], "db") 33 filenameStr, err := config.GetManager().Get(nameParts[0], filenameKey) 34 if err != nil { 35 return err 36 } 37 38 optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options") 39 optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey) 40 if err != nil { 41 return err 42 } 43 44 optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{}))) 45 for key, item := range optionsObj.(map[string]interface{}) { 46 optionsMap[key] = item.(string) 47 } 48 49 var internalConfig *Config 50 51 internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config") 52 internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey) 53 if err == nil { 54 // first marshal 55 configData, err := json.Marshal(internalConfigObj) 56 if err == nil { 57 _ = json.Unmarshal(configData, &internalConfig) 58 } 59 } 60 61 var internalLogger *LoggerConfig 62 63 internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger") 64 internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey) 65 if err == nil { 66 // first marshal 67 configData, err := json.Marshal(internalLoggerObj) 68 if err == nil { 69 _ = json.Unmarshal(configData, &internalLogger) 70 } 71 } 72 73 s.config = reflect.ValueOf(Sqlite{ 74 FileName: filenameStr.(string), 75 Options: optionsMap, 76 Config: internalConfig, 77 LoggerConfig: internalLogger, 78 }).Interface().(T) 79 } else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Mysql{}) { 80 dbNameKey := fmt.Sprintf("%s.%s", nameParts[1], "db") 81 dbNameStr, err := config.GetManager().Get(nameParts[0], dbNameKey) 82 if err != nil { 83 return err 84 } 85 86 hostKey := fmt.Sprintf("%s.%s", nameParts[1], "host") 87 hostStr, err := config.GetManager().Get(nameParts[0], hostKey) 88 if err != nil { 89 return err 90 } 91 92 portKey := fmt.Sprintf("%s.%s", nameParts[1], "port") 93 portStr, err := config.GetManager().Get(nameParts[0], portKey) 94 if err != nil { 95 return err 96 } 97 98 protocolKey := fmt.Sprintf("%s.%s", nameParts[1], "protocol") 99 protocolStr, err := config.GetManager().Get(nameParts[0], protocolKey) 100 if err != nil { 101 return err 102 } 103 104 usernameKey := fmt.Sprintf("%s.%s", nameParts[1], "username") 105 usernameStr, err := config.GetManager().Get(nameParts[0], usernameKey) 106 if err != nil { 107 return err 108 } 109 110 passwordKey := fmt.Sprintf("%s.%s", nameParts[1], "password") 111 passwordStr, err := config.GetManager().Get(nameParts[0], passwordKey) 112 if err != nil { 113 return err 114 } 115 116 optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options") 117 optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey) 118 if err != nil { 119 return err 120 } 121 122 optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{}))) 123 for key, item := range optionsObj.(map[string]interface{}) { 124 optionsMap[key] = item.(string) 125 } 126 127 var internalConfig *Config 128 129 internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config") 130 internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey) 131 if err == nil { 132 // first marshal 133 configData, err := json.Marshal(internalConfigObj) 134 if err == nil { 135 _ = json.Unmarshal(configData, &internalConfig) 136 } 137 } 138 139 var internalLogger *LoggerConfig 140 141 internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger") 142 internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey) 143 if err == nil { 144 // first marshal 145 configData, err := json.Marshal(internalLoggerObj) 146 if err == nil { 147 _ = json.Unmarshal(configData, &internalLogger) 148 } 149 } 150 151 var specificConfig *MysqlSpecificConfig 152 153 specificConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "specific_config") 154 specificConfigObj, err := config.GetManager().Get(nameParts[0], specificConfigKey) 155 if err == nil { 156 // first marshal 157 configData, err := json.Marshal(specificConfigObj) 158 if err == nil { 159 _ = json.Unmarshal(configData, &specificConfig) 160 } 161 } 162 163 s.config = reflect.ValueOf(Mysql{ 164 DatabaseName: dbNameStr.(string), 165 Username: usernameStr.(string), 166 Password: passwordStr.(string), 167 Host: hostStr.(string), 168 Port: portStr.(string), 169 Protocol: protocolStr.(string), 170 Options: optionsMap, 171 Config: internalConfig, 172 LoggerConfig: internalLogger, 173 SpecificConfig: specificConfig, 174 }).Interface().(T) 175 } else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Postgresql{}) { 176 dbNameKey := fmt.Sprintf("%s.%s", nameParts[1], "db") 177 dbNameStr, err := config.GetManager().Get(nameParts[0], dbNameKey) 178 if err != nil { 179 return err 180 } 181 182 hostKey := fmt.Sprintf("%s.%s", nameParts[1], "host") 183 hostStr, err := config.GetManager().Get(nameParts[0], hostKey) 184 if err != nil { 185 return err 186 } 187 188 portKey := fmt.Sprintf("%s.%s", nameParts[1], "port") 189 portStr, err := config.GetManager().Get(nameParts[0], portKey) 190 if err != nil { 191 return err 192 } 193 194 usernameKey := fmt.Sprintf("%s.%s", nameParts[1], "username") 195 usernameStr, err := config.GetManager().Get(nameParts[0], usernameKey) 196 if err != nil { 197 return err 198 } 199 200 passwordKey := fmt.Sprintf("%s.%s", nameParts[1], "password") 201 passwordStr, err := config.GetManager().Get(nameParts[0], passwordKey) 202 if err != nil { 203 return err 204 } 205 206 optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options") 207 optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey) 208 if err != nil { 209 return err 210 } 211 212 optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{}))) 213 for key, item := range optionsObj.(map[string]interface{}) { 214 optionsMap[key] = item.(string) 215 } 216 217 var internalConfig *Config 218 219 internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config") 220 internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey) 221 if err == nil { 222 // first marshal 223 configData, err := json.Marshal(internalConfigObj) 224 if err == nil { 225 _ = json.Unmarshal(configData, &internalConfig) 226 } 227 } 228 229 var internalLogger *LoggerConfig 230 231 internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger") 232 internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey) 233 if err == nil { 234 // first marshal 235 configData, err := json.Marshal(internalLoggerObj) 236 if err == nil { 237 _ = json.Unmarshal(configData, &internalLogger) 238 } 239 } 240 241 var specificConfig *PostgresqlSpecificConfig 242 243 specificConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "specific_config") 244 specificConfigObj, err := config.GetManager().Get(nameParts[0], specificConfigKey) 245 if err == nil { 246 // first marshal 247 configData, err := json.Marshal(specificConfigObj) 248 if err == nil { 249 _ = json.Unmarshal(configData, &specificConfig) 250 } 251 } 252 253 s.config = reflect.ValueOf(Postgresql{ 254 DatabaseName: dbNameStr.(string), 255 Username: usernameStr.(string), 256 Password: passwordStr.(string), 257 Host: hostStr.(string), 258 Port: portStr.(string), 259 Options: optionsMap, 260 Config: internalConfig, 261 LoggerConfig: internalLogger, 262 SpecificConfig: specificConfig, 263 }).Interface().(T) 264 } 265 266 return nil 267 } 268 269 // MARK: Public functions 270 271 // GetDb - return associated internal Db 272 func (s *SqlWrapper[T]) GetDb() (*gorm.DB, error) { 273 if s.databaseInstance == nil { 274 if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Sqlite{}) { 275 optionsQSArr := make([]string, 0) 276 config := reflect.ValueOf(s.config).Interface().(Sqlite) 277 for key, val := range config.Options { 278 optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val)) 279 } 280 optionsQS := strings.Join(optionsQSArr, "&") 281 282 dsn := fmt.Sprintf("file:%s?%s", config.FileName, optionsQS) 283 internalConfig := &gorm.Config{} 284 if config.Config != nil { 285 internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing 286 internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating 287 internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction 288 internalConfig.DryRun = config.Config.DryRun 289 internalConfig.PrepareStmt = config.Config.PrepareStmt 290 internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction 291 internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating 292 } 293 294 if config.LoggerConfig != nil { 295 internalConfig.Logger = NewDbLogger(*config.LoggerConfig) 296 } 297 298 db, err := gorm.Open(sqlite.Open(dsn), internalConfig) 299 if err != nil { 300 return nil, err 301 } 302 s.databaseInstance = db 303 } else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Mysql{}) { 304 optionsQSArr := make([]string, 0) 305 config := reflect.ValueOf(s.config).Interface().(Mysql) 306 for key, val := range config.Options { 307 optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val)) 308 } 309 optionsQS := strings.Join(optionsQSArr, "&") 310 311 dsn := fmt.Sprintf("%s:%s@%s(%s:%s)/%s?%s", config.Username, 312 config.Password, config.Protocol, config.Host, config.Port, 313 config.DatabaseName, optionsQS) 314 internalConfig := &gorm.Config{} 315 if config.Config != nil { 316 internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing 317 internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating 318 internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction 319 internalConfig.DryRun = config.Config.DryRun 320 internalConfig.PrepareStmt = config.Config.PrepareStmt 321 internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction 322 internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating 323 } 324 325 if config.LoggerConfig != nil { 326 internalConfig.Logger = NewDbLogger(*config.LoggerConfig) 327 } 328 329 if config.SpecificConfig == nil { 330 db, err := gorm.Open(mysql.Open(dsn), internalConfig) 331 if err != nil { 332 return nil, err 333 } 334 s.databaseInstance = db 335 } else { 336 db, err := gorm.Open(mysql.New(mysql.Config{ 337 DSN: dsn, 338 SkipInitializeWithVersion: config.SpecificConfig.SkipInitializeWithVersion, 339 DefaultStringSize: config.SpecificConfig.DefaultStringSize, 340 DefaultDatetimePrecision: &config.SpecificConfig.DefaultDatetimePrecision, 341 DisableWithReturning: config.SpecificConfig.DisableWithReturning, 342 DisableDatetimePrecision: config.SpecificConfig.DisableDatetimePrecision, 343 DontSupportRenameIndex: !config.SpecificConfig.SupportRenameIndex, 344 DontSupportRenameColumn: !config.SpecificConfig.SupportRenameColumn, 345 DontSupportForShareClause: !config.SpecificConfig.SupportForShareClause, 346 DontSupportNullAsDefaultValue: !config.SpecificConfig.SupportNullAsDefaultValue, 347 DontSupportRenameColumnUnique: !config.SpecificConfig.SupportRenameColumnUnique, 348 }), internalConfig) 349 if err != nil { 350 return nil, err 351 } 352 s.databaseInstance = db 353 } 354 355 } else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Postgresql{}) { 356 optionsQSArr := make([]string, 0) 357 config := reflect.ValueOf(s.config).Interface().(Postgresql) 358 for key, val := range config.Options { 359 optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val)) 360 } 361 optionsQS := strings.Join(optionsQSArr, " ") 362 363 dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s %s", 364 config.Host, config.Username, config.Password, config.DatabaseName, 365 config.Port, optionsQS, 366 ) 367 internalConfig := &gorm.Config{} 368 if config.Config != nil { 369 internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing 370 internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating 371 internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction 372 internalConfig.DryRun = config.Config.DryRun 373 internalConfig.PrepareStmt = config.Config.PrepareStmt 374 internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction 375 internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating 376 } 377 378 if config.LoggerConfig != nil { 379 internalConfig.Logger = NewDbLogger(*config.LoggerConfig) 380 } 381 382 if config.SpecificConfig == nil { 383 db, err := gorm.Open(postgres.Open(dsn), internalConfig) 384 if err != nil { 385 return nil, err 386 } 387 s.databaseInstance = db 388 } else { 389 db, err := gorm.Open(postgres.New(postgres.Config{ 390 DSN: dsn, 391 PreferSimpleProtocol: config.SpecificConfig.PreferSimpleProtocol, 392 WithoutReturning: config.SpecificConfig.WithoutReturning, 393 }), internalConfig) 394 if err != nil { 395 return nil, err 396 } 397 s.databaseInstance = db 398 } 399 } 400 } 401 return s.databaseInstance, nil 402 } 403 404 // Migrate - migrate models to the database 405 func (s *SqlWrapper[T]) Migrate(models ...interface{}) error { 406 err := s.databaseInstance.AutoMigrate(models...) 407 if err != nil { 408 return NewMigrateErr(err) 409 } 410 return nil 411 } 412 413 // AttachMigrationFunc - attach migration function to be called by end user 414 func (s *SqlWrapper[T]) AttachMigrationFunc(f func(migrator gorm.Migrator) error) error { 415 err := f(s.databaseInstance.Migrator()) 416 if err != nil { 417 return NewMigrateErr(err) 418 } 419 return nil 420 } 421 422 // NewSqlWrapper - create a new instance of SqlWrapper and returns it 423 func NewSqlWrapper[T SqlConfigurable](name string, dbType string) (*SqlWrapper[T], error) { 424 if strings.ToLower(dbType) == "sqlite" || 425 strings.ToLower(dbType) == "mysql" || 426 strings.ToLower(dbType) == "postgresql" { 427 wrapper := &SqlWrapper[T]{} 428 err := wrapper.init(name) 429 if err != nil { 430 return nil, NewCreateSqlWrapperErr(err) 431 } 432 433 return wrapper, nil 434 } 435 436 return nil, NewNotSupportedDbTypeErr(dbType) 437 }