github.com/blend/go-sdk@v1.20220411.3/db/column_collection.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"reflect"
    12  	"strings"
    13  	"sync"
    14  
    15  	"github.com/blend/go-sdk/stringutil"
    16  )
    17  
    18  var (
    19  	metaCacheMu sync.RWMutex
    20  	metaCache   = make(map[string]*ColumnCollection)
    21  )
    22  
    23  // --------------------------------------------------------------------------------
    24  // Common helpers
    25  // --------------------------------------------------------------------------------
    26  
    27  // Columns returns the cached column metadata for an object.
    28  func Columns(object DatabaseMapped) *ColumnCollection {
    29  	objectType := reflect.TypeOf(object)
    30  	return ColumnsFromType(newColumnCacheKey(objectType), objectType)
    31  }
    32  
    33  // ColumnsFromType reflects a reflect.Type into a column collection.
    34  // The results of this are cached for speed.
    35  func ColumnsFromType(identifier string, t reflect.Type) *ColumnCollection {
    36  	// check with read lock ...
    37  	metaCacheMu.RLock()
    38  	if value, ok := metaCache[identifier]; ok {
    39  		metaCacheMu.RUnlock()
    40  		return value
    41  	}
    42  	metaCacheMu.RUnlock()
    43  
    44  	// grab write lock ...
    45  	metaCacheMu.Lock()
    46  	defer metaCacheMu.Unlock()
    47  
    48  	// double checked lock
    49  	if value, ok := metaCache[identifier]; ok {
    50  		return value
    51  	}
    52  
    53  	metadata := NewColumnCollection(generateColumnsForType(nil, t)...)
    54  	metaCache[identifier] = metadata
    55  	return metadata
    56  }
    57  
    58  // --------------------------------------------------------------------------------
    59  // Utility
    60  // --------------------------------------------------------------------------------
    61  
    62  // ColumnNamesCSV returns a csv of column names.
    63  func ColumnNamesCSV(object DatabaseMapped) string {
    64  	return Columns(object).ColumnNamesCSV()
    65  }
    66  
    67  // --------------------------------------------------------------------------------
    68  // Column Collection
    69  // --------------------------------------------------------------------------------
    70  
    71  // NewColumnCollection returns a new empty column collection.
    72  func NewColumnCollection(columns ...Column) *ColumnCollection {
    73  	cc := ColumnCollection{
    74  		columns: columns,
    75  	}
    76  	lookup := make(map[string]*Column)
    77  	for i := 0; i < len(columns); i++ {
    78  		col := &columns[i]
    79  		lookup[col.ColumnName] = col
    80  	}
    81  	cc.lookup = lookup
    82  	return &cc
    83  }
    84  
    85  // NewColumnCollectionWithPrefix makes a new column collection with a column prefix.
    86  func NewColumnCollectionWithPrefix(columnPrefix string, columns ...Column) *ColumnCollection {
    87  	cc := ColumnCollection{
    88  		columns: columns,
    89  	}
    90  	lookup := make(map[string]*Column)
    91  	for i := 0; i < len(columns); i++ {
    92  		col := &columns[i]
    93  		lookup[col.ColumnName] = col
    94  	}
    95  	cc.lookup = lookup
    96  	cc.columnPrefix = columnPrefix
    97  	return &cc
    98  }
    99  
   100  // ColumnCollection represents the column metadata for a given struct.
   101  type ColumnCollection struct {
   102  	columns      []Column
   103  	lookup       map[string]*Column
   104  	columnPrefix string
   105  
   106  	autos          *ColumnCollection
   107  	notAutos       *ColumnCollection
   108  	readOnly       *ColumnCollection
   109  	notReadOnly    *ColumnCollection
   110  	primaryKeys    *ColumnCollection
   111  	notPrimaryKeys *ColumnCollection
   112  	uniqueKeys     *ColumnCollection
   113  	notUniqueKeys  *ColumnCollection
   114  	insertColumns  *ColumnCollection
   115  	updateColumns  *ColumnCollection
   116  }
   117  
   118  // Len returns the number of columns.
   119  func (cc *ColumnCollection) Len() int {
   120  	if cc == nil {
   121  		return 0
   122  	}
   123  	return len(cc.columns)
   124  }
   125  
   126  // Add adds a column.
   127  func (cc *ColumnCollection) Add(c Column) {
   128  	cc.columns = append(cc.columns, c)
   129  	cc.lookup[c.ColumnName] = &c
   130  }
   131  
   132  // Remove removes a column (by column name) from the collection.
   133  func (cc *ColumnCollection) Remove(columnName string) {
   134  	var newColumns []Column
   135  	for _, c := range cc.columns {
   136  		if c.ColumnName != columnName {
   137  			newColumns = append(newColumns, c)
   138  		}
   139  	}
   140  	cc.columns = newColumns
   141  	delete(cc.lookup, columnName)
   142  }
   143  
   144  // HasColumn returns if a column name is present in the collection.
   145  func (cc *ColumnCollection) HasColumn(columnName string) bool {
   146  	_, hasColumn := cc.lookup[columnName]
   147  	return hasColumn
   148  }
   149  
   150  // Copy creates a new column collection instance and carries over an existing column prefix.
   151  func (cc ColumnCollection) Copy() *ColumnCollection {
   152  	return NewColumnCollectionWithPrefix(cc.columnPrefix, cc.columns...)
   153  }
   154  
   155  // CopyWithColumnPrefix applies a column prefix to column names and returns a new column collection.
   156  func (cc ColumnCollection) CopyWithColumnPrefix(prefix string) *ColumnCollection {
   157  	return NewColumnCollectionWithPrefix(prefix, cc.columns...)
   158  }
   159  
   160  // InsertColumns are non-auto, non-readonly columns.
   161  func (cc *ColumnCollection) InsertColumns() *ColumnCollection {
   162  	if cc.insertColumns != nil {
   163  		return cc.insertColumns
   164  	}
   165  
   166  	cc.insertColumns = cc.NotReadOnly().NotAutos()
   167  	return cc.insertColumns
   168  }
   169  
   170  // UpdateColumns are non-primary key, non-readonly columns.
   171  func (cc *ColumnCollection) UpdateColumns() *ColumnCollection {
   172  	if cc.updateColumns != nil {
   173  		return cc.updateColumns
   174  	}
   175  
   176  	cc.updateColumns = cc.NotReadOnly().NotPrimaryKeys()
   177  	return cc.updateColumns
   178  }
   179  
   180  // PrimaryKeys are columns we use as where predicates and can't update.
   181  func (cc *ColumnCollection) PrimaryKeys() *ColumnCollection {
   182  	if cc.primaryKeys != nil {
   183  		return cc.primaryKeys
   184  	}
   185  
   186  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   187  	for _, c := range cc.columns {
   188  		if c.IsPrimaryKey {
   189  			newCC.Add(c)
   190  		}
   191  	}
   192  
   193  	cc.primaryKeys = newCC
   194  	return cc.primaryKeys
   195  }
   196  
   197  // NotPrimaryKeys are columns we can update.
   198  func (cc *ColumnCollection) NotPrimaryKeys() *ColumnCollection {
   199  	if cc.notPrimaryKeys != nil {
   200  		return cc.notPrimaryKeys
   201  	}
   202  
   203  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   204  
   205  	for _, c := range cc.columns {
   206  		if !c.IsPrimaryKey {
   207  			newCC.Add(c)
   208  		}
   209  	}
   210  
   211  	cc.notPrimaryKeys = newCC
   212  	return cc.notPrimaryKeys
   213  }
   214  
   215  // UniqueKeys are columns we use as where predicates and can't update.
   216  func (cc *ColumnCollection) UniqueKeys() *ColumnCollection {
   217  	if cc.uniqueKeys != nil {
   218  		return cc.uniqueKeys
   219  	}
   220  
   221  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   222  	for _, c := range cc.columns {
   223  		if c.IsUniqueKey {
   224  			newCC.Add(c)
   225  		}
   226  	}
   227  
   228  	cc.uniqueKeys = newCC
   229  	return cc.uniqueKeys
   230  }
   231  
   232  // NotUniqueKeys are columns we can update.
   233  func (cc *ColumnCollection) NotUniqueKeys() *ColumnCollection {
   234  	if cc.notUniqueKeys != nil {
   235  		return cc.notUniqueKeys
   236  	}
   237  
   238  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   239  	for _, c := range cc.columns {
   240  		if !c.IsUniqueKey {
   241  			newCC.Add(c)
   242  		}
   243  	}
   244  
   245  	cc.notUniqueKeys = newCC
   246  	return cc.notUniqueKeys
   247  }
   248  
   249  // Autos are columns we have to return the id of.
   250  func (cc *ColumnCollection) Autos() *ColumnCollection {
   251  	if cc.autos != nil {
   252  		return cc.autos
   253  	}
   254  
   255  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   256  	for _, c := range cc.columns {
   257  		if c.IsAuto {
   258  			newCC.Add(c)
   259  		}
   260  	}
   261  
   262  	cc.autos = newCC
   263  	return cc.autos
   264  }
   265  
   266  // NotAutos are columns we don't have to return the id of.
   267  func (cc *ColumnCollection) NotAutos() *ColumnCollection {
   268  	if cc.notAutos != nil {
   269  		return cc.notAutos
   270  	}
   271  
   272  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   273  	for _, c := range cc.columns {
   274  		if !c.IsAuto {
   275  			newCC.Add(c)
   276  		}
   277  	}
   278  	cc.notAutos = newCC
   279  	return cc.notAutos
   280  }
   281  
   282  // ReadOnly are columns that we don't have to insert upon Create().
   283  func (cc *ColumnCollection) ReadOnly() *ColumnCollection {
   284  	if cc.readOnly != nil {
   285  		return cc.readOnly
   286  	}
   287  
   288  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   289  	for _, c := range cc.columns {
   290  		if c.IsReadOnly {
   291  			newCC.Add(c)
   292  		}
   293  	}
   294  
   295  	cc.readOnly = newCC
   296  	return cc.readOnly
   297  }
   298  
   299  // NotReadOnly are columns that we have to insert upon Create().
   300  func (cc *ColumnCollection) NotReadOnly() *ColumnCollection {
   301  	if cc.notReadOnly != nil {
   302  		return cc.notReadOnly
   303  	}
   304  
   305  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   306  	for _, c := range cc.columns {
   307  		if !c.IsReadOnly {
   308  			newCC.Add(c)
   309  		}
   310  	}
   311  
   312  	cc.notReadOnly = newCC
   313  	return cc.notReadOnly
   314  }
   315  
   316  // Zero returns unset fields on an instance that correspond to fields in the column collection.
   317  func (cc *ColumnCollection) Zero(instance interface{}) *ColumnCollection {
   318  	objValue := ReflectValue(instance)
   319  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   320  	var fieldValue reflect.Value
   321  	for _, c := range cc.columns {
   322  		fieldValue = objValue.Field(c.Index)
   323  		if fieldValue.IsZero() {
   324  			newCC.Add(c)
   325  		}
   326  	}
   327  	return newCC
   328  }
   329  
   330  // NotZero returns set fields on an instance that correspond to fields in the column collection.
   331  func (cc *ColumnCollection) NotZero(instance interface{}) *ColumnCollection {
   332  	objValue := ReflectValue(instance)
   333  	newCC := NewColumnCollectionWithPrefix(cc.columnPrefix)
   334  	var fieldValue reflect.Value
   335  	for _, c := range cc.columns {
   336  		fieldValue = objValue.Field(c.Index)
   337  		if !fieldValue.IsZero() {
   338  			newCC.Add(c)
   339  		}
   340  	}
   341  	return newCC
   342  }
   343  
   344  // ColumnNames returns the string names for all the columns in the collection.
   345  func (cc *ColumnCollection) ColumnNames() []string {
   346  	if cc == nil {
   347  		return nil
   348  	}
   349  	names := make([]string, len(cc.columns))
   350  	for x := 0; x < len(cc.columns); x++ {
   351  		c := cc.columns[x]
   352  		if len(cc.columnPrefix) != 0 {
   353  			names[x] = cc.columnPrefix + c.ColumnName
   354  		} else {
   355  			names[x] = c.ColumnName
   356  		}
   357  	}
   358  	return names
   359  }
   360  
   361  // Columns returns the colummns
   362  func (cc *ColumnCollection) Columns() []Column {
   363  	return cc.columns
   364  }
   365  
   366  // Lookup gets the column name lookup.
   367  func (cc *ColumnCollection) Lookup() map[string]*Column {
   368  	if len(cc.columnPrefix) != 0 {
   369  		lookup := map[string]*Column{}
   370  		for key, value := range cc.lookup {
   371  			lookup[cc.columnPrefix+key] = value
   372  		}
   373  		return lookup
   374  	}
   375  	return cc.lookup
   376  }
   377  
   378  // ColumnNamesFromAlias returns the string names for all the columns in the collection.
   379  func (cc *ColumnCollection) ColumnNamesFromAlias(tableAlias string) []string {
   380  	names := make([]string, len(cc.columns))
   381  	for x := 0; x < len(cc.columns); x++ {
   382  		c := cc.columns[x]
   383  		if cc.columnPrefix != "" {
   384  			names[x] = tableAlias + "." + c.ColumnName + " as " + cc.columnPrefix + c.ColumnName
   385  		} else {
   386  			names[x] = tableAlias + "." + c.ColumnName
   387  		}
   388  	}
   389  	return names
   390  }
   391  
   392  // ColumnNamesCSVFromAlias returns the string names for all the columns in the collection.
   393  func (cc *ColumnCollection) ColumnNamesCSVFromAlias(tableAlias string) string {
   394  	return stringutil.CSV(cc.ColumnNamesFromAlias(tableAlias))
   395  }
   396  
   397  // ColumnValues returns the reflected value for all the columns on a given instance.
   398  func (cc *ColumnCollection) ColumnValues(instance interface{}) []interface{} {
   399  	value := ReflectValue(instance)
   400  
   401  	values := make([]interface{}, len(cc.columns))
   402  	for x := 0; x < len(cc.columns); x++ {
   403  		c := cc.columns[x]
   404  		valueField := value.FieldByName(c.FieldName)
   405  		if c.IsJSON {
   406  			values[x] = JSON(valueField.Interface())
   407  		} else {
   408  			values[x] = valueField.Interface()
   409  		}
   410  	}
   411  	return values
   412  }
   413  
   414  // FirstOrDefault returns the first column in the collection or `nil` if the collection is empty.
   415  func (cc *ColumnCollection) FirstOrDefault() *Column {
   416  	if len(cc.columns) > 0 {
   417  		return &cc.columns[0]
   418  	}
   419  	return nil
   420  }
   421  
   422  // ConcatWith merges a collection with another collection.
   423  func (cc *ColumnCollection) ConcatWith(other *ColumnCollection) *ColumnCollection {
   424  	total := make([]Column, len(cc.columns)+len(other.columns))
   425  	var x int
   426  	for ; x < len(cc.columns); x++ {
   427  		total[x] = cc.columns[x]
   428  	}
   429  	for y := 0; y < len(other.columns); y++ {
   430  		total[x+y] = other.columns[y]
   431  	}
   432  	return NewColumnCollection(total...)
   433  }
   434  
   435  func (cc *ColumnCollection) String() string {
   436  	return strings.Join(cc.ColumnNames(), ", ")
   437  }
   438  
   439  // ColumnNamesCSV returns a csv of column names.
   440  func (cc *ColumnCollection) ColumnNamesCSV() string {
   441  	return stringutil.CSV(cc.ColumnNames())
   442  }
   443  
   444  //
   445  // helpers
   446  //
   447  
   448  // newColumnCacheKey creates a cache key for a type.
   449  func newColumnCacheKey(objectType reflect.Type) string {
   450  	typeName := objectType.String()
   451  	instance := reflect.New(objectType).Interface()
   452  	if typed, ok := instance.(ColumnMetaCacheKeyProvider); ok {
   453  		return typeName + "_" + typed.ColumnMetaCacheKey()
   454  	}
   455  	if typed, ok := instance.(TableNameProvider); ok {
   456  		return typeName + "_" + typed.TableName()
   457  	}
   458  	return typeName
   459  }
   460  
   461  // generateColumnsForType generates a column list for a given type.
   462  func generateColumnsForType(parent *Column, t reflect.Type) []Column {
   463  	for t.Kind() == reflect.Ptr {
   464  		t = t.Elem()
   465  	}
   466  
   467  	var tableName string
   468  	if parent != nil {
   469  		tableName = parent.TableName
   470  	} else {
   471  		tableName = TableNameByType(t)
   472  	}
   473  
   474  	numFields := t.NumField()
   475  
   476  	var cols []Column
   477  	for index := 0; index < numFields; index++ {
   478  		field := t.Field(index)
   479  		col := NewColumnFromFieldTag(field)
   480  		if col != nil {
   481  			col.Parent = parent
   482  			col.Index = index
   483  			col.TableName = tableName
   484  			if col.Inline && field.Anonymous { // if it's not anonymous, whatchu doin
   485  				cols = append(cols, generateColumnsForType(col, col.FieldType)...)
   486  			} else if !field.Anonymous {
   487  				cols = append(cols, *col)
   488  			}
   489  		}
   490  	}
   491  
   492  	return cols
   493  }