github.com/angryronald/go-kit@v0.0.0-20240505173814-ff2bd9c79dbf/generic/repository/nosql/generic.repository.mutable.go (about)

     1  package nosql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/gocql/gocql"
    10  	"github.com/google/uuid"
    11  
    12  	"github.com/angryronald/go-kit/appcontext"
    13  	"github.com/angryronald/go-kit/cast"
    14  	"github.com/angryronald/go-kit/generic/repository"
    15  )
    16  
    17  type GenericRepository struct {
    18  	db       *gocql.Session
    19  	keySpace string
    20  }
    21  
    22  // FindAll for IN params please pass the object directly without parsing to string
    23  func (r *GenericRepository) FindAll(ctx context.Context, params map[string]interface{}, conditionalOperations []repository.ConditionalOperation, relationalOperations []repository.RelationalOperation, page int, limit int, result interface{}) (interface{}, error) {
    24  	tableName := repository.GetTableName(result)
    25  	if !repository.IsValidOperations(conditionalOperations, relationalOperations, params) {
    26  		return nil, repository.ErrBadParameters
    27  	}
    28  
    29  	where := ""
    30  	i := 0
    31  	args := []interface{}{}
    32  	for k, v := range params {
    33  		switch conditionalOperations[i] {
    34  		case repository.IN:
    35  			encapsulation := ""
    36  			if !repository.AreAllNumbers(v.([]interface{})) && !repository.AreAllUUIDs(v.([]interface{})) {
    37  				encapsulation = "'"
    38  			}
    39  
    40  			where = fmt.Sprintf(`%s %s %s (%s)`, where, k, conditionalOperations[i], repository.StringJoin(v.([]interface{}), encapsulation))
    41  
    42  		case repository.ILIKE, repository.LIKE:
    43  			return nil, repository.ErrNotImplement
    44  
    45  		default:
    46  			where = fmt.Sprintf(`%s %s %s ?`, where, k, conditionalOperations[i])
    47  			args = append(args, v)
    48  		}
    49  
    50  		if i < len(relationalOperations) {
    51  			where = fmt.Sprintf(`%s %s`, where, relationalOperations[i])
    52  		}
    53  		i++
    54  	}
    55  
    56  	query := fmt.Sprintf("SELECT * FROM %s.%s WHERE %s", r.keySpace, tableName, where)
    57  	// due to unable to use pagination this part will be comment out for now
    58  	// if page != 0 || limit != 0 {
    59  	// 	query = fmt.Sprintf("%s LIMIT %d OFFSET %d", query, limit, (page-1)*limit)
    60  	// }
    61  	query = fmt.Sprintf("%s ALLOW FILTERING;", query)
    62  	iter := r.db.Query(query, args...).Iter()
    63  
    64  	resultRaw := make([]interface{}, 0)
    65  	for {
    66  		m := make(map[string]interface{})
    67  		if !iter.MapScan(m) {
    68  			break
    69  		}
    70  		resultRaw = append(resultRaw, m)
    71  	}
    72  
    73  	cast.TransformObject(resultRaw, &result)
    74  
    75  	return result, nil
    76  }
    77  
    78  func (r *GenericRepository) FindOne(ctx context.Context, key string, value interface{}, result interface{}) (interface{}, error) {
    79  	where := fmt.Sprintf(`%s %s ?`, key, repository.EQUAL_WITH)
    80  	args := []interface{}{
    81  		value,
    82  	}
    83  	tableName := repository.GetTableName(result)
    84  	query := fmt.Sprintf("SELECT * FROM %s.%s WHERE %s ALLOW FILTERING;", r.keySpace, tableName, where)
    85  	iter := r.db.Query(query, args...).Iter()
    86  
    87  	m := make(map[string]interface{})
    88  	if !iter.MapScan(m) {
    89  		return nil, repository.ErrNotFound
    90  	}
    91  
    92  	cast.TransformObject(m, result)
    93  
    94  	return result, nil
    95  }
    96  
    97  func (r *GenericRepository) FindByID(ctx context.Context, id uuid.UUID, result interface{}) (interface{}, error) {
    98  	where := fmt.Sprintf(`ID %s ?`, repository.EQUAL_WITH)
    99  	args := []interface{}{
   100  		id.String(),
   101  	}
   102  	tableName := repository.GetTableName(result)
   103  	query := fmt.Sprintf("SELECT * FROM %s.%s WHERE %s ALLOW FILTERING;", r.keySpace, tableName, where)
   104  	iter := r.db.Query(query, args...).Iter()
   105  
   106  	m := make(map[string]interface{})
   107  	if !iter.MapScan(m) {
   108  		return nil, repository.ErrNotFound
   109  	}
   110  
   111  	cast.TransformObject(m, result)
   112  
   113  	return result, nil
   114  }
   115  
   116  func (r *GenericRepository) Insert(ctx context.Context, data interface{}) (interface{}, error) {
   117  	newID, err := uuid.NewRandom()
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	if err = repository.UpdatePropertyValue(data, "ID", newID); err != nil {
   122  		return nil, err
   123  	}
   124  	if err = repository.UpdatePropertyValue(data, "CreatedBy", appcontext.UserID(ctx)); err != nil {
   125  		return nil, err
   126  	}
   127  	if err = repository.UpdatePropertyValue(data, "UpdatedBy", appcontext.UserID(ctx)); err != nil && err != repository.ErrPropertyNotFound {
   128  		return nil, err
   129  	}
   130  
   131  	tableName := repository.GetTableName(data)
   132  	columns, values, err := GetPropertyNamesAndValues(data)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	paramString := ""
   137  	for i := 0; i < len(values); i++ {
   138  		paramString += "?,"
   139  	}
   140  
   141  	if len(paramString) > 0 {
   142  		paramString = paramString[:len(paramString)-1]
   143  	}
   144  
   145  	query := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", r.keySpace, tableName, strings.Join(columns, ","), paramString)
   146  	if err = r.db.Query(query, values...).Exec(); err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	return data, nil
   151  }
   152  
   153  func (r *GenericRepository) Update(ctx context.Context, data interface{}) (interface{}, error) {
   154  	var err error
   155  
   156  	newData, _ := repository.CopyObject(data)
   157  	ID, _ := repository.GetStructPropertyAsString(newData, "ID")
   158  
   159  	if _, err = r.FindByID(ctx, uuid.MustParse(ID), newData); err != nil {
   160  		return nil, repository.ErrNotFound
   161  	}
   162  
   163  	if err = repository.UpdatePropertyValue(data, "UpdatedBy", appcontext.UserID(ctx)); err != nil && err != repository.ErrPropertyNotFound {
   164  		return nil, err
   165  	}
   166  
   167  	tableName := repository.GetTableName(data)
   168  	columns, values, err := GetPropertyNamesAndValues(data)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	setQuery := []string{}
   173  	newArgs := []interface{}{}
   174  	var tempIDValue interface{}
   175  	for i := 0; i < len(columns); i++ {
   176  		if columns[i] == "ID" {
   177  			tempIDValue = values[i]
   178  			continue
   179  		}
   180  		setQuery = append(setQuery, fmt.Sprintf("%s=?", columns[i]))
   181  
   182  		newArgs = append(newArgs, values[i])
   183  	}
   184  	newArgs = append(newArgs, tempIDValue)
   185  
   186  	query := fmt.Sprintf("UPDATE %s.%s SET %s WHERE ID = ?;", r.keySpace, tableName, strings.Join(setQuery, ","))
   187  	if err := r.db.Query(query, newArgs...).Exec(); err != nil {
   188  		return nil, err
   189  	}
   190  
   191  	return data, nil
   192  }
   193  
   194  func (r *GenericRepository) Delete(ctx context.Context, data interface{}) (interface{}, error) {
   195  	ID, _ := repository.GetStructPropertyAsString(data, "ID")
   196  
   197  	tableName := repository.GetTableName(data)
   198  	query := fmt.Sprintf("DELETE FROM %s.%s WHERE ID = ?;", r.keySpace, tableName)
   199  	if err := r.db.Query(query, ID).Exec(); err != nil {
   200  		return nil, err
   201  	}
   202  
   203  	return data, nil
   204  }
   205  
   206  func (r *GenericRepository) Upsert(ctx context.Context, data interface{}) (interface{}, error) {
   207  	return nil, repository.ErrNotImplement
   208  }
   209  
   210  func (r *GenericRepository) BulkInsert(ctx context.Context, data interface{}) (interface{}, error) {
   211  	return nil, repository.ErrNotImplement
   212  }
   213  
   214  func (r *GenericRepository) BulkUpsert(ctx context.Context, data interface{}) (interface{}, error) {
   215  	return nil, repository.ErrNotImplement
   216  }
   217  
   218  func (r *GenericRepository) Query(ctx context.Context, query string, params []interface{}, result interface{}) (interface{}, error) {
   219  	iter := r.db.Query(query, params...).Iter()
   220  
   221  	resultRaw := make([]interface{}, 0)
   222  	for {
   223  		m := make(map[string]interface{})
   224  		if !iter.MapScan(m) {
   225  			break
   226  		}
   227  		resultRaw = append(resultRaw, m)
   228  	}
   229  
   230  	cast.TransformObject(resultRaw, &result)
   231  
   232  	return result, nil
   233  }
   234  
   235  // DO NOT USE Pointer for nosql (Cassandra) object due to migration issue, please use cast library to casting object do not rely on result due type mismatch, deleted process need to be adjusted later, for now it would be hard delete
   236  func NewRepository(db *gocql.Session, keySpace string) repository.GenericRepositoryInterface {
   237  	return &GenericRepository{
   238  		db:       db,
   239  		keySpace: keySpace,
   240  	}
   241  }
   242  
   243  // DO NOT USE Pointer for nosql (Cassandra) object due to migration issue, please use cast library to casting object do not rely on result due type mismatch, deleted process need to be adjusted later, for now it would be hard delete
   244  func NewMutableRepository(db *gocql.Session, keySpace string) repository.MutableGenericRepositoryInterface {
   245  	return &GenericRepository{
   246  		db:       db,
   247  		keySpace: keySpace,
   248  	}
   249  }
   250  
   251  func MigrateTable(session *gocql.Session, keySpace string, entity interface{}) error {
   252  	tableName := repository.GetTableName(entity)
   253  	entityType := reflect.TypeOf(entity).Elem()
   254  
   255  	createQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (", keySpace, tableName)
   256  	for i := 0; i < entityType.NumField(); i++ {
   257  		field := entityType.Field(i)
   258  		fieldName := field.Name
   259  		fieldType := field.Type
   260  
   261  		// If the field type is a pointer, dereference it
   262  		if field.Type.Kind() == reflect.Ptr {
   263  			fieldType = field.Type.Elem()
   264  		} else {
   265  			fieldType = field.Type
   266  		}
   267  
   268  		// Skip unexported fields
   269  		if field.PkgPath != "" {
   270  			continue
   271  		}
   272  
   273  		createQuery += fmt.Sprintf("%s %s,", fieldName, GoTypeMapToCQLType[fieldType.Name()])
   274  		if fieldName == "ID" {
   275  			createQuery = fmt.Sprintf("%s PRIMARY KEY,", createQuery[:len(createQuery)-1])
   276  		}
   277  	}
   278  	createQuery = createQuery[:len(createQuery)-1] + ") WITH CLUSTERING ORDER BY (CreatedAt DESC);"
   279  
   280  	if err := session.Query(createQuery).Exec(); err != nil {
   281  		return err
   282  	}
   283  
   284  	return nil
   285  }