github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/db/lint/lint.go (about)

     1  package lint
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"strings"
    11  
    12  	"github.com/Redstoneguy129/cli/internal/db/diff"
    13  	"github.com/Redstoneguy129/cli/internal/utils"
    14  	"github.com/jackc/pgx/v4"
    15  	"github.com/spf13/afero"
    16  )
    17  
    18  const ENABLE_PGSQL_CHECK = "CREATE EXTENSION IF NOT EXISTS plpgsql_check"
    19  
    20  var (
    21  	AllowedLevels = []string{
    22  		"warning",
    23  		"error",
    24  	}
    25  	//go:embed templates/check.sql
    26  	checkSchemaScript string
    27  )
    28  
    29  type LintLevel int
    30  
    31  func toEnum(level string) LintLevel {
    32  	for i, curr := range AllowedLevels {
    33  		if strings.HasPrefix(level, curr) {
    34  			return LintLevel(i)
    35  		}
    36  	}
    37  	return -1
    38  }
    39  
    40  func Run(ctx context.Context, schema []string, level string, fsys afero.Fs, opts ...func(*pgx.ConnConfig)) error {
    41  	// Sanity checks.
    42  	if err := utils.LoadConfigFS(fsys); err != nil {
    43  		return err
    44  	}
    45  	if err := utils.AssertSupabaseDbIsRunning(); err != nil {
    46  		return err
    47  	}
    48  	// Run lint script
    49  	conn, err := utils.ConnectLocalPostgres(ctx, utils.Config.Hostname, utils.Config.Db.Port, "postgres", opts...)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	defer conn.Close(context.Background())
    54  	result, err := LintDatabase(ctx, conn, schema)
    55  	if err != nil {
    56  		return err
    57  	}
    58  	return printResultJSON(result, toEnum(level), os.Stdout)
    59  }
    60  
    61  func filterResult(result []Result, minLevel LintLevel) (filtered []Result) {
    62  	for _, r := range result {
    63  		out := Result{Function: r.Function}
    64  		for _, issue := range r.Issues {
    65  			if toEnum(issue.Level) >= minLevel {
    66  				out.Issues = append(out.Issues, issue)
    67  			}
    68  		}
    69  		if len(out.Issues) > 0 {
    70  			filtered = append(filtered, out)
    71  		}
    72  	}
    73  	return filtered
    74  }
    75  
    76  func printResultJSON(result []Result, minLevel LintLevel, stdout io.Writer) error {
    77  	filtered := filterResult(result, minLevel)
    78  	if len(filtered) == 0 {
    79  		return nil
    80  	}
    81  	// Pretty print output
    82  	enc := json.NewEncoder(stdout)
    83  	enc.SetIndent("", "  ")
    84  	return enc.Encode(filtered)
    85  }
    86  
    87  func LintDatabase(ctx context.Context, conn *pgx.Conn, schema []string) ([]Result, error) {
    88  	tx, err := conn.Begin(ctx)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if len(schema) == 0 {
    93  		schema, err = diff.LoadUserSchemas(ctx, conn, utils.InternalSchemas...)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  	}
    98  	// Always rollback since lint should not have side effects
    99  	defer func() {
   100  		if err := tx.Rollback(context.Background()); err != nil {
   101  			fmt.Fprintln(os.Stderr, err)
   102  		}
   103  	}()
   104  	if _, err := conn.Exec(ctx, ENABLE_PGSQL_CHECK); err != nil {
   105  		return nil, err
   106  	}
   107  	// Batch prepares statements
   108  	batch := pgx.Batch{}
   109  	for _, s := range schema {
   110  		batch.Queue(checkSchemaScript, s)
   111  	}
   112  	br := conn.SendBatch(ctx, &batch)
   113  	defer br.Close()
   114  	var result []Result
   115  	for _, s := range schema {
   116  		fmt.Fprintln(os.Stderr, "Linting schema:", s)
   117  		rows, err := br.Query()
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  		// Parse result row
   122  		for rows.Next() {
   123  			var name string
   124  			var data []byte
   125  			if err := rows.Scan(&name, &data); err != nil {
   126  				return nil, err
   127  			}
   128  			var r Result
   129  			if err := json.Unmarshal(data, &r); err != nil {
   130  				return nil, err
   131  			}
   132  			// Update function name
   133  			r.Function = s + "." + name
   134  			result = append(result, r)
   135  		}
   136  		err = rows.Err()
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  	}
   141  	return result, nil
   142  }
   143  
   144  type Query struct {
   145  	Position string `json:"position"`
   146  	Text     string `json:"text"`
   147  }
   148  
   149  type Statement struct {
   150  	LineNumber string `json:"lineNumber"`
   151  	Text       string `json:"text"`
   152  }
   153  
   154  type Issue struct {
   155  	Level     string     `json:"level"`
   156  	Message   string     `json:"message"`
   157  	Statement *Statement `json:"statement,omitempty"`
   158  	Query     *Query     `json:"query,omitempty"`
   159  	Hint      string     `json:"hint,omitempty"`
   160  	Detail    string     `json:"detail,omitempty"`
   161  	Context   string     `json:"context,omitempty"`
   162  	SQLState  string     `json:"sqlState,omitempty"`
   163  }
   164  
   165  type Result struct {
   166  	Function string  `json:"function"`
   167  	Issues   []Issue `json:"issues"`
   168  }