github.com/abolfazlbeh/zhycan@v0.0.0-20230819144214-24cf38237387/internal/db/sql_wrapper.go (about)

     1  package db
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"github.com/abolfazlbeh/zhycan/internal/config"
     7  	"gorm.io/driver/mysql"
     8  	"gorm.io/driver/postgres"
     9  	"gorm.io/driver/sqlite"
    10  	"gorm.io/gorm"
    11  	"reflect"
    12  	"strings"
    13  )
    14  
    15  // Mark: Definitions
    16  
    17  // SqlWrapper struct
    18  type SqlWrapper[T SqlConfigurable] struct {
    19  	name             string
    20  	config           T
    21  	databaseInstance *gorm.DB
    22  }
    23  
    24  // init - SqlWrapper Constructor - It initializes the wrapper
    25  func (s *SqlWrapper[T]) init(name string) error {
    26  	s.name = name
    27  
    28  	// reading config
    29  	nameParts := strings.Split(s.name, "/")
    30  
    31  	if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Sqlite{}) {
    32  		filenameKey := fmt.Sprintf("%s.%s", nameParts[1], "db")
    33  		filenameStr, err := config.GetManager().Get(nameParts[0], filenameKey)
    34  		if err != nil {
    35  			return err
    36  		}
    37  
    38  		optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options")
    39  		optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey)
    40  		if err != nil {
    41  			return err
    42  		}
    43  
    44  		optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{})))
    45  		for key, item := range optionsObj.(map[string]interface{}) {
    46  			optionsMap[key] = item.(string)
    47  		}
    48  
    49  		var internalConfig *Config
    50  
    51  		internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config")
    52  		internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey)
    53  		if err == nil {
    54  			// first marshal
    55  			configData, err := json.Marshal(internalConfigObj)
    56  			if err == nil {
    57  				_ = json.Unmarshal(configData, &internalConfig)
    58  			}
    59  		}
    60  
    61  		var internalLogger *LoggerConfig
    62  
    63  		internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger")
    64  		internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey)
    65  		if err == nil {
    66  			// first marshal
    67  			configData, err := json.Marshal(internalLoggerObj)
    68  			if err == nil {
    69  				_ = json.Unmarshal(configData, &internalLogger)
    70  			}
    71  		}
    72  
    73  		s.config = reflect.ValueOf(Sqlite{
    74  			FileName:     filenameStr.(string),
    75  			Options:      optionsMap,
    76  			Config:       internalConfig,
    77  			LoggerConfig: internalLogger,
    78  		}).Interface().(T)
    79  	} else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Mysql{}) {
    80  		dbNameKey := fmt.Sprintf("%s.%s", nameParts[1], "db")
    81  		dbNameStr, err := config.GetManager().Get(nameParts[0], dbNameKey)
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		hostKey := fmt.Sprintf("%s.%s", nameParts[1], "host")
    87  		hostStr, err := config.GetManager().Get(nameParts[0], hostKey)
    88  		if err != nil {
    89  			return err
    90  		}
    91  
    92  		portKey := fmt.Sprintf("%s.%s", nameParts[1], "port")
    93  		portStr, err := config.GetManager().Get(nameParts[0], portKey)
    94  		if err != nil {
    95  			return err
    96  		}
    97  
    98  		protocolKey := fmt.Sprintf("%s.%s", nameParts[1], "protocol")
    99  		protocolStr, err := config.GetManager().Get(nameParts[0], protocolKey)
   100  		if err != nil {
   101  			return err
   102  		}
   103  
   104  		usernameKey := fmt.Sprintf("%s.%s", nameParts[1], "username")
   105  		usernameStr, err := config.GetManager().Get(nameParts[0], usernameKey)
   106  		if err != nil {
   107  			return err
   108  		}
   109  
   110  		passwordKey := fmt.Sprintf("%s.%s", nameParts[1], "password")
   111  		passwordStr, err := config.GetManager().Get(nameParts[0], passwordKey)
   112  		if err != nil {
   113  			return err
   114  		}
   115  
   116  		optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options")
   117  		optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey)
   118  		if err != nil {
   119  			return err
   120  		}
   121  
   122  		optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{})))
   123  		for key, item := range optionsObj.(map[string]interface{}) {
   124  			optionsMap[key] = item.(string)
   125  		}
   126  
   127  		var internalConfig *Config
   128  
   129  		internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config")
   130  		internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey)
   131  		if err == nil {
   132  			// first marshal
   133  			configData, err := json.Marshal(internalConfigObj)
   134  			if err == nil {
   135  				_ = json.Unmarshal(configData, &internalConfig)
   136  			}
   137  		}
   138  
   139  		var internalLogger *LoggerConfig
   140  
   141  		internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger")
   142  		internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey)
   143  		if err == nil {
   144  			// first marshal
   145  			configData, err := json.Marshal(internalLoggerObj)
   146  			if err == nil {
   147  				_ = json.Unmarshal(configData, &internalLogger)
   148  			}
   149  		}
   150  
   151  		var specificConfig *MysqlSpecificConfig
   152  
   153  		specificConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "specific_config")
   154  		specificConfigObj, err := config.GetManager().Get(nameParts[0], specificConfigKey)
   155  		if err == nil {
   156  			// first marshal
   157  			configData, err := json.Marshal(specificConfigObj)
   158  			if err == nil {
   159  				_ = json.Unmarshal(configData, &specificConfig)
   160  			}
   161  		}
   162  
   163  		s.config = reflect.ValueOf(Mysql{
   164  			DatabaseName:   dbNameStr.(string),
   165  			Username:       usernameStr.(string),
   166  			Password:       passwordStr.(string),
   167  			Host:           hostStr.(string),
   168  			Port:           portStr.(string),
   169  			Protocol:       protocolStr.(string),
   170  			Options:        optionsMap,
   171  			Config:         internalConfig,
   172  			LoggerConfig:   internalLogger,
   173  			SpecificConfig: specificConfig,
   174  		}).Interface().(T)
   175  	} else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Postgresql{}) {
   176  		dbNameKey := fmt.Sprintf("%s.%s", nameParts[1], "db")
   177  		dbNameStr, err := config.GetManager().Get(nameParts[0], dbNameKey)
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		hostKey := fmt.Sprintf("%s.%s", nameParts[1], "host")
   183  		hostStr, err := config.GetManager().Get(nameParts[0], hostKey)
   184  		if err != nil {
   185  			return err
   186  		}
   187  
   188  		portKey := fmt.Sprintf("%s.%s", nameParts[1], "port")
   189  		portStr, err := config.GetManager().Get(nameParts[0], portKey)
   190  		if err != nil {
   191  			return err
   192  		}
   193  
   194  		usernameKey := fmt.Sprintf("%s.%s", nameParts[1], "username")
   195  		usernameStr, err := config.GetManager().Get(nameParts[0], usernameKey)
   196  		if err != nil {
   197  			return err
   198  		}
   199  
   200  		passwordKey := fmt.Sprintf("%s.%s", nameParts[1], "password")
   201  		passwordStr, err := config.GetManager().Get(nameParts[0], passwordKey)
   202  		if err != nil {
   203  			return err
   204  		}
   205  
   206  		optionsKey := fmt.Sprintf("%s.%s", nameParts[1], "options")
   207  		optionsObj, err := config.GetManager().Get(nameParts[0], optionsKey)
   208  		if err != nil {
   209  			return err
   210  		}
   211  
   212  		optionsMap := make(map[string]string, len(optionsObj.(map[string]interface{})))
   213  		for key, item := range optionsObj.(map[string]interface{}) {
   214  			optionsMap[key] = item.(string)
   215  		}
   216  
   217  		var internalConfig *Config
   218  
   219  		internalConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "config")
   220  		internalConfigObj, err := config.GetManager().Get(nameParts[0], internalConfigKey)
   221  		if err == nil {
   222  			// first marshal
   223  			configData, err := json.Marshal(internalConfigObj)
   224  			if err == nil {
   225  				_ = json.Unmarshal(configData, &internalConfig)
   226  			}
   227  		}
   228  
   229  		var internalLogger *LoggerConfig
   230  
   231  		internalLoggerKey := fmt.Sprintf("%s.%s", nameParts[1], "logger")
   232  		internalLoggerObj, err := config.GetManager().Get(nameParts[0], internalLoggerKey)
   233  		if err == nil {
   234  			// first marshal
   235  			configData, err := json.Marshal(internalLoggerObj)
   236  			if err == nil {
   237  				_ = json.Unmarshal(configData, &internalLogger)
   238  			}
   239  		}
   240  
   241  		var specificConfig *PostgresqlSpecificConfig
   242  
   243  		specificConfigKey := fmt.Sprintf("%s.%s", nameParts[1], "specific_config")
   244  		specificConfigObj, err := config.GetManager().Get(nameParts[0], specificConfigKey)
   245  		if err == nil {
   246  			// first marshal
   247  			configData, err := json.Marshal(specificConfigObj)
   248  			if err == nil {
   249  				_ = json.Unmarshal(configData, &specificConfig)
   250  			}
   251  		}
   252  
   253  		s.config = reflect.ValueOf(Postgresql{
   254  			DatabaseName:   dbNameStr.(string),
   255  			Username:       usernameStr.(string),
   256  			Password:       passwordStr.(string),
   257  			Host:           hostStr.(string),
   258  			Port:           portStr.(string),
   259  			Options:        optionsMap,
   260  			Config:         internalConfig,
   261  			LoggerConfig:   internalLogger,
   262  			SpecificConfig: specificConfig,
   263  		}).Interface().(T)
   264  	}
   265  
   266  	return nil
   267  }
   268  
   269  // MARK: Public functions
   270  
   271  // GetDb - return associated internal Db
   272  func (s *SqlWrapper[T]) GetDb() (*gorm.DB, error) {
   273  	if s.databaseInstance == nil {
   274  		if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Sqlite{}) {
   275  			optionsQSArr := make([]string, 0)
   276  			config := reflect.ValueOf(s.config).Interface().(Sqlite)
   277  			for key, val := range config.Options {
   278  				optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val))
   279  			}
   280  			optionsQS := strings.Join(optionsQSArr, "&")
   281  
   282  			dsn := fmt.Sprintf("file:%s?%s", config.FileName, optionsQS)
   283  			internalConfig := &gorm.Config{}
   284  			if config.Config != nil {
   285  				internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing
   286  				internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating
   287  				internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction
   288  				internalConfig.DryRun = config.Config.DryRun
   289  				internalConfig.PrepareStmt = config.Config.PrepareStmt
   290  				internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction
   291  				internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating
   292  			}
   293  
   294  			if config.LoggerConfig != nil {
   295  				internalConfig.Logger = NewDbLogger(*config.LoggerConfig)
   296  			}
   297  
   298  			db, err := gorm.Open(sqlite.Open(dsn), internalConfig)
   299  			if err != nil {
   300  				return nil, err
   301  			}
   302  			s.databaseInstance = db
   303  		} else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Mysql{}) {
   304  			optionsQSArr := make([]string, 0)
   305  			config := reflect.ValueOf(s.config).Interface().(Mysql)
   306  			for key, val := range config.Options {
   307  				optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val))
   308  			}
   309  			optionsQS := strings.Join(optionsQSArr, "&")
   310  
   311  			dsn := fmt.Sprintf("%s:%s@%s(%s:%s)/%s?%s", config.Username,
   312  				config.Password, config.Protocol, config.Host, config.Port,
   313  				config.DatabaseName, optionsQS)
   314  			internalConfig := &gorm.Config{}
   315  			if config.Config != nil {
   316  				internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing
   317  				internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating
   318  				internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction
   319  				internalConfig.DryRun = config.Config.DryRun
   320  				internalConfig.PrepareStmt = config.Config.PrepareStmt
   321  				internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction
   322  				internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating
   323  			}
   324  
   325  			if config.LoggerConfig != nil {
   326  				internalConfig.Logger = NewDbLogger(*config.LoggerConfig)
   327  			}
   328  
   329  			if config.SpecificConfig == nil {
   330  				db, err := gorm.Open(mysql.Open(dsn), internalConfig)
   331  				if err != nil {
   332  					return nil, err
   333  				}
   334  				s.databaseInstance = db
   335  			} else {
   336  				db, err := gorm.Open(mysql.New(mysql.Config{
   337  					DSN:                           dsn,
   338  					SkipInitializeWithVersion:     config.SpecificConfig.SkipInitializeWithVersion,
   339  					DefaultStringSize:             config.SpecificConfig.DefaultStringSize,
   340  					DefaultDatetimePrecision:      &config.SpecificConfig.DefaultDatetimePrecision,
   341  					DisableWithReturning:          config.SpecificConfig.DisableWithReturning,
   342  					DisableDatetimePrecision:      config.SpecificConfig.DisableDatetimePrecision,
   343  					DontSupportRenameIndex:        !config.SpecificConfig.SupportRenameIndex,
   344  					DontSupportRenameColumn:       !config.SpecificConfig.SupportRenameColumn,
   345  					DontSupportForShareClause:     !config.SpecificConfig.SupportForShareClause,
   346  					DontSupportNullAsDefaultValue: !config.SpecificConfig.SupportNullAsDefaultValue,
   347  					DontSupportRenameColumnUnique: !config.SpecificConfig.SupportRenameColumnUnique,
   348  				}), internalConfig)
   349  				if err != nil {
   350  					return nil, err
   351  				}
   352  				s.databaseInstance = db
   353  			}
   354  
   355  		} else if reflect.ValueOf(s.config).Type() == reflect.TypeOf(Postgresql{}) {
   356  			optionsQSArr := make([]string, 0)
   357  			config := reflect.ValueOf(s.config).Interface().(Postgresql)
   358  			for key, val := range config.Options {
   359  				optionsQSArr = append(optionsQSArr, fmt.Sprintf("%s=%s", key, val))
   360  			}
   361  			optionsQS := strings.Join(optionsQSArr, " ")
   362  
   363  			dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s %s",
   364  				config.Host, config.Username, config.Password, config.DatabaseName,
   365  				config.Port, optionsQS,
   366  			)
   367  			internalConfig := &gorm.Config{}
   368  			if config.Config != nil {
   369  				internalConfig.DisableAutomaticPing = config.Config.DisableAutomaticPing
   370  				internalConfig.DisableForeignKeyConstraintWhenMigrating = config.Config.DisableForeignKeyConstraintWhenMigrating
   371  				internalConfig.DisableNestedTransaction = config.Config.DisableNestedTransaction
   372  				internalConfig.DryRun = config.Config.DryRun
   373  				internalConfig.PrepareStmt = config.Config.PrepareStmt
   374  				internalConfig.SkipDefaultTransaction = config.Config.SkipDefaultTransaction
   375  				internalConfig.IgnoreRelationshipsWhenMigrating = config.Config.IgnoreRelationshipsWhenMigrating
   376  			}
   377  
   378  			if config.LoggerConfig != nil {
   379  				internalConfig.Logger = NewDbLogger(*config.LoggerConfig)
   380  			}
   381  
   382  			if config.SpecificConfig == nil {
   383  				db, err := gorm.Open(postgres.Open(dsn), internalConfig)
   384  				if err != nil {
   385  					return nil, err
   386  				}
   387  				s.databaseInstance = db
   388  			} else {
   389  				db, err := gorm.Open(postgres.New(postgres.Config{
   390  					DSN:                  dsn,
   391  					PreferSimpleProtocol: config.SpecificConfig.PreferSimpleProtocol,
   392  					WithoutReturning:     config.SpecificConfig.WithoutReturning,
   393  				}), internalConfig)
   394  				if err != nil {
   395  					return nil, err
   396  				}
   397  				s.databaseInstance = db
   398  			}
   399  		}
   400  	}
   401  	return s.databaseInstance, nil
   402  }
   403  
   404  // Migrate - migrate models to the database
   405  func (s *SqlWrapper[T]) Migrate(models ...interface{}) error {
   406  	err := s.databaseInstance.AutoMigrate(models...)
   407  	if err != nil {
   408  		return NewMigrateErr(err)
   409  	}
   410  	return nil
   411  }
   412  
   413  // AttachMigrationFunc -  attach migration function to be called by end user
   414  func (s *SqlWrapper[T]) AttachMigrationFunc(f func(migrator gorm.Migrator) error) error {
   415  	err := f(s.databaseInstance.Migrator())
   416  	if err != nil {
   417  		return NewMigrateErr(err)
   418  	}
   419  	return nil
   420  }
   421  
   422  // NewSqlWrapper - create a new instance of SqlWrapper and returns it
   423  func NewSqlWrapper[T SqlConfigurable](name string, dbType string) (*SqlWrapper[T], error) {
   424  	if strings.ToLower(dbType) == "sqlite" ||
   425  		strings.ToLower(dbType) == "mysql" ||
   426  		strings.ToLower(dbType) == "postgresql" {
   427  		wrapper := &SqlWrapper[T]{}
   428  		err := wrapper.init(name)
   429  		if err != nil {
   430  			return nil, NewCreateSqlWrapperErr(err)
   431  		}
   432  
   433  		return wrapper, nil
   434  	}
   435  
   436  	return nil, NewNotSupportedDbTypeErr(dbType)
   437  }