github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/util/dynamo/dynamo.go (about)

     1  package dynamo
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/15mga/kiwi"
     8  	"github.com/15mga/kiwi/util"
     9  	"github.com/aws/aws-sdk-go-v2/aws"
    10  	"github.com/aws/aws-sdk-go-v2/config"
    11  	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
    12  	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
    13  	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
    14  	"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
    15  	"github.com/aws/smithy-go/logging"
    16  )
    17  
    18  var (
    19  	_Client      *dynamodb.Client
    20  	_NameToTable = make(map[string]ITable)
    21  )
    22  
    23  func Client() *dynamodb.Client {
    24  	return _Client
    25  }
    26  
    27  func ConnLocal(url string) *util.Err {
    28  	cfg, err := config.LoadDefaultConfig(util.Ctx(),
    29  		config.WithRegion("ap-east-1"),
    30  		config.WithEndpointResolverWithOptions(localEndpointResolver{
    31  			url: url,
    32  		}),
    33  		//config.WithCredentialsProvider(localCredentialsProvider{}),
    34  		config.WithLogger(logger{}),
    35  	)
    36  	if err != nil {
    37  		return util.NewErr(util.EcParamsErr, util.M{
    38  			"err": err.Error(),
    39  		})
    40  	}
    41  	_Client = dynamodb.NewFromConfig(cfg)
    42  	return nil
    43  }
    44  
    45  func ConnAWS(region string) *util.Err {
    46  	cfg, err := config.LoadDefaultConfig(util.Ctx(),
    47  		config.WithRegion(region),
    48  		config.WithLogger(logger{}),
    49  	)
    50  	if err != nil {
    51  		return util.NewErr(util.EcParamsErr, util.M{
    52  			"err": err.Error(),
    53  		})
    54  	}
    55  	_Client = dynamodb.NewFromConfig(cfg)
    56  	return nil
    57  }
    58  
    59  type localEndpointResolver struct {
    60  	url string
    61  }
    62  
    63  func (l localEndpointResolver) ResolveEndpoint(service, region string, options ...interface{}) (aws.Endpoint, error) {
    64  	return aws.Endpoint{
    65  		URL: l.url,
    66  	}, nil
    67  }
    68  
    69  type localCredentialsProvider struct {
    70  }
    71  
    72  func (l localCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
    73  	return aws.Credentials{
    74  		AccessKeyID: "dummy", SecretAccessKey: "dummy", SessionToken: "dummy",
    75  		Source: "Hard-coded credentials; values are irrelevant for local DynamoDB",
    76  	}, nil
    77  }
    78  
    79  type logger struct {
    80  }
    81  
    82  func (l logger) Logf(classification logging.Classification, format string, v ...interface{}) {
    83  	switch classification {
    84  	case logging.Debug:
    85  		kiwi.Debug(fmt.Sprintf(format, v...), nil)
    86  	case logging.Warn:
    87  		kiwi.Warn(util.NewErr(util.EcDbErr, util.M{
    88  			"error": fmt.Sprintf(format, v...),
    89  		}))
    90  	}
    91  }
    92  
    93  func BindTable(table ITable) {
    94  	_NameToTable[table.Name()] = table
    95  }
    96  
    97  func CreateTable(tableName string) *util.Err {
    98  	table, ok := _NameToTable[tableName]
    99  	if !ok {
   100  		return util.NewErr(util.EcNotExist, util.M{
   101  			"table": tableName,
   102  		})
   103  	}
   104  	return creatTable(table)
   105  }
   106  
   107  func creatTable(table ITable) *util.Err {
   108  	params := &dynamodb.CreateTableInput{
   109  		TableName:            aws.String(table.Name()),
   110  		AttributeDefinitions: table.AttributeDefinitions(),
   111  		KeySchema:            table.KeySchema(),
   112  	}
   113  	_, e := _Client.CreateTable(util.Ctx(), params)
   114  	if e != nil {
   115  		return util.WrapErr(util.EcDbErr, e)
   116  	}
   117  	return nil
   118  }
   119  
   120  func MigrateTables() {
   121  	for name, table := range _NameToTable {
   122  		ad := table.AttributeDefinitions()
   123  		if ad == nil || len(ad) == 0 {
   124  			continue
   125  		}
   126  		exist, err := IsTableExist(name)
   127  		if err != nil {
   128  			kiwi.Error(err)
   129  			continue
   130  		}
   131  		if exist {
   132  			continue
   133  		}
   134  		err = creatTable(table)
   135  		if err == nil {
   136  			kiwi.Error(err)
   137  		}
   138  	}
   139  }
   140  
   141  func IsTableExist(table string) (bool, *util.Err) {
   142  	_, e := _Client.DescribeTable(util.Ctx(), &dynamodb.DescribeTableInput{TableName: aws.String(table)})
   143  	if e != nil {
   144  		var notFoundEx *types.ResourceNotFoundException
   145  		if errors.As(e, &notFoundEx) {
   146  			return false, nil
   147  		}
   148  		return false, util.WrapErr(util.EcDbErr, e)
   149  	}
   150  	return true, nil
   151  }
   152  
   153  type MAV map[string]types.AttributeValue
   154  
   155  func mToAv(data util.M, avs MAV) {
   156  	for attr, val := range data {
   157  		av, e := attributevalue.Marshal(val)
   158  		if e != nil {
   159  			kiwi.Warn(util.WrapErr(util.EcMarshallErr, e))
   160  			continue
   161  		}
   162  		avs[attr] = av
   163  	}
   164  }
   165  
   166  func GetEntity[T any](table string, filter util.M, consistent bool, entity T, projectAttrs ...string) *util.Err {
   167  	mav := make(MAV)
   168  	mToAv(filter, mav)
   169  	params := &dynamodb.GetItemInput{
   170  		Key:            mav,
   171  		TableName:      aws.String(table),
   172  		ConsistentRead: &consistent,
   173  	}
   174  	if len(projectAttrs) > 0 {
   175  		pb := expression.ProjectionBuilder{}
   176  		for _, attr := range projectAttrs {
   177  			pb.AddNames(expression.Name(attr))
   178  		}
   179  		expr, e := expression.NewBuilder().WithProjection(pb).Build()
   180  		if e != nil {
   181  			return util.WrapErr(util.EcParamsErr, e)
   182  		}
   183  		params.ExpressionAttributeNames = expr.Names()
   184  		params.ProjectionExpression = expr.Projection()
   185  	}
   186  	res, e := _Client.GetItem(util.Ctx(), params)
   187  	if e != nil {
   188  		return util.WrapErr(util.EcDbErr, e)
   189  	}
   190  	e = attributevalue.UnmarshalMap(res.Item, entity)
   191  	return util.WrapErr(util.EcDbErr, e)
   192  }
   193  
   194  func PutNewEntity(table string, entity util.M, unique ...string) *util.Err {
   195  	item, e := attributevalue.MarshalMap(entity)
   196  	if e != nil {
   197  		return util.WrapErr(util.EcMarshallErr, e)
   198  	}
   199  	ce := ""
   200  	i := 0
   201  	c := len(unique) - 1
   202  	for _, key := range unique {
   203  		ce += "attribute_not_exists(" + key + ")"
   204  		if i < c {
   205  			ce += " AND "
   206  		}
   207  	}
   208  	params := &dynamodb.PutItemInput{
   209  		TableName:           aws.String(table),
   210  		Item:                item,
   211  		ConditionExpression: &ce,
   212  	}
   213  	_, e = _Client.PutItem(util.Ctx(), params)
   214  	return util.WrapErr(util.EcDbErr, e)
   215  }
   216  
   217  func PutOrReplaceEntity(table string, entity any) *util.Err {
   218  	item, e := attributevalue.MarshalMap(entity)
   219  	if e != nil {
   220  		return util.WrapErr(util.EcMarshallErr, e)
   221  	}
   222  	params := &dynamodb.PutItemInput{
   223  		TableName: aws.String(table),
   224  		Item:      item,
   225  	}
   226  	_, e = _Client.PutItem(util.Ctx(), params)
   227  	return util.WrapErr(util.EcDbErr, e)
   228  }
   229  
   230  type UpdateResult map[string]map[string]any
   231  
   232  func buildUpdateParams(table string, params *dynamodb.UpdateItemInput, filter, data util.M) *util.Err {
   233  	l := len(data)
   234  	if l == 0 {
   235  		return util.NewErr(util.EcParamsErr, util.M{
   236  			"error": "attrs length is zero",
   237  		})
   238  	}
   239  	mav := make(MAV)
   240  	mToAv(filter, mav)
   241  	ub := expression.UpdateBuilder{}
   242  	for attr, val := range data {
   243  		ub.Set(expression.Name(attr), expression.Value(val))
   244  	}
   245  	expr, e := expression.NewBuilder().WithUpdate(ub).Build()
   246  	if e != nil {
   247  		return util.WrapErr(util.EcParamsErr, e)
   248  	}
   249  	params.Key = mav
   250  	params.TableName = aws.String(table)
   251  	params.ExpressionAttributeNames = expr.Names()
   252  	params.ExpressionAttributeValues = expr.Values()
   253  	params.UpdateExpression = expr.Update()
   254  	return nil
   255  }
   256  
   257  func UpdateEntity(table string, filter, data util.M) *util.Err {
   258  	params := &dynamodb.UpdateItemInput{}
   259  	err := buildUpdateParams(table, params, filter, data)
   260  	if err != nil {
   261  		return err
   262  	}
   263  	_, e := _Client.UpdateItem(util.Ctx(), params)
   264  	return util.WrapErr(util.EcDbErr, e)
   265  }
   266  
   267  func UpdateEntityWithResult(table string, filter, data util.M, resultValue types.ReturnValue, result UpdateResult) *util.Err {
   268  	params := &dynamodb.UpdateItemInput{}
   269  	err := buildUpdateParams(table, params, filter, data)
   270  	if err != nil {
   271  		return err
   272  	}
   273  	params.ReturnValues = resultValue
   274  	res, e := _Client.UpdateItem(util.Ctx(), params)
   275  	if e != nil {
   276  		return util.WrapErr(util.EcDbErr, e)
   277  	}
   278  	e = attributevalue.UnmarshalMap(res.Attributes, &result)
   279  	return util.WrapErr(util.EcUnmarshallErr, e)
   280  }
   281  
   282  func getQueryEqualExpr(table string, filter util.M, params *dynamodb.QueryInput, projectAttrs []string) *util.Err {
   283  	b := expression.NewBuilder()
   284  	for attr, val := range filter {
   285  		kcb := expression.Key(attr).Equal(expression.Value(val))
   286  		b.WithKeyCondition(kcb)
   287  	}
   288  
   289  	hasProject := buildProjection(&b, projectAttrs)
   290  	expr, e := b.Build()
   291  	if e != nil {
   292  		return util.WrapErr(util.EcParamsErr, e)
   293  	}
   294  
   295  	params.ExpressionAttributeNames = expr.Names()
   296  	if hasProject {
   297  		params.ProjectionExpression = expr.Projection()
   298  	}
   299  
   300  	params.TableName = aws.String(table)
   301  	params.ExpressionAttributeNames = expr.Names()
   302  	params.ExpressionAttributeValues = expr.Values()
   303  	params.KeyConditionExpression = expr.KeyCondition()
   304  	return nil
   305  }
   306  
   307  func buildProjection(builder *expression.Builder, projectAttrs []string) bool {
   308  	if projectAttrs == nil || len(projectAttrs) == 0 {
   309  		return false
   310  	}
   311  	pb := expression.ProjectionBuilder{}
   312  	for _, attr := range projectAttrs {
   313  		pb.AddNames(expression.Name(attr))
   314  	}
   315  	builder.WithProjection(pb)
   316  	return true
   317  }
   318  
   319  func getQueryBetweenExpr(table string, params *dynamodb.QueryInput, key string, start, end any, limit int32,
   320  	forward bool, projectAttrs []string) *util.Err {
   321  	b := expression.NewBuilder()
   322  	kcb := expression.Key(key).Between(expression.Value(start), expression.Value(end))
   323  	b.WithKeyCondition(kcb)
   324  
   325  	hasProject := buildProjection(&b, projectAttrs)
   326  	expr, e := b.Build()
   327  	if e != nil {
   328  		return util.WrapErr(util.EcParamsErr, e)
   329  	}
   330  
   331  	params.ExpressionAttributeNames = expr.Names()
   332  	if hasProject {
   333  		params.ProjectionExpression = expr.Projection()
   334  	}
   335  
   336  	params.TableName = aws.String(table)
   337  	params.ExpressionAttributeNames = expr.Names()
   338  	params.ExpressionAttributeValues = expr.Values()
   339  	params.KeyConditionExpression = expr.KeyCondition()
   340  	params.ScanIndexForward = &forward
   341  	if limit > 0 {
   342  		params.Limit = &limit
   343  	}
   344  	return nil
   345  }
   346  
   347  func QueryEntities[T any](table string, filter util.M, items *[]T, projectAttrs ...string) *util.Err {
   348  	params := &dynamodb.QueryInput{}
   349  	err := getQueryEqualExpr(table, filter, params, projectAttrs)
   350  	if err != nil {
   351  		return err
   352  	}
   353  
   354  	res, e := _Client.Query(util.Ctx(), params)
   355  	if e != nil {
   356  		return util.WrapErr(util.EcDbErr, e)
   357  	}
   358  	e = attributevalue.UnmarshalListOfMaps(res.Items, items)
   359  	return util.WrapErr(util.EcUnmarshallErr, e)
   360  }
   361  
   362  func QueryEntitiesBetween[T any](table, key string, start, end any, limit int32,
   363  	forward bool, items *[]T, projectAttrs ...string) *util.Err {
   364  	params := &dynamodb.QueryInput{}
   365  	err := getQueryBetweenExpr(table, params, key, start, end, limit, forward, projectAttrs)
   366  	if err != nil {
   367  		return err
   368  	}
   369  
   370  	res, e := _Client.Query(util.Ctx(), params)
   371  	if e != nil {
   372  		return util.WrapErr(util.EcDbErr, e)
   373  	}
   374  	e = attributevalue.UnmarshalListOfMaps(res.Items, items)
   375  	return util.WrapErr(util.EcUnmarshallErr, e)
   376  }