github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/redis/redis.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //    Licensed under the Apache License, Version 2.0 (the "License");
     4  //    you may not use this file except in compliance with the License.
     5  //    You may obtain a copy of the License at
     6  //
     7  //        http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //    Unless required by applicable law or agreed to in writing, software
    10  //    distributed under the License is distributed on an "AS IS" BASIS,
    11  //    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //    See the License for the specific language governing permissions and
    13  //    limitations under the License.
    14  
    15  package redis
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"net"
    22  	"net/url"
    23  	"strconv"
    24  	"strings"
    25  
    26  	"github.com/redis/go-redis/v9"
    27  )
    28  
    29  // NewClient creates a new redis client (Cmdable) from the parameters in the
    30  // connectionString URL format:
    31  // Standalone mode:
    32  // (redis|rediss|unix)://[<user>:<password>@](<host>|<socket path>)[:<port>[/<db_number>]][?option=value]
    33  // Cluster mode:
    34  // (redis|rediss|unix)[+srv]://[<user>:<password>@]<host1>[,<host2>[,...]][:<port>][?option=value]
    35  //
    36  // The following query parameters are also available:
    37  // client_name         string
    38  // conn_max_idle_time  duration
    39  // conn_max_lifetime   duration
    40  // dial_timeout        duration
    41  // max_idle_conns      int
    42  // max_retries         int
    43  // max_retry_backoff   duration
    44  // min_idle_conns      int
    45  // min_retry_backoff   duration
    46  // pool_fifo           bool
    47  // pool_size           int
    48  // pool_timeout        duration
    49  // protocol            int
    50  // read_timeout        duration
    51  // tls                 bool
    52  // write_timeout       duration
    53  func ClientFromConnectionString(
    54  	ctx context.Context,
    55  	connectionString string,
    56  ) (redis.Cmdable, error) {
    57  	var (
    58  		redisurl   *url.URL
    59  		tlsOptions *tls.Config
    60  		rdb        redis.Cmdable
    61  	)
    62  	redisurl, err := url.Parse(connectionString)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	// in case connection string was provided in form of host:port
    67  	// add scheme and parse again
    68  	if redisurl.Host == "" {
    69  		redisurl, err = url.Parse("redis://" + connectionString)
    70  		if err != nil {
    71  			return nil, err
    72  		}
    73  	}
    74  	q := redisurl.Query()
    75  	scheme := redisurl.Scheme
    76  	cname := redisurl.Hostname()
    77  	if strings.HasSuffix(scheme, "+srv") {
    78  		scheme = strings.TrimSuffix(redisurl.Scheme, "+srv")
    79  		var srv []*net.SRV
    80  		cname, srv, err = net.DefaultResolver.LookupSRV(ctx, scheme, "tcp", redisurl.Host)
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  		addrs := make([]string, 0, len(srv))
    85  		for i := range srv {
    86  			if srv[i] == nil {
    87  				continue
    88  			}
    89  			host := strings.TrimSuffix(srv[i].Target, ".")
    90  			addrs = append(addrs, fmt.Sprintf("%s:%d", host, srv[i].Port))
    91  		}
    92  		redisurl.Host = strings.Join(addrs, ",")
    93  		// cleanup the scheme with one known to Redis
    94  		// to avoid: invalid URL scheme: tcp-redis+srv
    95  		redisurl.Scheme = "redis"
    96  
    97  	} else if scheme == "" {
    98  		redisurl.Scheme = "redis"
    99  	}
   100  	// To allow more flexibility for the srv record service
   101  	// name we use "tls" query parameter to determine if we
   102  	// should use TLS, otherwise we test if the service
   103  	// name contains "rediss" before falling back to no TLS.
   104  	var useTLS bool
   105  	if scheme == "rediss" {
   106  		useTLS = true
   107  	} else {
   108  		useTLS, _ = strconv.ParseBool(q.Get("tls"))
   109  	}
   110  	if useTLS {
   111  		tlsOptions = &tls.Config{ServerName: cname}
   112  	}
   113  	// Allow host to be a comma-separated list of hosts.
   114  	if idx := strings.LastIndexByte(redisurl.Host, ','); idx > 0 {
   115  		nodeAddrs := strings.Split(redisurl.Host[:idx], ",")
   116  		q["addr"] = nodeAddrs
   117  		redisurl.RawQuery = q.Encode()
   118  		redisurl.Host = redisurl.Host[idx+1:]
   119  	}
   120  	var cluster bool
   121  	if _, ok := q["addr"]; ok {
   122  		cluster = true
   123  	}
   124  	if cluster {
   125  		var redisOpts *redis.ClusterOptions
   126  		redisOpts, err = redis.ParseClusterURL(redisurl.String())
   127  		if err == nil {
   128  			if tlsOptions != nil {
   129  				redisOpts.TLSConfig = tlsOptions
   130  			}
   131  			rdb = redis.NewClusterClient(redisOpts)
   132  		}
   133  	} else {
   134  		var redisOpts *redis.Options
   135  		redisOpts, err = redis.ParseURL(redisurl.String())
   136  		if err == nil {
   137  			rdb = redis.NewClient(redisOpts)
   138  		}
   139  	}
   140  	if err != nil {
   141  		return nil, fmt.Errorf("redis: invalid connection string: %w", err)
   142  	}
   143  	_, err = rdb.
   144  		Ping(ctx).
   145  		Result()
   146  	return rdb, err
   147  }