github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/actions/lua/storage/aws/glue.go (about) 1 package aws 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 8 "github.com/Shopify/go-lua" 9 "github.com/aws/aws-sdk-go-v2/aws" 10 "github.com/aws/aws-sdk-go-v2/config" 11 "github.com/aws/aws-sdk-go-v2/credentials" 12 "github.com/aws/aws-sdk-go-v2/service/glue" 13 "github.com/aws/aws-sdk-go-v2/service/glue/types" 14 "github.com/treeverse/lakefs/pkg/actions/lua/util" 15 ) 16 17 func newGlueClient(ctx context.Context) lua.Function { 18 return func(l *lua.State) int { 19 accessKeyID := lua.CheckString(l, 1) 20 secretAccessKey := lua.CheckString(l, 2) 21 var region string 22 if !l.IsNone(3) { 23 region = lua.CheckString(l, 3) 24 } 25 var endpoint string 26 if !l.IsNone(4) { 27 endpoint = lua.CheckString(l, 4) 28 } 29 c := &GlueClient{ 30 AccessKeyID: accessKeyID, 31 SecretAccessKey: secretAccessKey, 32 Endpoint: endpoint, 33 Region: region, 34 ctx: ctx, 35 } 36 37 l.NewTable() 38 for name, goFn := range glueFunctions { 39 // -1: tbl 40 l.PushGoFunction(goFn(c)) 41 // -1: fn, -2:tbl 42 l.SetField(-2, name) 43 } 44 45 return 1 46 } 47 } 48 49 type GlueClient struct { 50 AccessKeyID string 51 SecretAccessKey string 52 Endpoint string 53 Region string 54 ctx context.Context 55 } 56 57 var glueFunctions = map[string]func(client *GlueClient) lua.Function{ 58 "get_table": getTable, 59 "create_table": createTable, 60 "update_table": updateTable, 61 "delete_table": deleteTable, 62 } 63 64 func (c *GlueClient) client() *glue.Client { 65 cfg, err := config.LoadDefaultConfig(c.ctx, 66 config.WithRegion(c.Region), 67 config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, "")), 68 ) 69 if err != nil { 70 panic(err) 71 } 72 return glue.NewFromConfig(cfg, func(o *glue.Options) { 73 if c.Endpoint != "" { 74 o.BaseEndpoint = aws.String(c.Endpoint) 75 } 76 }) 77 } 78 79 func deleteTable(c *GlueClient) lua.Function { 80 return func(l *lua.State) int { 81 client := c.client() 82 database := lua.CheckString(l, 1) 83 tableName := lua.CheckString(l, 2) 84 // check if catalog ID provided 85 var catalogID *string 86 if !l.IsNone(3) { 87 catalogID = aws.String(lua.CheckString(l, 3)) 88 } 89 _, err := client.DeleteTable(c.ctx, &glue.DeleteTableInput{ 90 DatabaseName: aws.String(database), 91 Name: aws.String(tableName), 92 CatalogId: catalogID, 93 }) 94 if err != nil { 95 lua.Errorf(l, "%s", err.Error()) 96 panic("unreachable") 97 } 98 99 return 0 100 } 101 } 102 103 func updateTable(c *GlueClient) lua.Function { 104 return func(l *lua.State) int { 105 client := c.client() 106 database := lua.CheckString(l, 1) 107 tableInputJSON := lua.CheckString(l, 2) 108 109 // check if catalog ID provided 110 var catalogID *string 111 if !l.IsNone(3) { 112 catalogID = aws.String(lua.CheckString(l, 3)) 113 } 114 115 // version Id optional 116 var versionID *string 117 if !l.IsNone(4) { 118 versionID = aws.String(lua.CheckString(l, 4)) 119 } 120 121 // glue skip-archive optional 122 var skipArchive *bool 123 if !l.IsNone(5) { 124 lua.CheckType(l, 5, lua.TypeBoolean) 125 skipArchive = aws.Bool(l.ToBoolean(5)) 126 } 127 128 // parse table input JSON 129 var tableInput types.TableInput 130 err := json.Unmarshal([]byte(tableInputJSON), &tableInput) 131 if err != nil { 132 lua.Errorf(l, "%s", err.Error()) 133 panic("unreachable") 134 } 135 136 _, err = client.UpdateTable(c.ctx, &glue.UpdateTableInput{ 137 DatabaseName: &database, 138 TableInput: &tableInput, 139 CatalogId: catalogID, 140 VersionId: versionID, 141 SkipArchive: skipArchive, 142 }) 143 144 if err != nil { 145 lua.Errorf(l, "%s", err.Error()) 146 panic("unreachable") 147 } 148 return 0 149 } 150 } 151 152 func createTable(c *GlueClient) lua.Function { 153 return func(l *lua.State) int { 154 client := c.client() 155 database := lua.CheckString(l, 1) 156 tableInputJSON := lua.CheckString(l, 2) 157 // parse table input JSON 158 var tableInput types.TableInput 159 err := json.Unmarshal([]byte(tableInputJSON), &tableInput) 160 if err != nil { 161 lua.Errorf(l, "%s", err.Error()) 162 panic("unreachable") 163 } 164 // check if catalog ID provided 165 var catalogID *string 166 if !l.IsNone(3) { 167 catalogID = aws.String(lua.CheckString(l, 3)) 168 } 169 // TODO(isan) Additional input params: partition index and iceberg table format 170 // AWS API call 171 _, err = client.CreateTable(c.ctx, &glue.CreateTableInput{ 172 DatabaseName: aws.String(database), 173 TableInput: &tableInput, 174 CatalogId: catalogID, 175 }) 176 if err != nil { 177 lua.Errorf(l, "%s", err.Error()) 178 panic("unreachable") 179 } 180 return 0 181 } 182 } 183 184 func getTable(c *GlueClient) lua.Function { 185 return func(l *lua.State) int { 186 client := c.client() 187 database := lua.CheckString(l, 1) 188 table := lua.CheckString(l, 2) 189 var catalogID *string 190 if !l.IsNone(3) { 191 catalogID = aws.String(lua.CheckString(l, 3)) 192 } 193 resp, err := client.GetTable(c.ctx, &glue.GetTableInput{ 194 DatabaseName: aws.String(database), 195 Name: aws.String(table), 196 CatalogId: catalogID, 197 }) 198 if err != nil { 199 var notFoundErr *types.EntityNotFoundException 200 if errors.As(err, ¬FoundErr) { 201 l.PushString("") 202 l.PushBoolean(false) // exists 203 return 2 204 } 205 lua.Errorf(l, "%s", err.Error()) 206 panic("unreachable") 207 } 208 // Marshal the GetTableOutput struct to JSON. 209 jsonBytes, err := json.Marshal(resp) 210 if err != nil { 211 lua.Errorf(l, "%s", err.Error()) 212 panic("unreachable") 213 } 214 // Unmarshal the JSON to a map. 215 var itemMap map[string]interface{} 216 err = json.Unmarshal(jsonBytes, &itemMap) 217 if err != nil { 218 lua.Errorf(l, "%s", err.Error()) 219 panic("unreachable") 220 } 221 util.DeepPush(l, itemMap) 222 l.PushBoolean(true) 223 return 2 224 } 225 }