github.com/ipfans/trojan-go@v0.11.0/statistic/mysql/mysql.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"strings"
     8  	"time"
     9  
    10  	// MySQL Driver
    11  	_ "github.com/go-sql-driver/mysql"
    12  
    13  	"github.com/ipfans/trojan-go/common"
    14  	"github.com/ipfans/trojan-go/config"
    15  	"github.com/ipfans/trojan-go/log"
    16  	"github.com/ipfans/trojan-go/statistic"
    17  	"github.com/ipfans/trojan-go/statistic/memory"
    18  )
    19  
    20  const Name = "MYSQL"
    21  
    22  type Authenticator struct {
    23  	*memory.Authenticator
    24  	db             *sql.DB
    25  	updateDuration time.Duration
    26  	ctx            context.Context
    27  }
    28  
    29  func (a *Authenticator) updater() {
    30  	for {
    31  		for _, user := range a.ListUsers() {
    32  			// swap upload and download for users
    33  			hash := user.Hash()
    34  			sent, recv := user.ResetTraffic()
    35  
    36  			s, err := a.db.Exec("UPDATE `users` SET `upload`=`upload`+?, `download`=`download`+? WHERE `password`=?;", recv, sent, hash)
    37  			if err != nil {
    38  				log.Error(common.NewError("failed to update data to user table").Base(err))
    39  				continue
    40  			}
    41  			if r, err := s.RowsAffected(); err != nil {
    42  				if r == 0 {
    43  					a.DelUser(hash)
    44  				}
    45  			}
    46  		}
    47  		log.Info("buffered data has been written into the database")
    48  
    49  		// update memory
    50  		rows, err := a.db.Query("SELECT password,quota,download,upload FROM users")
    51  		if err != nil || rows.Err() != nil {
    52  			log.Error(common.NewError("failed to pull data from the database").Base(err))
    53  			time.Sleep(a.updateDuration)
    54  			continue
    55  		}
    56  		for rows.Next() {
    57  			var hash string
    58  			var quota, download, upload int64
    59  			err := rows.Scan(&hash, &quota, &download, &upload)
    60  			if err != nil {
    61  				log.Error(common.NewError("failed to obtain data from the query result").Base(err))
    62  				break
    63  			}
    64  			if download+upload < quota || quota < 0 {
    65  				a.AddUser(hash)
    66  			} else {
    67  				a.DelUser(hash)
    68  			}
    69  		}
    70  
    71  		select {
    72  		case <-time.After(a.updateDuration):
    73  		case <-a.ctx.Done():
    74  			log.Debug("MySQL daemon exiting...")
    75  			return
    76  		}
    77  	}
    78  }
    79  
    80  func connectDatabase(driverName, username, password, ip string, port int, dbName string) (*sql.DB, error) {
    81  	path := strings.Join([]string{username, ":", password, "@tcp(", ip, ":", fmt.Sprintf("%d", port), ")/", dbName, "?charset=utf8"}, "")
    82  	return sql.Open(driverName, path)
    83  }
    84  
    85  func NewAuthenticator(ctx context.Context) (statistic.Authenticator, error) {
    86  	cfg := config.FromContext(ctx, Name).(*Config)
    87  	db, err := connectDatabase(
    88  		"mysql",
    89  		cfg.MySQL.Username,
    90  		cfg.MySQL.Password,
    91  		cfg.MySQL.ServerHost,
    92  		cfg.MySQL.ServerPort,
    93  		cfg.MySQL.Database,
    94  	)
    95  	if err != nil {
    96  		return nil, common.NewError("Failed to connect to database server").Base(err)
    97  	}
    98  	memoryAuth, err := memory.NewAuthenticator(ctx)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	a := &Authenticator{
   103  		db:             db,
   104  		ctx:            ctx,
   105  		updateDuration: time.Duration(cfg.MySQL.CheckRate) * time.Second,
   106  		Authenticator:  memoryAuth.(*memory.Authenticator),
   107  	}
   108  	go a.updater()
   109  	log.Debug("mysql authenticator created")
   110  	return a, nil
   111  }
   112  
   113  func init() {
   114  	statistic.RegisterAuthenticatorCreator(Name, NewAuthenticator)
   115  }