github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/actions/lua/storage/aws/glue.go (about)

     1  package aws
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  
     8  	"github.com/Shopify/go-lua"
     9  	"github.com/aws/aws-sdk-go-v2/aws"
    10  	"github.com/aws/aws-sdk-go-v2/config"
    11  	"github.com/aws/aws-sdk-go-v2/credentials"
    12  	"github.com/aws/aws-sdk-go-v2/service/glue"
    13  	"github.com/aws/aws-sdk-go-v2/service/glue/types"
    14  	"github.com/treeverse/lakefs/pkg/actions/lua/util"
    15  )
    16  
    17  func newGlueClient(ctx context.Context) lua.Function {
    18  	return func(l *lua.State) int {
    19  		accessKeyID := lua.CheckString(l, 1)
    20  		secretAccessKey := lua.CheckString(l, 2)
    21  		var region string
    22  		if !l.IsNone(3) {
    23  			region = lua.CheckString(l, 3)
    24  		}
    25  		var endpoint string
    26  		if !l.IsNone(4) {
    27  			endpoint = lua.CheckString(l, 4)
    28  		}
    29  		c := &GlueClient{
    30  			AccessKeyID:     accessKeyID,
    31  			SecretAccessKey: secretAccessKey,
    32  			Endpoint:        endpoint,
    33  			Region:          region,
    34  			ctx:             ctx,
    35  		}
    36  
    37  		l.NewTable()
    38  		for name, goFn := range glueFunctions {
    39  			// -1: tbl
    40  			l.PushGoFunction(goFn(c))
    41  			// -1: fn, -2:tbl
    42  			l.SetField(-2, name)
    43  		}
    44  
    45  		return 1
    46  	}
    47  }
    48  
    49  type GlueClient struct {
    50  	AccessKeyID     string
    51  	SecretAccessKey string
    52  	Endpoint        string
    53  	Region          string
    54  	ctx             context.Context
    55  }
    56  
    57  var glueFunctions = map[string]func(client *GlueClient) lua.Function{
    58  	"get_table":    getTable,
    59  	"create_table": createTable,
    60  	"update_table": updateTable,
    61  	"delete_table": deleteTable,
    62  }
    63  
    64  func (c *GlueClient) client() *glue.Client {
    65  	cfg, err := config.LoadDefaultConfig(c.ctx,
    66  		config.WithRegion(c.Region),
    67  		config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, "")),
    68  	)
    69  	if err != nil {
    70  		panic(err)
    71  	}
    72  	return glue.NewFromConfig(cfg, func(o *glue.Options) {
    73  		if c.Endpoint != "" {
    74  			o.BaseEndpoint = aws.String(c.Endpoint)
    75  		}
    76  	})
    77  }
    78  
    79  func deleteTable(c *GlueClient) lua.Function {
    80  	return func(l *lua.State) int {
    81  		client := c.client()
    82  		database := lua.CheckString(l, 1)
    83  		tableName := lua.CheckString(l, 2)
    84  		// check if catalog ID provided
    85  		var catalogID *string
    86  		if !l.IsNone(3) {
    87  			catalogID = aws.String(lua.CheckString(l, 3))
    88  		}
    89  		_, err := client.DeleteTable(c.ctx, &glue.DeleteTableInput{
    90  			DatabaseName: aws.String(database),
    91  			Name:         aws.String(tableName),
    92  			CatalogId:    catalogID,
    93  		})
    94  		if err != nil {
    95  			lua.Errorf(l, "%s", err.Error())
    96  			panic("unreachable")
    97  		}
    98  
    99  		return 0
   100  	}
   101  }
   102  
   103  func updateTable(c *GlueClient) lua.Function {
   104  	return func(l *lua.State) int {
   105  		client := c.client()
   106  		database := lua.CheckString(l, 1)
   107  		tableInputJSON := lua.CheckString(l, 2)
   108  
   109  		// check if catalog ID provided
   110  		var catalogID *string
   111  		if !l.IsNone(3) {
   112  			catalogID = aws.String(lua.CheckString(l, 3))
   113  		}
   114  
   115  		// version Id optional
   116  		var versionID *string
   117  		if !l.IsNone(4) {
   118  			versionID = aws.String(lua.CheckString(l, 4))
   119  		}
   120  
   121  		// glue skip-archive optional
   122  		var skipArchive *bool
   123  		if !l.IsNone(5) {
   124  			lua.CheckType(l, 5, lua.TypeBoolean)
   125  			skipArchive = aws.Bool(l.ToBoolean(5))
   126  		}
   127  
   128  		// parse table input JSON
   129  		var tableInput types.TableInput
   130  		err := json.Unmarshal([]byte(tableInputJSON), &tableInput)
   131  		if err != nil {
   132  			lua.Errorf(l, "%s", err.Error())
   133  			panic("unreachable")
   134  		}
   135  
   136  		_, err = client.UpdateTable(c.ctx, &glue.UpdateTableInput{
   137  			DatabaseName: &database,
   138  			TableInput:   &tableInput,
   139  			CatalogId:    catalogID,
   140  			VersionId:    versionID,
   141  			SkipArchive:  skipArchive,
   142  		})
   143  
   144  		if err != nil {
   145  			lua.Errorf(l, "%s", err.Error())
   146  			panic("unreachable")
   147  		}
   148  		return 0
   149  	}
   150  }
   151  
   152  func createTable(c *GlueClient) lua.Function {
   153  	return func(l *lua.State) int {
   154  		client := c.client()
   155  		database := lua.CheckString(l, 1)
   156  		tableInputJSON := lua.CheckString(l, 2)
   157  		// parse table input JSON
   158  		var tableInput types.TableInput
   159  		err := json.Unmarshal([]byte(tableInputJSON), &tableInput)
   160  		if err != nil {
   161  			lua.Errorf(l, "%s", err.Error())
   162  			panic("unreachable")
   163  		}
   164  		// check if catalog ID provided
   165  		var catalogID *string
   166  		if !l.IsNone(3) {
   167  			catalogID = aws.String(lua.CheckString(l, 3))
   168  		}
   169  		// TODO(isan) Additional input params: partition index and iceberg table format
   170  		// AWS API call
   171  		_, err = client.CreateTable(c.ctx, &glue.CreateTableInput{
   172  			DatabaseName: aws.String(database),
   173  			TableInput:   &tableInput,
   174  			CatalogId:    catalogID,
   175  		})
   176  		if err != nil {
   177  			lua.Errorf(l, "%s", err.Error())
   178  			panic("unreachable")
   179  		}
   180  		return 0
   181  	}
   182  }
   183  
   184  func getTable(c *GlueClient) lua.Function {
   185  	return func(l *lua.State) int {
   186  		client := c.client()
   187  		database := lua.CheckString(l, 1)
   188  		table := lua.CheckString(l, 2)
   189  		var catalogID *string
   190  		if !l.IsNone(3) {
   191  			catalogID = aws.String(lua.CheckString(l, 3))
   192  		}
   193  		resp, err := client.GetTable(c.ctx, &glue.GetTableInput{
   194  			DatabaseName: aws.String(database),
   195  			Name:         aws.String(table),
   196  			CatalogId:    catalogID,
   197  		})
   198  		if err != nil {
   199  			var notFoundErr *types.EntityNotFoundException
   200  			if errors.As(err, &notFoundErr) {
   201  				l.PushString("")
   202  				l.PushBoolean(false) // exists
   203  				return 2
   204  			}
   205  			lua.Errorf(l, "%s", err.Error())
   206  			panic("unreachable")
   207  		}
   208  		// Marshal the GetTableOutput struct to JSON.
   209  		jsonBytes, err := json.Marshal(resp)
   210  		if err != nil {
   211  			lua.Errorf(l, "%s", err.Error())
   212  			panic("unreachable")
   213  		}
   214  		// Unmarshal the JSON to a map.
   215  		var itemMap map[string]interface{}
   216  		err = json.Unmarshal(jsonBytes, &itemMap)
   217  		if err != nil {
   218  			lua.Errorf(l, "%s", err.Error())
   219  			panic("unreachable")
   220  		}
   221  		util.DeepPush(l, itemMap)
   222  		l.PushBoolean(true)
   223  		return 2
   224  	}
   225  }