github.com/orofarne/hammy@v0.0.0-20130409105742-374fadfd6ecb/src/hammy/mysql_state.go (about) 1 package hammy 2 3 import ( 4 "fmt" 5 "strings" 6 "encoding/json" 7 "database/sql" 8 _ "github.com/ziutek/mymysql/godrv" 9 ) 10 11 // Driver for retriving and saving state in MySQL database 12 // It's assumes the table structure like this: 13 // 14 // CREATE TABLE `states` ( 15 // `host` varchar(255) NOT NULL, 16 // `state` text, 17 // `cas` BIGINT NOT NULL DEFAULT 0, 18 // PRIMARY KEY (`host`) 19 // ) ENGINE=InnoDB DEFAULT CHARSET=utf8 20 // 21 type MySQLStateKeeper struct { 22 db *sql.DB 23 tableName string 24 pool chan int 25 } 26 27 func NewMySQLStateKeeper(cfg Config) (sk *MySQLStateKeeper, err error) { 28 sk = new(MySQLStateKeeper) 29 sk.db, err = sql.Open("mymysql", cfg.MySQLStates.Database + "/" + cfg.MySQLStates.User + "/" + cfg.MySQLStates.Password) 30 if err != nil { 31 return 32 } 33 34 sk.tableName = cfg.MySQLStates.Table 35 36 sk.pool = make(chan int, cfg.MySQLStates.MaxConn) 37 for i := 0; i < cfg.MySQLStates.MaxConn; i++ { 38 sk.pool <- 1 39 } 40 41 return 42 } 43 44 func (sk *MySQLStateKeeper) Get(key string) (ans StateKeeperAnswer) { 45 // Pool limits 46 <- sk.pool 47 defer func() { 48 sk.pool <- 1 49 }() 50 51 var stateRaw []byte 52 var cas uint64 53 54 sqlq := fmt.Sprintf("SELECT `state`, `cas` FROM `%s` WHERE `host` = ?", sk.tableName) 55 row := sk.db.QueryRow(sqlq, key) 56 err := row.Scan(&stateRaw, &cas) 57 58 var s State 59 switch err { 60 case nil: 61 e := json.Unmarshal(stateRaw, &s) 62 if e != nil { 63 ans.Err = e 64 return 65 } 66 case sql.ErrNoRows: 67 // Do nothing 68 default: 69 ans.Err = err 70 return 71 } 72 73 ans.State = s 74 ans.Cas = &cas 75 return 76 } 77 78 func (sk *MySQLStateKeeper) MGet(keys []string) (states map[string]StateKeeperAnswer) { 79 // Pool limits 80 <- sk.pool 81 defer func() { 82 sk.pool <- 1 83 }() 84 85 states = make(map[string]StateKeeperAnswer) 86 87 n := len(keys) 88 // Selecting states by 10 rows 89 SUBKEYS: for i := 0; i < n; i += 10 { 90 var subkeys []string 91 if (i + 10) < n { 92 subkeys = keys[i:i+10] 93 } else { 94 subkeys = keys[i:] 95 } 96 97 m := len(subkeys) 98 99 sqlq := fmt.Sprintf("SELECT `host`, `state`, `cas` FROM `%s` WHERE `host` IN (?", sk.tableName) 100 for j := 1; j < m; j++ { 101 sqlq += ", ?" 102 } 103 sqlq += ")" 104 105 args := make([]interface{}, m) 106 for k, s := range subkeys { 107 args[k] = s 108 } 109 110 rows, e := sk.db.Query(sqlq, args...) 111 if e != nil { 112 for _, k := range subkeys { 113 states[k] = StateKeeperAnswer{ 114 State: nil, 115 Cas: nil, 116 Err: fmt.Errorf("Query error: %v", e), 117 } 118 } 119 continue 120 } 121 122 for rows.Next() { 123 var hostK string 124 var stateRaw []byte 125 var cas uint64 126 127 err := rows.Scan(&hostK, &stateRaw, &cas) 128 if err != nil { 129 for _, k := range subkeys { 130 states[k] = StateKeeperAnswer{ 131 State: nil, 132 Cas: nil, 133 Err: fmt.Errorf("Query error: %v", err), 134 } 135 } 136 continue SUBKEYS 137 } 138 139 var s State 140 err = json.Unmarshal(stateRaw, &s) 141 if err != nil { 142 states[hostK] = StateKeeperAnswer{ 143 State: nil, 144 Cas: nil, 145 Err: fmt.Errorf("Unmarshal error: %v", err), 146 } 147 } else { 148 states[hostK] = StateKeeperAnswer{ 149 State: s, 150 Cas: &cas, 151 Err: nil, 152 } 153 } 154 } 155 } 156 157 for _, k := range keys { 158 if _, found := states[k]; !found { 159 states[k] = StateKeeperAnswer{ 160 State: *NewState(), 161 Cas: nil, 162 Err: nil, 163 } 164 } 165 } 166 167 return 168 } 169 170 func (sk *MySQLStateKeeper) Set(key string, data State, cas *uint64) (retry bool, err error) { 171 // Pool limits 172 <- sk.pool 173 defer func() { 174 sk.pool <- 1 175 }() 176 177 stateRaw, err := json.Marshal(data) 178 if err != nil { 179 return 180 } 181 182 if cas == nil { 183 sqlq := fmt.Sprintf("INSERT INTO `%s` SET `host` = ?, `state` = ?, `cas` = ?", sk.tableName) 184 _, e := sk.db.Exec(sqlq, key, stateRaw, 0) 185 if e != nil { 186 // Error may looks like this: 187 // Received #1062 error from MySQL server: "Duplicate entry 'foo.example.com' for key 'PRIMARY'" 188 if strings.Contains(e.Error(), "Received #1062 error from MySQL server") { 189 retry = true 190 } else { 191 err = e 192 } 193 return 194 } 195 } else { 196 newCas := *cas + 1 197 sqlq := fmt.Sprintf("UPDATE `%s` SET `state` = ?, `cas` = ? WHERE `host` = ? AND `cas` = ?", sk.tableName) 198 res, e := sk.db.Exec(sqlq, stateRaw, newCas, key, *cas) 199 if e != nil { 200 err = e 201 return 202 } 203 rowsAffected, e := res.RowsAffected() 204 if e != nil { 205 err = e 206 return 207 } 208 if rowsAffected != 1 { 209 if rowsAffected > 1 { 210 panic("More than one row has been affected") 211 } 212 retry = true 213 } 214 } 215 216 return 217 }