github.com/ipfans/trojan-go@v0.11.0/statistic/mysql/mysql.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strings" 8 "time" 9 10 // MySQL Driver 11 _ "github.com/go-sql-driver/mysql" 12 13 "github.com/ipfans/trojan-go/common" 14 "github.com/ipfans/trojan-go/config" 15 "github.com/ipfans/trojan-go/log" 16 "github.com/ipfans/trojan-go/statistic" 17 "github.com/ipfans/trojan-go/statistic/memory" 18 ) 19 20 const Name = "MYSQL" 21 22 type Authenticator struct { 23 *memory.Authenticator 24 db *sql.DB 25 updateDuration time.Duration 26 ctx context.Context 27 } 28 29 func (a *Authenticator) updater() { 30 for { 31 for _, user := range a.ListUsers() { 32 // swap upload and download for users 33 hash := user.Hash() 34 sent, recv := user.ResetTraffic() 35 36 s, err := a.db.Exec("UPDATE `users` SET `upload`=`upload`+?, `download`=`download`+? WHERE `password`=?;", recv, sent, hash) 37 if err != nil { 38 log.Error(common.NewError("failed to update data to user table").Base(err)) 39 continue 40 } 41 if r, err := s.RowsAffected(); err != nil { 42 if r == 0 { 43 a.DelUser(hash) 44 } 45 } 46 } 47 log.Info("buffered data has been written into the database") 48 49 // update memory 50 rows, err := a.db.Query("SELECT password,quota,download,upload FROM users") 51 if err != nil || rows.Err() != nil { 52 log.Error(common.NewError("failed to pull data from the database").Base(err)) 53 time.Sleep(a.updateDuration) 54 continue 55 } 56 for rows.Next() { 57 var hash string 58 var quota, download, upload int64 59 err := rows.Scan(&hash, "a, &download, &upload) 60 if err != nil { 61 log.Error(common.NewError("failed to obtain data from the query result").Base(err)) 62 break 63 } 64 if download+upload < quota || quota < 0 { 65 a.AddUser(hash) 66 } else { 67 a.DelUser(hash) 68 } 69 } 70 71 select { 72 case <-time.After(a.updateDuration): 73 case <-a.ctx.Done(): 74 log.Debug("MySQL daemon exiting...") 75 return 76 } 77 } 78 } 79 80 func connectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) { 81 path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "") 82 return sql.Open(driverName, path) 83 } 84 85 func NewAuthenticator(ctx context.Context) (statistic.Authenticator, error) { 86 cfg := config.FromContext(ctx, Name).(*Config) 87 db, err := connectDatabase( 88 "mysql", 89 cfg.MySQL.Username, 90 cfg.MySQL.Password, 91 cfg.MySQL.ServerHost, 92 cfg.MySQL.ServerPort, 93 cfg.MySQL.Database, 94 ) 95 if err != nil { 96 return nil, common.NewError("Failed to connect to database server").Base(err) 97 } 98 memoryAuth, err := memory.NewAuthenticator(ctx) 99 if err != nil { 100 return nil, err 101 } 102 a := &Authenticator{ 103 db: db, 104 ctx: ctx, 105 updateDuration: time.Duration(cfg.MySQL.CheckRate) * time.Second, 106 Authenticator: memoryAuth.(*memory.Authenticator), 107 } 108 go a.updater() 109 log.Debug("mysql authenticator created") 110 return a, nil 111 } 112 113 func init() { 114 statistic.RegisterAuthenticatorCreator(Name, NewAuthenticator) 115 }