github.com/justinjmoses/evergreen@v0.0.0-20170530173719-1d50e381ff0d/db/db_utils.go (about)

     1  package db
     2  
     3  import (
     4  	"io"
     5  
     6  	"github.com/mongodb/grip"
     7  	"github.com/pkg/errors"
     8  
     9  	"gopkg.in/mgo.v2"
    10  	"gopkg.in/mgo.v2/bson"
    11  )
    12  
    13  var (
    14  	NoProjection = bson.M{}
    15  	NoSort       = []string{}
    16  	NoSkip       = 0
    17  	NoLimit      = 0
    18  )
    19  
    20  // Insert inserts the specified item into the specified collection.
    21  func Insert(collection string, item interface{}) error {
    22  	session, db, err := GetGlobalSessionFactory().GetSession()
    23  	if err != nil {
    24  		return nil
    25  	}
    26  	defer session.Close()
    27  
    28  	return db.C(collection).Insert(item)
    29  }
    30  
    31  // Clear removes all documents from a specified collection.
    32  func Clear(collection string) error {
    33  	session, db, err := GetGlobalSessionFactory().GetSession()
    34  	if err != nil {
    35  		return err
    36  	}
    37  	defer session.Close()
    38  	_, err = db.C(collection).RemoveAll(bson.M{})
    39  	return err
    40  }
    41  
    42  // ClearCollections clears all documents from all the specified collections, returning an error
    43  // immediately if clearing any one of them fails.
    44  func ClearCollections(collections ...string) error {
    45  	session, db, err := GetGlobalSessionFactory().GetSession()
    46  	if err != nil {
    47  		return err
    48  	}
    49  	defer session.Close()
    50  	for _, collection := range collections {
    51  		_, err = db.C(collection).RemoveAll(bson.M{})
    52  		if err != nil {
    53  			return errors.Wrapf(err, "Couldn't clear collection '%v'", collection)
    54  		}
    55  	}
    56  	return nil
    57  }
    58  
    59  // EnsureIndex takes in a collection and ensures that the
    60  func EnsureIndex(collection string, index mgo.Index) error {
    61  	session, db, err := GetGlobalSessionFactory().GetSession()
    62  	if err != nil {
    63  		return errors.WithStack(err)
    64  	}
    65  	defer session.Close()
    66  	return db.C(collection).EnsureIndex(index)
    67  }
    68  
    69  // DropIndex takes in a collection and a slice of keys and drops those indexes
    70  func DropIndex(collection string, key ...string) error {
    71  	session, db, err := GetGlobalSessionFactory().GetSession()
    72  	if err != nil {
    73  		return err
    74  	}
    75  	defer session.Close()
    76  	return db.C(collection).DropIndex(key...)
    77  }
    78  
    79  // Remove removes one item matching the query from the specified collection.
    80  func Remove(collection string, query interface{}) error {
    81  	session, db, err := GetGlobalSessionFactory().GetSession()
    82  	if err != nil {
    83  		return err
    84  	}
    85  	defer session.Close()
    86  
    87  	return db.C(collection).Remove(query)
    88  }
    89  
    90  // RemoveAll removes all items matching the query from the specified collection.
    91  func RemoveAll(collection string, query interface{}) error {
    92  	session, db, err := GetGlobalSessionFactory().GetSession()
    93  	if err != nil {
    94  		return err
    95  	}
    96  	defer session.Close()
    97  
    98  	_, err = db.C(collection).RemoveAll(query)
    99  	return err
   100  }
   101  
   102  // FindOne finds one item from the specified collection and unmarshals it into the
   103  // provided interface, which must be a pointer.
   104  func FindOne(collection string, query interface{},
   105  	projection interface{}, sort []string, out interface{}) error {
   106  
   107  	session, db, err := GetGlobalSessionFactory().GetSession()
   108  	if err != nil {
   109  		grip.Errorf("error establishing db connection: %+v", err)
   110  		return err
   111  	}
   112  	defer session.Close()
   113  
   114  	q := db.C(collection).Find(query).Select(projection)
   115  	if len(sort) != 0 {
   116  		q = q.Sort(sort...)
   117  	}
   118  	return q.One(out)
   119  }
   120  
   121  // FindAll finds the items from the specified collection and unmarshals them into the
   122  // provided interface, which must be a slice.
   123  func FindAll(collection string, query interface{},
   124  	projection interface{}, sort []string, skip int, limit int,
   125  	out interface{}) error {
   126  
   127  	session, db, err := GetGlobalSessionFactory().GetSession()
   128  	if err != nil {
   129  		grip.Errorf("error establishing db connection: %+v", err)
   130  
   131  		return err
   132  	}
   133  	defer session.Close()
   134  
   135  	q := db.C(collection).Find(query).Select(projection)
   136  	if len(sort) != 0 {
   137  		q = q.Sort(sort...)
   138  	}
   139  	return q.Skip(skip).Limit(limit).All(out)
   140  }
   141  
   142  // Update updates one matching document in the collection.
   143  func Update(collection string, query interface{},
   144  	update interface{}) error {
   145  
   146  	session, db, err := GetGlobalSessionFactory().GetSession()
   147  	if err != nil {
   148  		grip.Errorf("error establishing db connection: %+v", err)
   149  
   150  		return err
   151  	}
   152  	defer session.Close()
   153  
   154  	return db.C(collection).Update(query, update)
   155  }
   156  
   157  // UpdateId updates one _id-matching document in the collection.
   158  func UpdateId(collection string, id, update interface{}) error {
   159  
   160  	session, db, err := GetGlobalSessionFactory().GetSession()
   161  	if err != nil {
   162  		grip.Errorf("error establishing db connection: %+v", err)
   163  
   164  		return err
   165  	}
   166  	defer session.Close()
   167  
   168  	return db.C(collection).UpdateId(id, update)
   169  }
   170  
   171  // UpdateAll updates all matching documents in the collection.
   172  func UpdateAll(collection string, query interface{},
   173  	update interface{}) (*mgo.ChangeInfo, error) {
   174  
   175  	session, db, err := GetGlobalSessionFactory().GetSession()
   176  	if err != nil {
   177  		grip.Errorf("error establishing db connection: %+v", err)
   178  
   179  		return nil, err
   180  	}
   181  	defer session.Close()
   182  
   183  	return db.C(collection).UpdateAll(query, update)
   184  }
   185  
   186  // Upsert run the specified update against the collection as an upsert operation.
   187  func Upsert(collection string, query interface{},
   188  	update interface{}) (*mgo.ChangeInfo, error) {
   189  
   190  	session, db, err := GetGlobalSessionFactory().GetSession()
   191  	if err != nil {
   192  		grip.Errorf("error establishing db connection: %+v", err)
   193  
   194  		return nil, err
   195  	}
   196  	defer session.Close()
   197  
   198  	return db.C(collection).Upsert(query, update)
   199  }
   200  
   201  // Count run a count command with the specified query against the collection.
   202  func Count(collection string, query interface{}) (int, error) {
   203  
   204  	session, db, err := GetGlobalSessionFactory().GetSession()
   205  	if err != nil {
   206  		grip.Errorf("error establishing db connection: %+v", err)
   207  
   208  		return 0, err
   209  	}
   210  	defer session.Close()
   211  
   212  	return db.C(collection).Find(query).Count()
   213  }
   214  
   215  // FindAndModify runs the specified query and change against the collection,
   216  // unmarshaling the result into the specified interface.
   217  func FindAndModify(collection string, query interface{}, sort []string,
   218  	change mgo.Change, out interface{}) (*mgo.ChangeInfo, error) {
   219  
   220  	session, db, err := GetGlobalSessionFactory().GetSession()
   221  	if err != nil {
   222  		grip.Errorf("error establishing db connection: %+v", err)
   223  
   224  		return nil, err
   225  	}
   226  	defer session.Close()
   227  	return db.C(collection).Find(query).Sort(sort...).Apply(change, out)
   228  }
   229  
   230  // WriteGridFile writes the data in the source Reader to a GridFS collection with
   231  // the given prefix and filename.
   232  func WriteGridFile(fsPrefix, name string, source io.Reader) error {
   233  	session, db, err := GetGlobalSessionFactory().GetSession()
   234  	if err != nil {
   235  		return err
   236  	}
   237  	defer session.Close()
   238  
   239  	file, err := db.GridFS(fsPrefix).Create(name)
   240  	if err != nil {
   241  		return err
   242  	}
   243  	defer file.Close()
   244  	_, err = io.Copy(file, source)
   245  	return err
   246  }
   247  
   248  type sessionBackedGridFile struct {
   249  	*mgo.GridFile
   250  	session *mgo.Session
   251  }
   252  
   253  func (sbgf *sessionBackedGridFile) Close() error {
   254  	err := sbgf.GridFile.Close()
   255  	sbgf.session.Close()
   256  	return err
   257  }
   258  
   259  // GetGridFile returns a ReadCloser for a file stored with the given name under the GridFS prefix.
   260  func GetGridFile(fsPrefix, name string) (io.ReadCloser, error) {
   261  	session, db, err := GetGlobalSessionFactory().GetSession()
   262  	if err != nil {
   263  		err = errors.Wrap(err, "error establishing db connection")
   264  		grip.Error(err)
   265  		return nil, err
   266  	}
   267  	file, err := db.GridFS(fsPrefix).Open(name)
   268  	if err != nil {
   269  		return nil, errors.WithStack(err)
   270  	}
   271  	return &sessionBackedGridFile{file, session}, nil
   272  }
   273  
   274  // Aggregate runs an aggregation pipeline on a collection and unmarshals
   275  // the results to the given "out" interface (usually a pointer
   276  // to an array of structs/bson.M)
   277  func Aggregate(collection string, pipeline interface{}, out interface{}) error {
   278  	session, db, err := GetGlobalSessionFactory().GetSession()
   279  	if err != nil {
   280  		err = errors.Wrap(err, "error establishing db connection")
   281  		grip.Error(err)
   282  		return err
   283  	}
   284  	defer session.Close()
   285  
   286  	session.SetSocketTimeout(0)
   287  	pipe := db.C(collection).Pipe(pipeline).AllowDiskUse()
   288  	return errors.WithStack(pipe.All(out))
   289  }