github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/mysql/mysql.go (about)

     1  package mysql
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  	"net/url"
     7  	"reflect"
     8  	"regexp"
     9  	"time"
    10  
    11  	"github.com/fatih/color"
    12  	"github.com/sirupsen/logrus"
    13  
    14  	"github.com/artisanhe/gorm"
    15  
    16  	"github.com/artisanhe/tools/conf"
    17  	"github.com/artisanhe/tools/conf/presets"
    18  )
    19  
    20  type MySQL struct {
    21  	Name            string
    22  	Host            string `conf:"upstream" validate:"@hostname"`
    23  	Port            int
    24  	User            string           `conf:"env" validate:"@string[1,)"`
    25  	Password        presets.Password `conf:"env" validate:"@string[1,)"`
    26  	Extra           string
    27  	PoolSize        int
    28  	ConnMaxLifetime time.Duration
    29  	presets.Retry
    30  	db *gorm.DB
    31  }
    32  
    33  func (m MySQL) DockerDefaults() conf.DockerDefaults {
    34  	return conf.DockerDefaults{
    35  		"Host": "db-master", //conf.RancherInternal("tool-dbs", m.Name),
    36  		"Port": 33306,
    37  	}
    38  }
    39  
    40  func (m MySQL) MarshalDefaults(v interface{}) {
    41  	if mysql, ok := v.(*MySQL); ok {
    42  		mysql.Retry.MarshalDefaults(&mysql.Retry)
    43  
    44  		if mysql.Port == 0 {
    45  			mysql.Port = 3306
    46  		}
    47  
    48  		if mysql.PoolSize == 0 {
    49  			mysql.PoolSize = 10
    50  		}
    51  
    52  		if mysql.ConnMaxLifetime == 0 {
    53  			mysql.ConnMaxLifetime = 4 * time.Hour
    54  		}
    55  
    56  		if mysql.Extra == "" {
    57  			values := url.Values{}
    58  			values.Set("charset", "utf8")
    59  			values.Set("parseTime", "true")
    60  			values.Set("interpolateParams", "true")
    61  			values.Set("autocommit", "true")
    62  			values.Set("loc", "Local")
    63  			mysql.Extra = values.Encode()
    64  		}
    65  	}
    66  }
    67  
    68  func (m MySQL) GetConnect() string {
    69  	return fmt.Sprintf("%s:%s@tcp(%s:%d)/?%s", m.User, m.Password, m.Host, m.Port, m.Extra)
    70  }
    71  
    72  func (m *MySQL) Connect() error {
    73  	m.MarshalDefaults(m)
    74  	db, err := connectMysql(m.GetConnect(), m.PoolSize, m.ConnMaxLifetime)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	m.db = db
    79  	return nil
    80  }
    81  
    82  func (m *MySQL) Init() {
    83  	if m.db == nil {
    84  		m.Do(m.Connect)
    85  		m.db.SetLogger(&logger{})
    86  	}
    87  }
    88  
    89  func (m *MySQL) Get() *gorm.DB {
    90  	return m.db
    91  }
    92  
    93  type DBGetter interface {
    94  	Get() *gorm.DB
    95  }
    96  
    97  type logger struct {
    98  }
    99  
   100  var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
   101  
   102  func (l *logger) Print(values ...interface{}) {
   103  	if len(values) > 1 {
   104  		level := values[0]
   105  		messages := []interface{}{}
   106  		if level == "sql" {
   107  			// sql
   108  			var formatedValues []interface{}
   109  			for _, value := range values[4].([]interface{}) {
   110  				indirectValue := reflect.Indirect(reflect.ValueOf(value))
   111  				if indirectValue.IsValid() {
   112  					value = indirectValue.Interface()
   113  					if t, ok := value.(time.Time); ok {
   114  						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
   115  					} else if b, ok := value.([]byte); ok {
   116  						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
   117  					} else if r, ok := value.(driver.Valuer); ok {
   118  						if value, err := r.Value(); err == nil && value != nil {
   119  							formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
   120  						} else {
   121  							formatedValues = append(formatedValues, "NULL")
   122  						}
   123  					} else {
   124  						formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
   125  					}
   126  				} else {
   127  					formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
   128  				}
   129  			}
   130  			messages = append(messages, color.RedString(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
   131  			// duration
   132  			messages = append(messages, color.MagentaString(" [%fms]", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
   133  		} else {
   134  			messages = append(messages, values[2:]...)
   135  		}
   136  
   137  		logrus.WithField("tag", "gorm").Debug(messages...)
   138  	}
   139  }