github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/actions/lua/databricks/client.go (about)

     1  package databricks
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/url"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"github.com/Shopify/go-lua"
    12  	"github.com/databricks/databricks-sdk-go"
    13  	"github.com/databricks/databricks-sdk-go/config"
    14  	"github.com/databricks/databricks-sdk-go/service/catalog"
    15  	"github.com/databricks/databricks-sdk-go/service/sql"
    16  	luautil "github.com/treeverse/lakefs/pkg/actions/lua/util"
    17  )
    18  
    19  // identifierRegex https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html
    20  var identifierRegex = regexp.MustCompile(`\W`)
    21  
    22  var ErrInvalidTableName = errors.New("invalid table name")
    23  
    24  type Client struct {
    25  	workspaceClient *databricks.WorkspaceClient
    26  	ctx             context.Context
    27  }
    28  
    29  // validateTableName
    30  // https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html
    31  // https://docs.databricks.com/en/sql/language-manual/sql-ref-names.html
    32  func validateTableName(tableName string) error {
    33  	if identifierRegex.MatchString(tableName) {
    34  		return ErrInvalidTableName
    35  	}
    36  	if len(tableName) > 255 {
    37  		return ErrInvalidTableName
    38  	}
    39  	return nil
    40  }
    41  
    42  func validateTableLocation(tableLocation string) error {
    43  	_, err := url.ParseRequestURI(tableLocation)
    44  	return err
    45  }
    46  
    47  func validateTableInput(tableName, location string) error {
    48  	errName := validateTableName(tableName)
    49  	errLocation := validateTableLocation(location)
    50  	return errors.Join(errName, errLocation)
    51  }
    52  
    53  func (client *Client) createExternalTable(warehouseID, catalogName, schemaName, tableName, location string, metadata map[string]any) (string, error) {
    54  	if err := validateTableInput(tableName, location); err != nil {
    55  		return "", fmt.Errorf("external table \"%s\" creation failed: %w", tableName, err)
    56  	}
    57  	statement := fmt.Sprintf(`CREATE EXTERNAL TABLE %s LOCATION '%s'`, tableName, location)
    58  	if metadata != nil && metadata["description"] != "" {
    59  		if ms, ok := metadata["description"].(string); ok {
    60  			statement = fmt.Sprintf(`%s COMMENT '%s'`, statement, ms)
    61  		}
    62  	}
    63  	esr, err := client.workspaceClient.StatementExecution.ExecuteAndWait(client.ctx, sql.ExecuteStatementRequest{
    64  		WarehouseId: warehouseID,
    65  		Catalog:     catalogName,
    66  		Schema:      schemaName,
    67  		Statement:   statement,
    68  	})
    69  	if err != nil {
    70  		return "", fmt.Errorf("external table \"%s\" creation failed: %w", tableName, err)
    71  	}
    72  	return esr.Status.State.String(), nil
    73  }
    74  
    75  func tableFullName(catalogName, schemaName, tableName string) string {
    76  	return fmt.Sprintf("%s.%s.%s", catalogName, schemaName, tableName)
    77  }
    78  
    79  func (client *Client) deleteTable(catalogName, schemaName, tableName string) error {
    80  	err := client.workspaceClient.Tables.DeleteByFullName(client.ctx, tableFullName(catalogName, schemaName, tableName))
    81  	if err != nil {
    82  		return fmt.Errorf("failed deleting an existing table \"%s\": %w", tableName, err)
    83  	}
    84  	return nil
    85  }
    86  
    87  func (client *Client) createSchema(catalogName, schemaName string, getIfExists bool) (*catalog.SchemaInfo, error) {
    88  	schemaInfo, err := client.workspaceClient.Schemas.Create(client.ctx, catalog.CreateSchema{
    89  		Name:        schemaName,
    90  		CatalogName: catalogName,
    91  	})
    92  
    93  	if err == nil {
    94  		return schemaInfo, nil
    95  	}
    96  	if getIfExists && alreadyExists(err) {
    97  		// Full name of schema, in form of <catalog_name>.<schema_name>
    98  		schemaInfo, err = client.workspaceClient.Schemas.GetByFullName(client.ctx, catalogName+"."+schemaName)
    99  		if err == nil {
   100  			return schemaInfo, nil
   101  		}
   102  		return nil, fmt.Errorf("failed getting schema \"%s\": %w", schemaName, err)
   103  	}
   104  	return nil, fmt.Errorf("failed creating schema \"%s\": %w", schemaName, err)
   105  }
   106  
   107  func newDatabricksClient(l *lua.State) (*databricks.WorkspaceClient, error) {
   108  	host := lua.CheckString(l, 1)
   109  	token := lua.CheckString(l, 2)
   110  	return databricks.NewWorkspaceClient(
   111  		&databricks.Config{
   112  			Host:        host,
   113  			Token:       token,
   114  			Credentials: config.PatCredentials{},
   115  		},
   116  	)
   117  }
   118  
   119  func (client *Client) RegisterExternalTable(l *lua.State) int {
   120  	tableName := lua.CheckString(l, 1)
   121  	tableName = strings.ReplaceAll(tableName, "-", "_")
   122  	location := lua.CheckString(l, 2)
   123  	warehouseID := lua.CheckString(l, 3)
   124  	catalogName := lua.CheckString(l, 4)
   125  	schemaName := lua.CheckString(l, 5)
   126  	metadata, _ := luautil.PullTable(l, 6)
   127  	var metadataMap map[string]any
   128  	if metadata != nil {
   129  		metadataMap = metadata.(map[string]any)
   130  	}
   131  	status, err := client.createExternalTable(warehouseID, catalogName, schemaName, tableName, location, metadataMap)
   132  	if err != nil {
   133  		if alreadyExists(err) {
   134  			err = client.deleteTable(catalogName, schemaName, tableName)
   135  			if err != nil {
   136  				lua.Errorf(l, "%s", err.Error())
   137  				panic("unreachable")
   138  			}
   139  			status, err = client.createExternalTable(warehouseID, catalogName, schemaName, tableName, location, metadataMap)
   140  			if err != nil {
   141  				lua.Errorf(l, "%s", err.Error())
   142  				panic("unreachable")
   143  			}
   144  		} else {
   145  			lua.Errorf(l, "%s", err.Error())
   146  			panic("unreachable")
   147  		}
   148  	}
   149  	l.PushString(status)
   150  	return 1
   151  }
   152  
   153  func (client *Client) CreateSchema(l *lua.State) int {
   154  	ref := lua.CheckString(l, 1)
   155  	catalogName := lua.CheckString(l, 2)
   156  	getIfExists := l.ToBoolean(3)
   157  	schemaInfo, err := client.createSchema(catalogName, ref, getIfExists)
   158  	if err != nil {
   159  		lua.Errorf(l, "%s", err.Error())
   160  		panic("unreachable")
   161  	}
   162  	l.PushString(schemaInfo.Name)
   163  	return 1
   164  }
   165  
   166  func alreadyExists(e error) bool {
   167  	return strings.Contains(e.Error(), "already exists")
   168  }
   169  
   170  func newClient(ctx context.Context) lua.Function {
   171  	return func(l *lua.State) int {
   172  		workspaceClient, err := newDatabricksClient(l)
   173  		if err != nil {
   174  			lua.Errorf(l, "%s", err.Error())
   175  			panic("unreachable")
   176  		}
   177  		client := &Client{workspaceClient: workspaceClient, ctx: ctx}
   178  		l.NewTable()
   179  		functions := map[string]lua.Function{
   180  			"create_schema":           client.CreateSchema,
   181  			"register_external_table": client.RegisterExternalTable,
   182  		}
   183  		for name, goFn := range functions {
   184  			l.PushGoFunction(goFn)
   185  			l.SetField(-2, name)
   186  		}
   187  		return 1
   188  	}
   189  }