github.com/supabase/cli@v1.168.1/internal/db/lint/lint.go (about)

     1  package lint
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"strings"
    11  
    12  	"github.com/go-errors/errors"
    13  	"github.com/jackc/pgconn"
    14  	"github.com/jackc/pgx/v4"
    15  	"github.com/spf13/afero"
    16  	"github.com/supabase/cli/internal/db/reset"
    17  	"github.com/supabase/cli/internal/utils"
    18  )
    19  
    20  const ENABLE_PGSQL_CHECK = "CREATE EXTENSION IF NOT EXISTS plpgsql_check"
    21  
    22  var (
    23  	AllowedLevels = []string{
    24  		"warning",
    25  		"error",
    26  	}
    27  	//go:embed templates/check.sql
    28  	checkSchemaScript string
    29  )
    30  
    31  type LintLevel int
    32  
    33  func toEnum(level string) LintLevel {
    34  	for i, curr := range AllowedLevels {
    35  		if strings.HasPrefix(level, curr) {
    36  			return LintLevel(i)
    37  		}
    38  	}
    39  	return -1
    40  }
    41  
    42  func Run(ctx context.Context, schema []string, level string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    43  	// Sanity checks.
    44  	conn, err := utils.ConnectByConfig(ctx, config, options...)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	defer conn.Close(context.Background())
    49  	// Run lint script
    50  	result, err := LintDatabase(ctx, conn, schema)
    51  	if err != nil {
    52  		return err
    53  	}
    54  	if len(result) == 0 {
    55  		fmt.Fprintln(os.Stderr, "\nNo schema errors found")
    56  		return nil
    57  	}
    58  	return printResultJSON(result, toEnum(level), os.Stdout)
    59  }
    60  
    61  func filterResult(result []Result, minLevel LintLevel) (filtered []Result) {
    62  	for _, r := range result {
    63  		out := Result{Function: r.Function}
    64  		for _, issue := range r.Issues {
    65  			if toEnum(issue.Level) >= minLevel {
    66  				out.Issues = append(out.Issues, issue)
    67  			}
    68  		}
    69  		if len(out.Issues) > 0 {
    70  			filtered = append(filtered, out)
    71  		}
    72  	}
    73  	return filtered
    74  }
    75  
    76  func printResultJSON(result []Result, minLevel LintLevel, stdout io.Writer) error {
    77  	filtered := filterResult(result, minLevel)
    78  	if len(filtered) == 0 {
    79  		return nil
    80  	}
    81  	// Pretty print output
    82  	enc := json.NewEncoder(stdout)
    83  	enc.SetIndent("", "  ")
    84  	if err := enc.Encode(filtered); err != nil {
    85  		return errors.Errorf("failed to print result json: %w", err)
    86  	}
    87  	return nil
    88  }
    89  
    90  func LintDatabase(ctx context.Context, conn *pgx.Conn, schema []string) ([]Result, error) {
    91  	tx, err := conn.Begin(ctx)
    92  	if err != nil {
    93  		return nil, errors.Errorf("failed to begin transaction: %w", err)
    94  	}
    95  	if len(schema) == 0 {
    96  		schema, err = reset.LoadUserSchemas(ctx, conn)
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  	// Always rollback since lint should not have side effects
   102  	defer func() {
   103  		if err := tx.Rollback(context.Background()); err != nil {
   104  			fmt.Fprintln(os.Stderr, err)
   105  		}
   106  	}()
   107  	if _, err := conn.Exec(ctx, ENABLE_PGSQL_CHECK); err != nil {
   108  		return nil, errors.Errorf("failed to enable pgsql_check: %w", err)
   109  	}
   110  	// Batch prepares statements
   111  	batch := pgx.Batch{}
   112  	for _, s := range schema {
   113  		batch.Queue(checkSchemaScript, s)
   114  	}
   115  	br := conn.SendBatch(ctx, &batch)
   116  	defer br.Close()
   117  	var result []Result
   118  	for _, s := range schema {
   119  		fmt.Fprintln(os.Stderr, "Linting schema:", s)
   120  		rows, err := br.Query()
   121  		if err != nil {
   122  			return nil, errors.Errorf("failed to query rows: %w", err)
   123  		}
   124  		// Parse result row
   125  		for rows.Next() {
   126  			var name string
   127  			var data []byte
   128  			if err := rows.Scan(&name, &data); err != nil {
   129  				return nil, errors.Errorf("failed to scan rows: %w", err)
   130  			}
   131  			var r Result
   132  			if err := json.Unmarshal(data, &r); err != nil {
   133  				return nil, errors.Errorf("failed to marshal json: %w", err)
   134  			}
   135  			// Update function name
   136  			r.Function = s + "." + name
   137  			result = append(result, r)
   138  		}
   139  		err = rows.Err()
   140  		if err != nil {
   141  			return nil, errors.Errorf("failed to parse rows: %w", err)
   142  		}
   143  	}
   144  	return result, nil
   145  }
   146  
   147  type Query struct {
   148  	Position string `json:"position"`
   149  	Text     string `json:"text"`
   150  }
   151  
   152  type Statement struct {
   153  	LineNumber string `json:"lineNumber"`
   154  	Text       string `json:"text"`
   155  }
   156  
   157  type Issue struct {
   158  	Level     string     `json:"level"`
   159  	Message   string     `json:"message"`
   160  	Statement *Statement `json:"statement,omitempty"`
   161  	Query     *Query     `json:"query,omitempty"`
   162  	Hint      string     `json:"hint,omitempty"`
   163  	Detail    string     `json:"detail,omitempty"`
   164  	Context   string     `json:"context,omitempty"`
   165  	SQLState  string     `json:"sqlState,omitempty"`
   166  }
   167  
   168  type Result struct {
   169  	Function string  `json:"function"`
   170  	Issues   []Issue `json:"issues"`
   171  }