github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/cassandra_repository.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"encoding/json"
     5  	"github.com/hellofresh/janus/cassandra/wrapper"
     6  	log "github.com/sirupsen/logrus"
     7  )
     8  
     9  // CassandraRepository represents a cassandra repository
    10  type CassandraRepository struct {
    11  	session wrapper.Holder
    12  }
    13  
    14  func NewCassandraRepository(session wrapper.Holder) (*CassandraRepository, error) {
    15  	return &CassandraRepository{session: session}, nil
    16  
    17  }
    18  
    19  // FindAll fetches all the OAuth Servers available
    20  func (r *CassandraRepository) FindAll() ([]*OAuth, error) {
    21  	log.Debugf("finding all oauth servers")
    22  
    23  	var results []*OAuth
    24  
    25  	iter := r.session.GetSession().Query("SELECT name, oauth FROM oauth").Iter()
    26  
    27  	var savedOauth string
    28  
    29  	err := iter.ScanAndClose(func() bool {
    30  		var oauth *OAuth
    31  		err := json.Unmarshal([]byte(savedOauth), &oauth)
    32  		if err != nil {
    33  			log.Errorf("error trying to unmarshal oauth json: %v", err)
    34  			return false
    35  		}
    36  		results = append(results, oauth)
    37  		return true
    38  	}, &savedOauth)
    39  	if err != nil {
    40  		log.Errorf("error getting all oauths: %v", err)
    41  	}
    42  	return results, err
    43  }
    44  
    45  // FindByName find an OAuth Server by name
    46  func (r *CassandraRepository) FindByName(name string) (*OAuth, error) {
    47  	log.Debugf("finding: %s", name)
    48  
    49  	var savedOauth string
    50  	var oauth *OAuth
    51  
    52  	err := r.session.GetSession().Query(
    53  		"SELECT oauth " +
    54  			"FROM oauth " +
    55  			"WHERE name = ?",
    56  		name).Scan(&savedOauth)
    57  
    58  	err = json.Unmarshal([]byte(savedOauth), &oauth)
    59  
    60  	if err != nil {
    61  		log.Errorf("error selecting oauth %s: %v", name, err)
    62  	} else {
    63  		log.Debugf("successfully found oauth %s", name)
    64  	}
    65  
    66  	return oauth, err
    67  }
    68  
    69  // Add add a new OAuth Server to the repository
    70  // Add is the same as Save because Cassandra only upserts and I didn't want to write an existence checker
    71  func (r *CassandraRepository) Add(oauth *OAuth) error {
    72  	log.Debugf("adding: %s", oauth.Name)
    73  
    74  	saveOauth, err := json.Marshal(oauth)
    75  	if err != nil {
    76  		log.Errorf("error marshaling oauth: %v", err)
    77  		return err
    78  	}
    79  	err = r.session.GetSession().Query(
    80  		"UPDATE oauth " +
    81  			"SET oauth = ? " +
    82  			"WHERE name = ?",
    83  		saveOauth, oauth.Name).Exec()
    84  
    85  	if err != nil {
    86  		log.Errorf("error saving oauth %s: %v", oauth.Name, err)
    87  	} else {
    88  		log.Debugf("successfully saved oauth %s", oauth.Name)
    89  	}
    90  
    91  	return err
    92  }
    93  
    94  // Save saves OAuth Server to the repository
    95  func (r *CassandraRepository) Save(oauth *OAuth) error {
    96  	log.Debugf("adding: %s", oauth.Name)
    97  
    98  	saveOauth, err := json.Marshal(oauth)
    99  	if err != nil {
   100  		log.Errorf("error marshaling oauth: %v", err)
   101  		return err
   102  	}
   103  	err = r.session.GetSession().Query(
   104  		"UPDATE oauth " +
   105  			"SET oauth = ? " +
   106  			"WHERE name = ?",
   107  		saveOauth, oauth.Name).Exec()
   108  
   109  	if err != nil {
   110  		log.Errorf("error saving oauth %s: %v", oauth.Name, err)
   111  	} else {
   112  		log.Debugf("successfully saved oauth %s", oauth.Name)
   113  	}
   114  
   115  	return err
   116  }
   117  
   118  // Remove removes an OAuth Server from the repository
   119  func (r *CassandraRepository) Remove(name string) error {
   120  	log.Debugf("removing: %s", name)
   121  
   122  	err := r.session.GetSession().Query(
   123  		"DELETE FROM oauth WHERE name = ?", name).Exec()
   124  
   125  	if err != nil {
   126  		log.Errorf("error removing oauth %s: %v", name, err)
   127  	} else {
   128  		log.Debugf("successfully removed oauth %s", name)
   129  	}
   130  
   131  	return err
   132  }