github.com/orofarne/hammy@v0.0.0-20130409105742-374fadfd6ecb/src/hammy/mysql_state.go (about)

     1  package hammy
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"encoding/json"
     7  	"database/sql"
     8  	_ "github.com/ziutek/mymysql/godrv"
     9  )
    10  
    11  // Driver for retriving and saving state in MySQL database
    12  // It's assumes the table structure like this:
    13  //
    14  //  CREATE TABLE `states` (
    15  //    `host` varchar(255) NOT NULL,
    16  //    `state` text,
    17  //    `cas` BIGINT NOT NULL DEFAULT 0,
    18  //    PRIMARY KEY (`host`)
    19  //  ) ENGINE=InnoDB DEFAULT CHARSET=utf8
    20  //
    21  type MySQLStateKeeper struct {
    22  	db *sql.DB
    23  	tableName string
    24  	pool chan int
    25  }
    26  
    27  func NewMySQLStateKeeper(cfg Config) (sk *MySQLStateKeeper, err error) {
    28  	sk = new(MySQLStateKeeper)
    29  	sk.db, err = sql.Open("mymysql", cfg.MySQLStates.Database + "/" + cfg.MySQLStates.User + "/" + cfg.MySQLStates.Password)
    30  	if err != nil {
    31  		return
    32  	}
    33  
    34  	sk.tableName = cfg.MySQLStates.Table
    35  
    36  	sk.pool = make(chan int, cfg.MySQLStates.MaxConn)
    37  	for i := 0; i < cfg.MySQLStates.MaxConn; i++ {
    38  		sk.pool <- 1
    39  	}
    40  
    41  	return
    42  }
    43  
    44  func (sk *MySQLStateKeeper) Get(key string) (ans StateKeeperAnswer) {
    45  	// Pool limits
    46  	<- sk.pool
    47  	defer func() {
    48  		sk.pool <- 1
    49  	}()
    50  
    51  	var stateRaw []byte
    52  	var cas uint64
    53  
    54  	sqlq := fmt.Sprintf("SELECT `state`, `cas` FROM `%s` WHERE `host` = ?", sk.tableName)
    55  	row := sk.db.QueryRow(sqlq, key)
    56  	err := row.Scan(&stateRaw, &cas)
    57  
    58  	var s State
    59  	switch err {
    60  		case nil:
    61  			e := json.Unmarshal(stateRaw, &s)
    62  			if e != nil {
    63  				ans.Err = e
    64  				return
    65  			}
    66  		case sql.ErrNoRows:
    67  			// Do nothing
    68  		default:
    69  			ans.Err = err
    70  			return
    71  	}
    72  
    73  	ans.State = s
    74  	ans.Cas = &cas
    75  	return
    76  }
    77  
    78  func (sk *MySQLStateKeeper) MGet(keys []string) (states map[string]StateKeeperAnswer) {
    79  	// Pool limits
    80  	<- sk.pool
    81  	defer func() {
    82  		sk.pool <- 1
    83  	}()
    84  
    85  	states = make(map[string]StateKeeperAnswer)
    86  
    87  	n := len(keys)
    88  	// Selecting states by 10 rows
    89  SUBKEYS:	for i := 0; i < n; i += 10 {
    90  		var subkeys []string
    91  		if (i + 10) < n {
    92  			subkeys = keys[i:i+10]
    93  		} else {
    94  			subkeys = keys[i:]
    95  		}
    96  
    97  		m := len(subkeys)
    98  
    99  		sqlq := fmt.Sprintf("SELECT `host`, `state`, `cas` FROM `%s` WHERE `host` IN (?", sk.tableName)
   100  		for j := 1; j < m; j++ {
   101  			sqlq += ", ?"
   102  		}
   103  		sqlq += ")"
   104  
   105  		args := make([]interface{}, m)
   106  		for k, s := range subkeys {
   107  			args[k] = s
   108  		}
   109  
   110  		rows, e := sk.db.Query(sqlq, args...)
   111  		if e != nil {
   112  			for _, k := range subkeys {
   113  				states[k] = StateKeeperAnswer{
   114  					State: nil,
   115  					Cas: nil,
   116  					Err: fmt.Errorf("Query error: %v", e),
   117  				}
   118  			}
   119  			continue
   120  		}
   121  
   122  		for rows.Next() {
   123  			var hostK string
   124  			var stateRaw []byte
   125  			var cas uint64
   126  
   127  			err := rows.Scan(&hostK, &stateRaw, &cas)
   128  			if err != nil {
   129  				for _, k := range subkeys {
   130  					states[k] = StateKeeperAnswer{
   131  						State: nil,
   132  						Cas: nil,
   133  						Err: fmt.Errorf("Query error: %v", err),
   134  					}
   135  				}
   136  				continue SUBKEYS
   137  			}
   138  
   139  			var s State
   140  			err = json.Unmarshal(stateRaw, &s)
   141  			if err != nil {
   142  				states[hostK] = StateKeeperAnswer{
   143  					State: nil,
   144  					Cas: nil,
   145  					Err: fmt.Errorf("Unmarshal error: %v", err),
   146  				}
   147  			} else {
   148  				states[hostK] = StateKeeperAnswer{
   149  					State: s,
   150  					Cas: &cas,
   151  					Err: nil,
   152  				}
   153  			}
   154  		}
   155  	}
   156  
   157  	for _, k := range keys {
   158  		if _, found := states[k]; !found {
   159  			states[k] = StateKeeperAnswer{
   160  				State: *NewState(),
   161  				Cas: nil,
   162  				Err: nil,
   163  			}
   164  		}
   165  	}
   166  
   167  	return
   168  }
   169  
   170  func (sk *MySQLStateKeeper) Set(key string, data State, cas *uint64) (retry bool, err error) {
   171  	// Pool limits
   172  	<- sk.pool
   173  	defer func() {
   174  		sk.pool <- 1
   175  	}()
   176  
   177  	stateRaw, err := json.Marshal(data)
   178  	if err != nil {
   179  		return
   180  	}
   181  
   182  	if cas == nil {
   183  		sqlq := fmt.Sprintf("INSERT INTO `%s` SET `host` = ?, `state` = ?, `cas` = ?", sk.tableName)
   184  		_, e := sk.db.Exec(sqlq, key, stateRaw, 0)
   185  		if e != nil {
   186  			// Error may looks like this:
   187  			//  Received #1062 error from MySQL server: "Duplicate entry 'foo.example.com' for key 'PRIMARY'"
   188  			if strings.Contains(e.Error(), "Received #1062 error from MySQL server") {
   189  				retry = true
   190  			} else {
   191  				err = e
   192  			}
   193  			return
   194  		}
   195  	} else {
   196  		newCas := *cas + 1
   197  		sqlq := fmt.Sprintf("UPDATE `%s` SET `state` = ?, `cas` = ? WHERE `host` = ? AND `cas` = ?", sk.tableName)
   198  		res, e := sk.db.Exec(sqlq, stateRaw, newCas, key, *cas)
   199  		if e != nil {
   200  			err = e
   201  			return
   202  		}
   203  		rowsAffected, e := res.RowsAffected()
   204  		if e != nil {
   205  			err = e
   206  			return
   207  		}
   208  		if rowsAffected != 1 {
   209  			if rowsAffected > 1 {
   210  				panic("More than one row has been affected")
   211  			}
   212  			retry = true
   213  		}
   214  	}
   215  
   216  	return
   217  }