github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/api/cassandra_repository.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	cass "github.com/hellofresh/janus/cassandra"
     7  	"github.com/hellofresh/janus/cassandra/wrapper"
     8  	"github.com/opentracing/opentracing-go"
     9  	log "github.com/sirupsen/logrus"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  // CassandraRepository represents a cassandra repository
    16  type CassandraRepository struct {
    17  	//TODO: we need to expose this so the plugins can use the same Session. We should abstract mongo DB and provide
    18  	// the plugins with a simple interface to search, insert, update and remove data from whatever backend implementation
    19  	Session     wrapper.Holder
    20  	refreshTime time.Duration
    21  }
    22  
    23  func NewCassandraRepository(dsn string, refreshTime time.Duration) (*CassandraRepository, error) {
    24  	log.Debugf("getting new api cassandra repo")
    25  	span := opentracing.StartSpan("NewCassandraRepository")
    26  	defer span.Finish()
    27  	span.SetTag("Interface", "CassandraRepository")
    28  
    29  	// parse the dsn string for the cluster host, system key space, app key space and connection timeout.
    30  	log.Infof("dsn is %s", dsn)
    31  	clusterHost, systemKeyspace, appKeyspace, connectionTimeout := parseDSN(dsn)
    32  	if clusterHost == "" {
    33  		clusterHost = cass.ClusterHostName
    34  	}
    35  	if systemKeyspace == "" {
    36  		systemKeyspace = cass.SystemKeyspace
    37  	}
    38  	if appKeyspace == "" {
    39  		appKeyspace = cass.AppKeyspace
    40  	}
    41  	if connectionTimeout == 0 {
    42  		connectionTimeout = cass.Timeout
    43  	}
    44  
    45  	// Wait for Cassandra to start, setup Cassandra keyspace if required
    46  	wrapper.Initialize(clusterHost, systemKeyspace, appKeyspace, time.Duration(connectionTimeout)*time.Second)
    47  
    48  	// Getting a cassandra connection initializer
    49  	initializer := wrapper.New(clusterHost, appKeyspace)
    50  
    51  	// Starting a new cassandra Session
    52  	sessionHolder, err := initializer.NewSession()
    53  	if err != nil {
    54  		panic(err)
    55  	}
    56  	// api cassandra repo Session
    57  	cass.SetSessionHolder(sessionHolder)
    58  
    59  	return &CassandraRepository{
    60  		Session:     sessionHolder,
    61  		refreshTime: refreshTime,
    62  	}, nil
    63  
    64  }
    65  
    66  func (r *CassandraRepository) Close() error {
    67  	// Close the Session
    68  	r.Session.CloseSession()
    69  	return nil
    70  }
    71  
    72  // Listen watches for changes on the configuration
    73  func (r *CassandraRepository) Listen(ctx context.Context, cfgChan <-chan ConfigurationMessage) {
    74  	go func() {
    75  		log.Debug("Listening for changes on the provider...")
    76  		for {
    77  			select {
    78  			case cfg := <-cfgChan:
    79  				switch cfg.Operation {
    80  				case AddedOperation:
    81  					err := r.add(cfg.Configuration)
    82  					if err != nil {
    83  						log.WithError(err).Error("Could not add the configuration on the provider")
    84  					}
    85  				case UpdatedOperation:
    86  					err := r.add(cfg.Configuration)
    87  					if err != nil {
    88  						log.WithError(err).Error("Could not update the configuration on the provider")
    89  					}
    90  				case RemovedOperation:
    91  					err := r.remove(cfg.Configuration.Name)
    92  					if err != nil {
    93  						log.WithError(err).Error("Could not remove the configuration from the provider")
    94  					}
    95  				}
    96  			case <-ctx.Done():
    97  				return
    98  			}
    99  		}
   100  	}()
   101  }
   102  
   103  // Watch watches for changes on the database
   104  func (r *CassandraRepository) Watch(ctx context.Context, cfgChan chan<- ConfigurationChanged) {
   105  	t := time.NewTicker(r.refreshTime)
   106  	go func(refreshTicker *time.Ticker) {
   107  		defer refreshTicker.Stop()
   108  		log.Debug("Watching Provider...")
   109  		for {
   110  			select {
   111  			case <-refreshTicker.C:
   112  				defs, err := r.FindAll()
   113  				if err != nil {
   114  					log.WithError(err).Error("Failed to get configurations on watch")
   115  					continue
   116  				}
   117  
   118  				cfgChan <- ConfigurationChanged{
   119  					Configurations: &Configuration{Definitions: defs},
   120  				}
   121  			case <-ctx.Done():
   122  				return
   123  			}
   124  		}
   125  	}(t)
   126  }
   127  
   128  // FindAll fetches all the API definitions available
   129  func (r *CassandraRepository) FindAll() ([]*Definition, error) {
   130  	log.Debugf("finding all definitions")
   131  
   132  	var results []*Definition
   133  
   134  	iter := r.Session.GetSession().Query(
   135  		"SELECT definition FROM api_definition").Iter()
   136  
   137  	var savedDef string
   138  
   139  	err := iter.ScanAndClose(func() bool {
   140  		var definition *Definition
   141  		err := json.Unmarshal([]byte(savedDef), &definition)
   142  		if err != nil {
   143  			log.Errorf("error trying to unmarshal definition json: %v", err)
   144  			return false
   145  		}
   146  		results = append(results, definition)
   147  		return true
   148  	}, &savedDef)
   149  
   150  	if err != nil {
   151  		log.Errorf("error getting all definitions: %v", err)
   152  	}
   153  	return results, err
   154  }
   155  
   156  // Add adds an API definition to the repository
   157  func (r *CassandraRepository) add(definition *Definition) error {
   158  	log.Debugf("adding: %s", definition.Name)
   159  
   160  	isValid, err := definition.Validate()
   161  	if false == isValid && err != nil {
   162  		log.WithError(err).Error("Validation errors")
   163  		return err
   164  	}
   165  
   166  	saveDef, err := json.Marshal(definition)
   167  	if err != nil {
   168  		log.Errorf("error marshaling oauth: %v", err)
   169  		return err
   170  	}
   171  
   172  	err = r.Session.GetSession().Query(
   173  		"UPDATE api_definition "+
   174  			"SET definition = ? "+
   175  			"WHERE name = ?",
   176  		saveDef, definition.Name).Exec()
   177  
   178  	if err != nil {
   179  		log.Errorf("error saving definition %s: %v", definition.Name, err)
   180  	} else {
   181  		log.Debugf("successfully saved definition %s", definition.Name)
   182  	}
   183  
   184  	return err
   185  }
   186  
   187  // Remove removes an API definition from the repository
   188  func (r *CassandraRepository) remove(name string) error {
   189  	log.Debugf("removing: %s", name)
   190  
   191  	err := r.Session.GetSession().Query(
   192  		"DELETE FROM api_definition WHERE name = ?", name).Exec()
   193  
   194  	if err != nil {
   195  		log.Errorf("error saving definition %s: %v", name, err)
   196  	} else {
   197  		log.Debugf("successfully saved definition %s", name)
   198  	}
   199  
   200  	return err
   201  }
   202  
   203  func parseDSN(dsn string) (clusterHost string, systemKeyspace string, appKeyspace string, connectionTimeout int) {
   204  	trimDSN := strings.TrimSpace(dsn)
   205  	if len(trimDSN) == 0 {
   206  		return "", "", "", 0
   207  	}
   208  	splitDSN := strings.Split(trimDSN, "/")
   209  	// list of info
   210  	for i, v := range splitDSN {
   211  		// start at 1 because the dsn path comes in with a leading /
   212  		switch i {
   213  		case 1:
   214  			clusterHost = v
   215  		case 2:
   216  			systemKeyspace = v
   217  		case 3:
   218  			appKeyspace = v
   219  		case 4:
   220  			timeout, err := strconv.Atoi(v)
   221  			if err != nil {
   222  				log.Error("timeout is not an int")
   223  				timeout = 0
   224  			}
   225  			connectionTimeout = timeout
   226  		}
   227  	}
   228  	return clusterHost, systemKeyspace, appKeyspace, connectionTimeout
   229  }