github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/actions/lua/storage/aws/s3.go (about)

     1  package aws
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sort"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/Shopify/go-lua"
    13  	"github.com/aws/aws-sdk-go-v2/aws"
    14  	"github.com/aws/aws-sdk-go-v2/config"
    15  	"github.com/aws/aws-sdk-go-v2/credentials"
    16  	"github.com/aws/aws-sdk-go-v2/service/s3"
    17  	"github.com/aws/aws-sdk-go-v2/service/s3/types"
    18  	"github.com/treeverse/lakefs/pkg/actions/lua/util"
    19  )
    20  
    21  var errDeleteObject = errors.New("delete object failed")
    22  
    23  func newS3Client(ctx context.Context) lua.Function {
    24  	return func(l *lua.State) int {
    25  		accessKeyID := lua.CheckString(l, 1)
    26  		secretAccessKey := lua.CheckString(l, 2)
    27  		var region string
    28  		if !l.IsNone(3) {
    29  			region = lua.CheckString(l, 3)
    30  		}
    31  		var endpoint string
    32  		if !l.IsNone(4) {
    33  			endpoint = lua.CheckString(l, 4)
    34  		}
    35  
    36  		c := &S3Client{
    37  			AccessKeyID:     accessKeyID,
    38  			SecretAccessKey: secretAccessKey,
    39  			Endpoint:        endpoint,
    40  			Region:          region,
    41  			ctx:             ctx,
    42  		}
    43  		l.NewTable()
    44  		functions := map[string]lua.Function{
    45  			"get_object":       c.GetObject,
    46  			"put_object":       c.PutObject,
    47  			"list_objects":     c.ListObjects,
    48  			"delete_object":    c.DeleteObject,
    49  			"delete_recursive": c.DeleteRecursive,
    50  		}
    51  		for name, goFn := range functions {
    52  			l.PushGoFunction(goFn)
    53  			l.SetField(-2, name)
    54  		}
    55  
    56  		return 1
    57  	}
    58  }
    59  
    60  type S3Client struct {
    61  	AccessKeyID     string
    62  	SecretAccessKey string
    63  	Endpoint        string
    64  	Region          string
    65  	ctx             context.Context
    66  }
    67  
    68  func (c *S3Client) client() *s3.Client {
    69  	cfg, err := config.LoadDefaultConfig(c.ctx,
    70  		config.WithRegion(c.Region),
    71  		config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, "")),
    72  	)
    73  	if err != nil {
    74  		panic(err)
    75  	}
    76  	return s3.NewFromConfig(cfg, func(o *s3.Options) {
    77  		if c.Endpoint != "" {
    78  			o.BaseEndpoint = aws.String(c.Endpoint)
    79  		}
    80  	})
    81  }
    82  
    83  func (c *S3Client) DeleteRecursive(l *lua.State) int {
    84  	bucketName := lua.CheckString(l, 1)
    85  	prefix := lua.CheckString(l, 2)
    86  
    87  	client := c.client()
    88  	input := &s3.ListObjectsV2Input{
    89  		Bucket: aws.String(bucketName),
    90  		Prefix: aws.String(prefix),
    91  	}
    92  
    93  	var errs error
    94  	for {
    95  		// list objects to delete and delete them
    96  		listObjects, err := client.ListObjectsV2(c.ctx, input)
    97  		if err != nil {
    98  			lua.Errorf(l, "%s", err.Error())
    99  			panic("unreachable")
   100  		}
   101  
   102  		deleteInput := &s3.DeleteObjectsInput{
   103  			Bucket: &bucketName,
   104  			Delete: &types.Delete{},
   105  		}
   106  		for _, content := range listObjects.Contents {
   107  			deleteInput.Delete.Objects = append(deleteInput.Delete.Objects, types.ObjectIdentifier{Key: content.Key})
   108  		}
   109  		deleteObjects, err := client.DeleteObjects(c.ctx, deleteInput)
   110  		if err != nil {
   111  			errs = errors.Join(errs, err)
   112  			break
   113  		}
   114  		for _, deleteError := range deleteObjects.Errors {
   115  			errDel := fmt.Errorf("%w '%s', %s",
   116  				errDeleteObject, aws.ToString(deleteError.Key), aws.ToString(deleteError.Message))
   117  			errs = errors.Join(errs, errDel)
   118  		}
   119  
   120  		if !aws.ToBool(listObjects.IsTruncated) {
   121  			break
   122  		}
   123  		input.ContinuationToken = listObjects.NextContinuationToken
   124  	}
   125  	if errs != nil {
   126  		lua.Errorf(l, "%s", errs.Error())
   127  		panic("unreachable")
   128  	}
   129  	return 0
   130  }
   131  
   132  func (c *S3Client) GetObject(l *lua.State) int {
   133  	client := c.client()
   134  	key := lua.CheckString(l, 2)
   135  	bucket := lua.CheckString(l, 1)
   136  	resp, err := client.GetObject(c.ctx, &s3.GetObjectInput{
   137  		Bucket: aws.String(bucket),
   138  		Key:    aws.String(key),
   139  	})
   140  	if err != nil {
   141  		var (
   142  			noSuchBucket *types.NoSuchBucket
   143  			noSuchKey    *types.NoSuchKey
   144  		)
   145  		if errors.As(err, &noSuchBucket) || errors.As(err, &noSuchKey) {
   146  			l.PushString("")
   147  			l.PushBoolean(false) // exists
   148  			return 2
   149  		}
   150  		lua.Errorf(l, "%s", err.Error())
   151  		panic("unreachable")
   152  	}
   153  	defer resp.Body.Close()
   154  	data, err := io.ReadAll(resp.Body)
   155  	if err != nil {
   156  		lua.Errorf(l, "%s", err.Error())
   157  		panic("unreachable")
   158  	}
   159  	l.PushString(string(data))
   160  	l.PushBoolean(true) // exists
   161  	return 2
   162  }
   163  
   164  func (c *S3Client) PutObject(l *lua.State) int {
   165  	client := c.client()
   166  	buf := strings.NewReader(lua.CheckString(l, 3))
   167  	_, err := client.PutObject(c.ctx, &s3.PutObjectInput{
   168  		Body:   buf,
   169  		Bucket: aws.String(lua.CheckString(l, 1)),
   170  		Key:    aws.String(lua.CheckString(l, 2)),
   171  	})
   172  	if err != nil {
   173  		lua.Errorf(l, "%s", err.Error())
   174  		panic("unreachable")
   175  	}
   176  	return 0
   177  }
   178  
   179  func (c *S3Client) DeleteObject(l *lua.State) int {
   180  	client := c.client()
   181  	_, err := client.DeleteObject(c.ctx, &s3.DeleteObjectInput{
   182  		Bucket: aws.String(lua.CheckString(l, 1)),
   183  		Key:    aws.String(lua.CheckString(l, 2)),
   184  	})
   185  	if err != nil {
   186  		lua.Errorf(l, "%s", err.Error())
   187  		panic("unreachable")
   188  	}
   189  	return 0
   190  }
   191  
   192  func (c *S3Client) ListObjects(l *lua.State) int {
   193  	client := c.client()
   194  
   195  	var prefix, delimiter, continuationToken *string
   196  	if !l.IsNone(2) {
   197  		prefix = aws.String(lua.CheckString(l, 2))
   198  	}
   199  	if !l.IsNone(3) {
   200  		continuationToken = aws.String(lua.CheckString(l, 3))
   201  	}
   202  	if !l.IsNone(4) {
   203  		delimiter = aws.String(lua.CheckString(l, 4))
   204  	} else {
   205  		delimiter = aws.String("/")
   206  	}
   207  
   208  	resp, err := client.ListObjectsV2(c.ctx, &s3.ListObjectsV2Input{
   209  		Bucket:            aws.String(lua.CheckString(l, 1)),
   210  		ContinuationToken: continuationToken,
   211  		Delimiter:         delimiter,
   212  		Prefix:            prefix,
   213  	})
   214  	if err != nil {
   215  		lua.Errorf(l, "%s", err.Error())
   216  		panic("unreachable")
   217  	}
   218  	results := make([]map[string]interface{}, 0)
   219  	for _, prefix := range resp.CommonPrefixes {
   220  		results = append(results, map[string]interface{}{
   221  			"key":  *prefix.Prefix,
   222  			"type": "prefix",
   223  		})
   224  	}
   225  	for _, obj := range resp.Contents {
   226  		results = append(results, map[string]interface{}{
   227  			"key":           *obj.Key,
   228  			"type":          "object",
   229  			"etag":          *obj.ETag,
   230  			"size":          obj.Size,
   231  			"last_modified": obj.LastModified.Format(time.RFC3339),
   232  		})
   233  	}
   234  
   235  	// sort it
   236  	sort.Slice(results, func(i, j int) bool {
   237  		return results[i]["key"].(string) > results[j]["key"].(string)
   238  	})
   239  
   240  	response := map[string]interface{}{
   241  		"is_truncated":            resp.IsTruncated,
   242  		"next_continuation_token": aws.ToString(resp.NextContinuationToken),
   243  		"results":                 results,
   244  	}
   245  
   246  	return util.DeepPush(l, response)
   247  }