github.com/rudderlabs/rudder-go-kit@v0.30.0/kafkaclient/config.go (about)

     1  package client
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/segmentio/kafka-go/sasl"
    10  	"github.com/segmentio/kafka-go/sasl/plain"
    11  	"github.com/segmentio/kafka-go/sasl/scram"
    12  )
    13  
    14  type ScramHashGenerator uint8
    15  
    16  const (
    17  	ScramPlainText ScramHashGenerator = iota
    18  	ScramSHA256
    19  	ScramSHA512
    20  )
    21  
    22  func (s ScramHashGenerator) String() string {
    23  	switch s {
    24  	case ScramPlainText:
    25  		return "plain"
    26  	case ScramSHA256:
    27  		return "sha256"
    28  	case ScramSHA512:
    29  		return "sha512"
    30  	default:
    31  		panic(fmt.Errorf("scram hash generator out of the known domain %d", s))
    32  	}
    33  }
    34  
    35  // ScramHashGeneratorFromString returns the proper ScramHashGenerator from its string counterpart
    36  func ScramHashGeneratorFromString(s string) (ScramHashGenerator, error) {
    37  	switch s {
    38  	case "plain":
    39  		return ScramPlainText, nil
    40  	case "sha256":
    41  		return ScramSHA256, nil
    42  	case "sha512":
    43  		return ScramSHA512, nil
    44  	}
    45  	var hg ScramHashGenerator
    46  	return hg, fmt.Errorf("scram hash generator out of the known domain: %s", s)
    47  }
    48  
    49  type Config struct {
    50  	ClientID    string
    51  	DialTimeout time.Duration
    52  	TLS         *TLS
    53  	SASL        *SASL
    54  	SSHConfig   *SSHConfig
    55  }
    56  
    57  type SSHConfig struct {
    58  	User, Host, PrivateKey string
    59  }
    60  
    61  func (c *Config) defaults() {
    62  	if c.DialTimeout < 1 {
    63  		c.DialTimeout = 10 * time.Second
    64  	}
    65  }
    66  
    67  type TLS struct {
    68  	Cert, Key,
    69  	CACertificate []byte
    70  	WithSystemCertPool,
    71  	InsecureSkipVerify bool
    72  }
    73  
    74  func (c *TLS) build() (*tls.Config, error) {
    75  	if len(c.CACertificate) == 0 && !c.InsecureSkipVerify && !c.WithSystemCertPool {
    76  		return nil, fmt.Errorf("invalid TLS configuration, either provide certificates or skip validation")
    77  	}
    78  
    79  	conf := &tls.Config{ // skipcq: GSC-G402
    80  		MinVersion: tls.VersionTLS11,
    81  		MaxVersion: tls.VersionTLS12,
    82  	}
    83  
    84  	if c.InsecureSkipVerify {
    85  		conf.InsecureSkipVerify = true // skipcq: GSC-G402
    86  	}
    87  
    88  	if c.WithSystemCertPool {
    89  		caCertPool, err := x509.SystemCertPool()
    90  		if err != nil {
    91  			return nil, fmt.Errorf("could not copy of the system cert pool: %w", err)
    92  		}
    93  
    94  		conf.RootCAs = caCertPool
    95  	}
    96  
    97  	if len(c.CACertificate) > 0 {
    98  		if conf.RootCAs == nil {
    99  			conf.RootCAs = x509.NewCertPool()
   100  		}
   101  		if ok := conf.RootCAs.AppendCertsFromPEM(c.CACertificate); !ok {
   102  			return nil, fmt.Errorf("could not append certs from PEM")
   103  		}
   104  	}
   105  
   106  	if len(c.Cert) > 0 && len(c.Key) > 0 {
   107  		certificate, err := tls.X509KeyPair(c.Cert, c.Key)
   108  		if err != nil {
   109  			return nil, fmt.Errorf("could not get TLS certificate: %w", err)
   110  		}
   111  
   112  		conf.Certificates = []tls.Certificate{certificate}
   113  	}
   114  
   115  	return conf, nil
   116  }
   117  
   118  type SASL struct {
   119  	ScramHashGen       ScramHashGenerator
   120  	Username, Password string
   121  }
   122  
   123  func (c *SASL) build() (mechanism sasl.Mechanism, err error) {
   124  	switch c.ScramHashGen {
   125  	case ScramPlainText:
   126  		mechanism = plain.Mechanism{
   127  			Username: c.Username,
   128  			Password: c.Password,
   129  		}
   130  		return
   131  	case ScramSHA256, ScramSHA512:
   132  		algo := scram.SHA256
   133  		if c.ScramHashGen == ScramSHA512 {
   134  			algo = scram.SHA512
   135  		}
   136  		mechanism, err = scram.Mechanism(algo, c.Username, c.Password)
   137  		return
   138  	default:
   139  		return nil, fmt.Errorf("scram hash generator out of the known domain: %v", c.ScramHashGen)
   140  	}
   141  }