github.com/gitbundle/modules@v0.0.0-20231025071548-85b91c5c3b01/nosql/manager_redis.go (about) 1 // Copyright 2023 The GitBundle Inc. All rights reserved. 2 // Copyright 2017 The Gitea Authors. All rights reserved. 3 // Use of this source code is governed by a MIT-style 4 // license that can be found in the LICENSE file. 5 6 package nosql 7 8 import ( 9 "crypto/tls" 10 "net/url" 11 "path" 12 "runtime/pprof" 13 "strconv" 14 "strings" 15 16 "github.com/gitbundle/modules/log" 17 18 "github.com/redis/go-redis/v9" 19 ) 20 21 var replacer = strings.NewReplacer("_", "", "-", "") 22 23 // CloseRedisClient closes a redis client 24 func (m *Manager) CloseRedisClient(connection string) error { 25 m.mutex.Lock() 26 defer m.mutex.Unlock() 27 client, ok := m.RedisConnections[connection] 28 if !ok { 29 connection = ToRedisURI(connection).String() 30 client, ok = m.RedisConnections[connection] 31 } 32 if !ok { 33 return nil 34 } 35 36 client.count-- 37 if client.count > 0 { 38 return nil 39 } 40 41 for _, name := range client.name { 42 delete(m.RedisConnections, name) 43 } 44 return client.UniversalClient.Close() 45 } 46 47 // GetRedisClient gets a redis client for a particular connection 48 func (m *Manager) GetRedisClient(connection string) (client redis.UniversalClient) { 49 // Because we want associate any goroutines created by this call to the main nosqldb context we need to 50 // wrap this in a goroutine labelled with the nosqldb context 51 done := make(chan struct{}) 52 var recovered interface{} 53 go func() { 54 defer func() { 55 recovered = recover() 56 if recovered != nil { 57 log.Critical("PANIC during GetRedisClient: %v\nStacktrace: %s", recovered, log.Stack(2)) 58 } 59 close(done) 60 }() 61 pprof.SetGoroutineLabels(m.ctx) 62 63 client = m.getRedisClient(connection) 64 }() 65 <-done 66 if recovered != nil { 67 panic(recovered) 68 } 69 return 70 } 71 72 func (m *Manager) getRedisClient(connection string) redis.UniversalClient { 73 m.mutex.Lock() 74 defer m.mutex.Unlock() 75 client, ok := m.RedisConnections[connection] 76 if ok { 77 client.count++ 78 return client 79 } 80 81 uri := ToRedisURI(connection) 82 client, ok = m.RedisConnections[uri.String()] 83 if ok { 84 client.count++ 85 return client 86 } 87 client = &redisClientHolder{ 88 name: []string{connection, uri.String()}, 89 } 90 91 opts := getRedisOptions(uri) 92 tlsConfig := getRedisTLSOptions(uri) 93 94 clientName := uri.Query().Get("clientname") 95 96 if len(clientName) > 0 { 97 client.name = append(client.name, clientName) 98 } 99 100 switch uri.Scheme { 101 case "redis+sentinels": 102 fallthrough 103 case "rediss+sentinel": 104 opts.TLSConfig = tlsConfig 105 fallthrough 106 case "redis+sentinel": 107 client.UniversalClient = redis.NewFailoverClient(opts.Failover()) 108 case "redis+clusters": 109 fallthrough 110 case "rediss+cluster": 111 opts.TLSConfig = tlsConfig 112 fallthrough 113 case "redis+cluster": 114 client.UniversalClient = redis.NewClusterClient(opts.Cluster()) 115 case "redis+socket": 116 simpleOpts := opts.Simple() 117 simpleOpts.Network = "unix" 118 simpleOpts.Addr = path.Join(uri.Host, uri.Path) 119 client.UniversalClient = redis.NewClient(simpleOpts) 120 case "rediss": 121 opts.TLSConfig = tlsConfig 122 fallthrough 123 case "redis": 124 client.UniversalClient = redis.NewClient(opts.Simple()) 125 default: 126 return nil 127 } 128 129 for _, name := range client.name { 130 m.RedisConnections[name] = client 131 } 132 133 client.count++ 134 135 return client 136 } 137 138 // getRedisOptions pulls various configuration options based on the RedisUri format and converts them to go-redis's 139 // UniversalOptions fields. This function explicitly excludes fields related to TLS configuration, which is 140 // conditionally attached to this options struct before being converted to the specific type for the redis scheme being 141 // used, and only in scenarios where TLS is applicable (e.g. rediss://, redis+clusters://). 142 func getRedisOptions(uri *url.URL) *redis.UniversalOptions { 143 opts := &redis.UniversalOptions{} 144 145 // Handle username/password 146 if password, ok := uri.User.Password(); ok { 147 opts.Password = password 148 // Username does not appear to be handled by redis.Options 149 opts.Username = uri.User.Username() 150 } else if uri.User.Username() != "" { 151 // assume this is the password 152 opts.Password = uri.User.Username() 153 } 154 155 // Now handle the uri query sets 156 for k, v := range uri.Query() { 157 switch replacer.Replace(strings.ToLower(k)) { 158 case "addr": 159 opts.Addrs = append(opts.Addrs, v...) 160 case "addrs": 161 opts.Addrs = append(opts.Addrs, strings.Split(v[0], ",")...) 162 case "username": 163 opts.Username = v[0] 164 case "password": 165 opts.Password = v[0] 166 case "database": 167 fallthrough 168 case "db": 169 opts.DB, _ = strconv.Atoi(v[0]) 170 case "maxretries": 171 opts.MaxRetries, _ = strconv.Atoi(v[0]) 172 case "minretrybackoff": 173 opts.MinRetryBackoff = valToTimeDuration(v) 174 case "maxretrybackoff": 175 opts.MaxRetryBackoff = valToTimeDuration(v) 176 case "timeout": 177 timeout := valToTimeDuration(v) 178 if timeout != 0 { 179 if opts.DialTimeout == 0 { 180 opts.DialTimeout = timeout 181 } 182 if opts.ReadTimeout == 0 { 183 opts.ReadTimeout = timeout 184 } 185 } 186 case "dialtimeout": 187 opts.DialTimeout = valToTimeDuration(v) 188 case "readtimeout": 189 opts.ReadTimeout = valToTimeDuration(v) 190 case "writetimeout": 191 opts.WriteTimeout = valToTimeDuration(v) 192 case "poolsize": 193 opts.PoolSize, _ = strconv.Atoi(v[0]) 194 case "minidleconns": 195 opts.MinIdleConns, _ = strconv.Atoi(v[0]) 196 case "pooltimeout": 197 opts.PoolTimeout = valToTimeDuration(v) 198 case "idletimeout": 199 // opts.IdleTimeout = valToTimeDuration(v) 200 opts.ConnMaxIdleTime = valToTimeDuration(v) 201 // case "idlecheckfrequency": 202 // opts.IdleCheckFrequency = valToTimeDuration(v) 203 case "maxredirects": 204 opts.MaxRedirects, _ = strconv.Atoi(v[0]) 205 case "readonly": 206 opts.ReadOnly, _ = strconv.ParseBool(v[0]) 207 case "routebylatency": 208 opts.RouteByLatency, _ = strconv.ParseBool(v[0]) 209 case "routerandomly": 210 opts.RouteRandomly, _ = strconv.ParseBool(v[0]) 211 case "sentinelmasterid": 212 fallthrough 213 case "mastername": 214 opts.MasterName = v[0] 215 case "sentinelusername": 216 opts.SentinelUsername = v[0] 217 case "sentinelpassword": 218 opts.SentinelPassword = v[0] 219 } 220 } 221 222 if uri.Host != "" { 223 opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...) 224 } 225 226 // A redis connection string uses the path section of the URI in two different ways. In a TCP-based connection, the 227 // path will be a database index to automatically have the client SELECT. In a Unix socket connection, it will be the 228 // file path. We only want to try to coerce this to the database index when we're not expecting a file path so that 229 // the error log stays clean. 230 if uri.Path != "" && uri.Scheme != "redis+socket" { 231 if db, err := strconv.Atoi(uri.Path[1:]); err == nil { 232 opts.DB = db 233 } else { 234 log.Error("Provided database identifier '%s' is not a valid integer. GitBundle will ignore this option.", uri.Path) 235 } 236 } 237 238 return opts 239 } 240 241 // getRedisTlsOptions parses RedisUri TLS configuration parameters and converts them to the go TLS configuration 242 // equivalent fields. 243 func getRedisTLSOptions(uri *url.URL) *tls.Config { 244 tlsConfig := &tls.Config{} 245 246 skipverify := uri.Query().Get("skipverify") 247 248 if len(skipverify) > 0 { 249 skipverify, err := strconv.ParseBool(skipverify) 250 if err != nil { 251 tlsConfig.InsecureSkipVerify = skipverify 252 } 253 } 254 255 insecureskipverify := uri.Query().Get("insecureskipverify") 256 257 if len(insecureskipverify) > 0 { 258 insecureskipverify, err := strconv.ParseBool(insecureskipverify) 259 if err != nil { 260 tlsConfig.InsecureSkipVerify = insecureskipverify 261 } 262 } 263 264 return tlsConfig 265 }