github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/mongodb_repository.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/asaskevich/govalidator"
     8  	log "github.com/sirupsen/logrus"
     9  	"go.mongodb.org/mongo-driver/bson"
    10  	"go.mongodb.org/mongo-driver/mongo"
    11  	"go.mongodb.org/mongo-driver/mongo/options"
    12  )
    13  
    14  const (
    15  	collectionName = "oauth_servers"
    16  
    17  	mongoQueryTimeout = 10 * time.Second
    18  )
    19  
    20  // Repository defines the behavior of a OAuth Server repo
    21  type Repository interface {
    22  	FindAll() ([]*OAuth, error)
    23  	FindByName(name string) (*OAuth, error)
    24  	Add(oauth *OAuth) error
    25  	Save(oauth *OAuth) error
    26  	Remove(id string) error
    27  }
    28  
    29  // MongoRepository represents a mongodb repository
    30  type MongoRepository struct {
    31  	collection *mongo.Collection
    32  }
    33  
    34  // NewMongoRepository creates a mongodb OAuth Server repo
    35  func NewMongoRepository(db *mongo.Database) (*MongoRepository, error) {
    36  	return &MongoRepository{db.Collection(collectionName)}, nil
    37  }
    38  
    39  // FindAll fetches all the OAuth Servers available
    40  func (r *MongoRepository) FindAll() ([]*OAuth, error) {
    41  	var result []*OAuth
    42  
    43  	ctx, cancel := context.WithTimeout(context.Background(), mongoQueryTimeout)
    44  	defer cancel()
    45  
    46  	cur, err := r.collection.Find(ctx, bson.M{}, options.Find().SetSort(bson.D{{Key: "name", Value: 1}}))
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	defer cur.Close(ctx)
    51  
    52  	for cur.Next(ctx) {
    53  		o := new(OAuth)
    54  		if err := cur.Decode(o); err != nil {
    55  			return nil, err
    56  		}
    57  
    58  		result = append(result, o)
    59  	}
    60  
    61  	return result, cur.Err()
    62  }
    63  
    64  // FindByName find an OAuth Server by name
    65  func (r *MongoRepository) FindByName(name string) (*OAuth, error) {
    66  	ctx, cancel := context.WithTimeout(context.Background(), mongoQueryTimeout)
    67  	defer cancel()
    68  
    69  	result := NewOAuth()
    70  	err := r.collection.FindOne(ctx, bson.M{"name": name}).Decode(result)
    71  	if err == mongo.ErrNoDocuments {
    72  		return nil, ErrOauthServerNotFound
    73  	}
    74  
    75  	return result, err
    76  }
    77  
    78  // Add add a new OAuth Server to the repository
    79  func (r *MongoRepository) Add(oauth *OAuth) error {
    80  	isValid, err := govalidator.ValidateStruct(oauth)
    81  	if !isValid && err != nil {
    82  		log.WithField("errors", err.Error()).Error("Validation errors")
    83  		return err
    84  	}
    85  
    86  	ctx, cancel := context.WithTimeout(context.Background(), mongoQueryTimeout)
    87  	defer cancel()
    88  
    89  	_, err = r.collection.InsertOne(ctx, oauth)
    90  	if err != nil {
    91  		if isDuplicateKeyError(err) {
    92  			return ErrOauthServerNameExists
    93  		}
    94  		log.WithField("name", oauth.Name).WithError(err).Error("There was an error persisting the resource")
    95  		return err
    96  	}
    97  
    98  	log.WithField("name", oauth.Name).Debug("Resource persisted")
    99  	return nil
   100  }
   101  
   102  // Save saves OAuth Server to the repository
   103  func (r *MongoRepository) Save(oauth *OAuth) error {
   104  	isValid, err := govalidator.ValidateStruct(oauth)
   105  	if !isValid && err != nil {
   106  		log.WithField("errors", err.Error()).Error("Validation errors")
   107  		return err
   108  	}
   109  
   110  	ctx, cancel := context.WithTimeout(context.Background(), mongoQueryTimeout)
   111  	defer cancel()
   112  
   113  	if err := r.collection.FindOneAndUpdate(
   114  		ctx,
   115  		bson.M{"name": oauth.Name},
   116  		bson.M{"$set": oauth},
   117  		options.FindOneAndUpdate().SetUpsert(true),
   118  	).Err(); err != nil {
   119  		log.WithField("name", oauth.Name).WithError(err).Error("There was an error adding the resource")
   120  		return err
   121  	}
   122  
   123  	log.WithField("name", oauth.Name).Debug("Resource added")
   124  	return nil
   125  }
   126  
   127  // Remove removes an OAuth Server from the repository
   128  func (r *MongoRepository) Remove(name string) error {
   129  	ctx, cancel := context.WithTimeout(context.Background(), mongoQueryTimeout)
   130  	defer cancel()
   131  
   132  	_, err := r.collection.DeleteOne(ctx, bson.M{"name": name})
   133  	if err != nil {
   134  		log.WithField("name", name).Error("There was an error removing the resource")
   135  		return err
   136  	}
   137  
   138  	log.WithField("name", name).Debug("Resource removed")
   139  	return nil
   140  }
   141  
   142  func isDuplicateKeyError(err error) bool {
   143  	// TODO: maybe there is (or will be) a better way of checking duplicate key error
   144  	// this one is based on https://github.com/mongodb/mongo-go-driver/blob/master/mongo/integration/collection_test.go#L54-L65
   145  	we, ok := err.(mongo.WriteException)
   146  	if !ok {
   147  		return false
   148  	}
   149  
   150  	return len(we.WriteErrors) > 0 && we.WriteErrors[0].Code == 11000
   151  }