github.com/openfga/openfga@v1.5.4-rc1/pkg/storage/sqlcommon/sqlcommon.go (about)

     1  package sqlcommon
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"strings"
    10  	"time"
    11  
    12  	sq "github.com/Masterminds/squirrel"
    13  	"github.com/go-sql-driver/mysql"
    14  	"github.com/oklog/ulid/v2"
    15  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    16  	"github.com/pressly/goose/v3"
    17  	"google.golang.org/protobuf/proto"
    18  	"google.golang.org/protobuf/types/known/structpb"
    19  
    20  	"github.com/openfga/openfga/internal/build"
    21  	"github.com/openfga/openfga/pkg/logger"
    22  	"github.com/openfga/openfga/pkg/storage"
    23  	tupleUtils "github.com/openfga/openfga/pkg/tuple"
    24  )
    25  
    26  // Config defines the configuration parameters
    27  // for setting up and managing a sql connection.
    28  type Config struct {
    29  	Username               string
    30  	Password               string
    31  	Logger                 logger.Logger
    32  	MaxTuplesPerWriteField int
    33  	MaxTypesPerModelField  int
    34  
    35  	MaxOpenConns    int
    36  	MaxIdleConns    int
    37  	ConnMaxIdleTime time.Duration
    38  	ConnMaxLifetime time.Duration
    39  
    40  	ExportMetrics bool
    41  }
    42  
    43  // DatastoreOption defines a function type
    44  // used for configuring a Config object.
    45  type DatastoreOption func(*Config)
    46  
    47  // WithUsername returns a DatastoreOption that sets the username in the Config.
    48  func WithUsername(username string) DatastoreOption {
    49  	return func(config *Config) {
    50  		config.Username = username
    51  	}
    52  }
    53  
    54  // WithPassword returns a DatastoreOption that sets the password in the Config.
    55  func WithPassword(password string) DatastoreOption {
    56  	return func(config *Config) {
    57  		config.Password = password
    58  	}
    59  }
    60  
    61  // WithLogger returns a DatastoreOption that sets the Logger in the Config.
    62  func WithLogger(l logger.Logger) DatastoreOption {
    63  	return func(cfg *Config) {
    64  		cfg.Logger = l
    65  	}
    66  }
    67  
    68  // WithMaxTuplesPerWrite returns a DatastoreOption that sets
    69  // the maximum number of tuples per write in the Config.
    70  func WithMaxTuplesPerWrite(maxTuples int) DatastoreOption {
    71  	return func(cfg *Config) {
    72  		cfg.MaxTuplesPerWriteField = maxTuples
    73  	}
    74  }
    75  
    76  // WithMaxTypesPerAuthorizationModel returns a DatastoreOption that sets
    77  // the maximum number of types per authorization model in the Config.
    78  func WithMaxTypesPerAuthorizationModel(maxTypes int) DatastoreOption {
    79  	return func(cfg *Config) {
    80  		cfg.MaxTypesPerModelField = maxTypes
    81  	}
    82  }
    83  
    84  // WithMaxOpenConns returns a DatastoreOption that sets the
    85  // maximum number of open connections in the Config.
    86  func WithMaxOpenConns(c int) DatastoreOption {
    87  	return func(cfg *Config) {
    88  		cfg.MaxOpenConns = c
    89  	}
    90  }
    91  
    92  // WithMaxIdleConns returns a DatastoreOption that sets the
    93  // maximum number of idle connections in the Config.
    94  func WithMaxIdleConns(c int) DatastoreOption {
    95  	return func(cfg *Config) {
    96  		cfg.MaxIdleConns = c
    97  	}
    98  }
    99  
   100  // WithConnMaxIdleTime returns a DatastoreOption that sets
   101  // the maximum idle time for a connection in the Config.
   102  func WithConnMaxIdleTime(d time.Duration) DatastoreOption {
   103  	return func(cfg *Config) {
   104  		cfg.ConnMaxIdleTime = d
   105  	}
   106  }
   107  
   108  // WithConnMaxLifetime returns a DatastoreOption that sets
   109  // the maximum lifetime for a connection in the Config.
   110  func WithConnMaxLifetime(d time.Duration) DatastoreOption {
   111  	return func(cfg *Config) {
   112  		cfg.ConnMaxLifetime = d
   113  	}
   114  }
   115  
   116  // WithMetrics returns a DatastoreOption that
   117  // enables the export of metrics in the Config.
   118  func WithMetrics() DatastoreOption {
   119  	return func(cfg *Config) {
   120  		cfg.ExportMetrics = true
   121  	}
   122  }
   123  
   124  // NewConfig creates a new Config instance with default values
   125  // and applies any provided DatastoreOption modifications.
   126  func NewConfig(opts ...DatastoreOption) *Config {
   127  	cfg := &Config{}
   128  
   129  	for _, opt := range opts {
   130  		opt(cfg)
   131  	}
   132  
   133  	if cfg.Logger == nil {
   134  		cfg.Logger = logger.NewNoopLogger()
   135  	}
   136  
   137  	if cfg.MaxTuplesPerWriteField == 0 {
   138  		cfg.MaxTuplesPerWriteField = storage.DefaultMaxTuplesPerWrite
   139  	}
   140  
   141  	if cfg.MaxTypesPerModelField == 0 {
   142  		cfg.MaxTypesPerModelField = storage.DefaultMaxTypesPerAuthorizationModel
   143  	}
   144  
   145  	return cfg
   146  }
   147  
   148  // ContToken represents a continuation token structure used in pagination.
   149  type ContToken struct {
   150  	Ulid       string `json:"ulid"`
   151  	ObjectType string `json:"ObjectType"`
   152  }
   153  
   154  // NewContToken creates a new instance of ContToken
   155  // with the provided ULID and object type.
   156  func NewContToken(ulid, objectType string) *ContToken {
   157  	return &ContToken{
   158  		Ulid:       ulid,
   159  		ObjectType: objectType,
   160  	}
   161  }
   162  
   163  // UnmarshallContToken takes a string representation of a continuation
   164  // token and attempts to unmarshal it into a ContToken struct.
   165  func UnmarshallContToken(from string) (*ContToken, error) {
   166  	var token ContToken
   167  	if err := json.Unmarshal([]byte(from), &token); err != nil {
   168  		return nil, storage.ErrInvalidContinuationToken
   169  	}
   170  	return &token, nil
   171  }
   172  
   173  // SQLTupleIterator is a struct that implements the storage.TupleIterator
   174  // interface for iterating over tuples fetched from a SQL database.
   175  type SQLTupleIterator struct {
   176  	rows     *sql.Rows
   177  	resultCh chan *storage.TupleRecord
   178  	errCh    chan error
   179  }
   180  
   181  // Ensures that SQLTupleIterator implements the TupleIterator interface.
   182  var _ storage.TupleIterator = (*SQLTupleIterator)(nil)
   183  
   184  // NewSQLTupleIterator returns a SQL tuple iterator.
   185  func NewSQLTupleIterator(rows *sql.Rows) *SQLTupleIterator {
   186  	return &SQLTupleIterator{
   187  		rows:     rows,
   188  		resultCh: make(chan *storage.TupleRecord, 1),
   189  		errCh:    make(chan error, 1),
   190  	}
   191  }
   192  
   193  func (t *SQLTupleIterator) next() (*storage.TupleRecord, error) {
   194  	if !t.rows.Next() {
   195  		if err := t.rows.Err(); err != nil {
   196  			return nil, err
   197  		}
   198  		return nil, storage.ErrIteratorDone
   199  	}
   200  
   201  	var conditionName sql.NullString
   202  	var conditionContext []byte
   203  	var record storage.TupleRecord
   204  	err := t.rows.Scan(
   205  		&record.Store,
   206  		&record.ObjectType,
   207  		&record.ObjectID,
   208  		&record.Relation,
   209  		&record.User,
   210  		&conditionName,
   211  		&conditionContext,
   212  		&record.Ulid,
   213  		&record.InsertedAt,
   214  	)
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	record.ConditionName = conditionName.String
   220  
   221  	if conditionContext != nil {
   222  		var conditionContextStruct structpb.Struct
   223  		if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
   224  			return nil, err
   225  		}
   226  		record.ConditionContext = &conditionContextStruct
   227  	}
   228  
   229  	return &record, nil
   230  }
   231  
   232  // ToArray converts the tupleIterator to an []*openfgav1.Tuple and a possibly empty continuation token.
   233  // If the continuation token exists it is the ulid of the last element of the returned array.
   234  func (t *SQLTupleIterator) ToArray(
   235  	opts storage.PaginationOptions,
   236  ) ([]*openfgav1.Tuple, []byte, error) {
   237  	var res []*openfgav1.Tuple
   238  	for i := 0; i < opts.PageSize; i++ {
   239  		tupleRecord, err := t.next()
   240  		if err != nil {
   241  			if err == storage.ErrIteratorDone {
   242  				return res, nil, nil
   243  			}
   244  			return nil, nil, err
   245  		}
   246  		res = append(res, tupleRecord.AsTuple())
   247  	}
   248  
   249  	// Check if we are at the end of the iterator.
   250  	// If we are then we do not need to return a continuation token.
   251  	// This is why we have LIMIT+1 in the query.
   252  	tupleRecord, err := t.next()
   253  	if err != nil {
   254  		if errors.Is(err, storage.ErrIteratorDone) {
   255  			return res, nil, nil
   256  		}
   257  		return nil, nil, err
   258  	}
   259  
   260  	contToken, err := json.Marshal(NewContToken(tupleRecord.Ulid, ""))
   261  	if err != nil {
   262  		return nil, nil, err
   263  	}
   264  
   265  	return res, contToken, nil
   266  }
   267  
   268  // Next will return the next available item.
   269  func (t *SQLTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
   270  	if ctx.Err() != nil {
   271  		return nil, ctx.Err()
   272  	}
   273  
   274  	record, err := t.next()
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	return record.AsTuple(), nil
   280  }
   281  
   282  // Stop terminates iteration.
   283  func (t *SQLTupleIterator) Stop() {
   284  	t.rows.Close()
   285  }
   286  
   287  // HandleSQLError processes an SQL error and converts it into a more
   288  // specific error type based on the nature of the SQL error.
   289  func HandleSQLError(err error, args ...interface{}) error {
   290  	if errors.Is(err, sql.ErrNoRows) {
   291  		return storage.ErrNotFound
   292  	} else if errors.Is(err, storage.ErrIteratorDone) {
   293  		return err
   294  	} else if strings.Contains(err.Error(), "duplicate key value") { // Postgres.
   295  		if len(args) > 0 {
   296  			if tk, ok := args[0].(*openfgav1.TupleKey); ok {
   297  				return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
   298  			}
   299  		}
   300  		return storage.ErrCollision
   301  	} else if me, ok := err.(*mysql.MySQLError); ok && me.Number == 1062 {
   302  		if len(args) > 0 {
   303  			if tk, ok := args[0].(*openfgav1.TupleKey); ok {
   304  				return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
   305  			}
   306  		}
   307  		return storage.ErrCollision
   308  	}
   309  
   310  	return fmt.Errorf("sql error: %w", err)
   311  }
   312  
   313  // DBInfo encapsulates DB information for use in common method.
   314  type DBInfo struct {
   315  	db      *sql.DB
   316  	stbl    sq.StatementBuilderType
   317  	sqlTime interface{}
   318  }
   319  
   320  // NewDBInfo constructs a [DBInfo] object.
   321  func NewDBInfo(db *sql.DB, stbl sq.StatementBuilderType, sqlTime interface{}) *DBInfo {
   322  	return &DBInfo{
   323  		db:      db,
   324  		stbl:    stbl,
   325  		sqlTime: sqlTime,
   326  	}
   327  }
   328  
   329  // Write provides the common method for writing to database across sql storage.
   330  func Write(
   331  	ctx context.Context,
   332  	dbInfo *DBInfo,
   333  	store string,
   334  	deletes storage.Deletes,
   335  	writes storage.Writes,
   336  	now time.Time,
   337  ) error {
   338  	txn, err := dbInfo.db.BeginTx(ctx, nil)
   339  	if err != nil {
   340  		return HandleSQLError(err)
   341  	}
   342  	defer func() {
   343  		_ = txn.Rollback()
   344  	}()
   345  
   346  	changelogBuilder := dbInfo.stbl.
   347  		Insert("changelog").
   348  		Columns(
   349  			"store", "object_type", "object_id", "relation", "_user",
   350  			"condition_name", "condition_context", "operation", "ulid", "inserted_at",
   351  		)
   352  
   353  	deleteBuilder := dbInfo.stbl.Delete("tuple")
   354  
   355  	for _, tk := range deletes {
   356  		id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String()
   357  		objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
   358  
   359  		res, err := deleteBuilder.
   360  			Where(sq.Eq{
   361  				"store":       store,
   362  				"object_type": objectType,
   363  				"object_id":   objectID,
   364  				"relation":    tk.GetRelation(),
   365  				"_user":       tk.GetUser(),
   366  				"user_type":   tupleUtils.GetUserTypeFromUser(tk.GetUser()),
   367  			}).
   368  			RunWith(txn). // Part of a txn.
   369  			ExecContext(ctx)
   370  		if err != nil {
   371  			return HandleSQLError(err, tk)
   372  		}
   373  
   374  		rowsAffected, err := res.RowsAffected()
   375  		if err != nil {
   376  			return HandleSQLError(err)
   377  		}
   378  
   379  		if rowsAffected != 1 {
   380  			return storage.InvalidWriteInputError(
   381  				tk,
   382  				openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
   383  			)
   384  		}
   385  
   386  		changelogBuilder = changelogBuilder.Values(
   387  			store, objectType, objectID,
   388  			tk.GetRelation(), tk.GetUser(),
   389  			"", nil, // Redact condition info for deletes since we only need the base triplet (object, relation, user).
   390  			openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
   391  			id, dbInfo.sqlTime,
   392  		)
   393  	}
   394  
   395  	insertBuilder := dbInfo.stbl.
   396  		Insert("tuple").
   397  		Columns(
   398  			"store", "object_type", "object_id", "relation", "_user", "user_type",
   399  			"condition_name", "condition_context", "ulid", "inserted_at",
   400  		)
   401  
   402  	for _, tk := range writes {
   403  		id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String()
   404  		objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
   405  
   406  		conditionName, conditionContext, err := marshalRelationshipCondition(tk.GetCondition())
   407  		if err != nil {
   408  			return err
   409  		}
   410  
   411  		_, err = insertBuilder.
   412  			Values(
   413  				store,
   414  				objectType,
   415  				objectID,
   416  				tk.GetRelation(),
   417  				tk.GetUser(),
   418  				tupleUtils.GetUserTypeFromUser(tk.GetUser()),
   419  				conditionName,
   420  				conditionContext,
   421  				id,
   422  				dbInfo.sqlTime,
   423  			).
   424  			RunWith(txn). // Part of a txn.
   425  			ExecContext(ctx)
   426  		if err != nil {
   427  			return HandleSQLError(err, tk)
   428  		}
   429  
   430  		changelogBuilder = changelogBuilder.Values(
   431  			store,
   432  			objectType,
   433  			objectID,
   434  			tk.GetRelation(),
   435  			tk.GetUser(),
   436  			conditionName,
   437  			conditionContext,
   438  			openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
   439  			id,
   440  			dbInfo.sqlTime,
   441  		)
   442  	}
   443  
   444  	if len(writes) > 0 || len(deletes) > 0 {
   445  		_, err := changelogBuilder.RunWith(txn).ExecContext(ctx) // Part of a txn.
   446  		if err != nil {
   447  			return HandleSQLError(err)
   448  		}
   449  	}
   450  
   451  	if err := txn.Commit(); err != nil {
   452  		return HandleSQLError(err)
   453  	}
   454  
   455  	return nil
   456  }
   457  
   458  // WriteAuthorizationModel writes an authorization model for the given store.
   459  func WriteAuthorizationModel(
   460  	ctx context.Context,
   461  	dbInfo *DBInfo,
   462  	store string,
   463  	model *openfgav1.AuthorizationModel,
   464  ) error {
   465  	schemaVersion := model.GetSchemaVersion()
   466  	typeDefinitions := model.GetTypeDefinitions()
   467  
   468  	if len(typeDefinitions) < 1 {
   469  		return nil
   470  	}
   471  
   472  	pbdata, err := proto.Marshal(model)
   473  	if err != nil {
   474  		return err
   475  	}
   476  
   477  	_, err = dbInfo.stbl.
   478  		Insert("authorization_model").
   479  		Columns("store", "authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
   480  		Values(store, model.GetId(), schemaVersion, "", nil, pbdata).
   481  		ExecContext(ctx)
   482  	if err != nil {
   483  		return HandleSQLError(err)
   484  	}
   485  
   486  	return nil
   487  }
   488  
   489  func constructAuthorizationModelFromSQLRows(rows *sql.Rows) (*openfgav1.AuthorizationModel, error) {
   490  	var modelID string
   491  	var schemaVersion string
   492  	var typeDefs []*openfgav1.TypeDefinition
   493  	for rows.Next() {
   494  		var typeName string
   495  		var marshalledTypeDef []byte
   496  		var marshalledModel []byte
   497  		err := rows.Scan(&modelID, &schemaVersion, &typeName, &marshalledTypeDef, &marshalledModel)
   498  		if err != nil {
   499  			return nil, HandleSQLError(err)
   500  		}
   501  
   502  		if len(marshalledModel) > 0 {
   503  			// Prefer building an authorization model from the first row that has it available.
   504  			var model openfgav1.AuthorizationModel
   505  			if err := proto.Unmarshal(marshalledModel, &model); err != nil {
   506  				return nil, err
   507  			}
   508  
   509  			return &model, nil
   510  		}
   511  
   512  		var typeDef openfgav1.TypeDefinition
   513  		if err := proto.Unmarshal(marshalledTypeDef, &typeDef); err != nil {
   514  			return nil, err
   515  		}
   516  
   517  		typeDefs = append(typeDefs, &typeDef)
   518  	}
   519  
   520  	if err := rows.Err(); err != nil {
   521  		return nil, HandleSQLError(err)
   522  	}
   523  
   524  	if len(typeDefs) == 0 {
   525  		return nil, storage.ErrNotFound
   526  	}
   527  
   528  	return &openfgav1.AuthorizationModel{
   529  		SchemaVersion:   schemaVersion,
   530  		Id:              modelID,
   531  		TypeDefinitions: typeDefs,
   532  	}, nil
   533  }
   534  
   535  // FindLatestAuthorizationModel reads the latest authorization model corresponding to the store.
   536  func FindLatestAuthorizationModel(
   537  	ctx context.Context,
   538  	dbInfo *DBInfo,
   539  	store string,
   540  ) (*openfgav1.AuthorizationModel, error) {
   541  	rows, err := dbInfo.stbl.
   542  		Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
   543  		From("authorization_model").
   544  		Where(sq.Eq{"store": store}).
   545  		OrderBy("authorization_model_id desc").
   546  		Limit(1).
   547  		QueryContext(ctx)
   548  	if err != nil {
   549  		return nil, HandleSQLError(err)
   550  	}
   551  	defer rows.Close()
   552  	return constructAuthorizationModelFromSQLRows(rows)
   553  }
   554  
   555  // ReadAuthorizationModel reads the model corresponding to store and model ID.
   556  func ReadAuthorizationModel(
   557  	ctx context.Context,
   558  	dbInfo *DBInfo,
   559  	store, modelID string,
   560  ) (*openfgav1.AuthorizationModel, error) {
   561  	rows, err := dbInfo.stbl.
   562  		Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
   563  		From("authorization_model").
   564  		Where(sq.Eq{
   565  			"store":                  store,
   566  			"authorization_model_id": modelID,
   567  		}).
   568  		QueryContext(ctx)
   569  	if err != nil {
   570  		return nil, HandleSQLError(err)
   571  	}
   572  	defer rows.Close()
   573  	return constructAuthorizationModelFromSQLRows(rows)
   574  }
   575  
   576  // IsReady returns true if the connection to the datastore is successful
   577  // and the datastore has the latest migration applied.
   578  func IsReady(ctx context.Context, db *sql.DB) (storage.ReadinessStatus, error) {
   579  	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
   580  	defer cancel()
   581  
   582  	if err := db.PingContext(ctx); err != nil {
   583  		return storage.ReadinessStatus{}, err
   584  	}
   585  
   586  	revision, err := goose.GetDBVersion(db)
   587  	if err != nil {
   588  		return storage.ReadinessStatus{}, err
   589  	}
   590  
   591  	if revision < build.MinimumSupportedDatastoreSchemaRevision {
   592  		return storage.ReadinessStatus{
   593  			Message: fmt.Sprintf("datastore requires migrations: at revision '%d', but requires '%d'. Run 'openfga migrate'.", revision, build.MinimumSupportedDatastoreSchemaRevision),
   594  			IsReady: false,
   595  		}, nil
   596  	}
   597  
   598  	return storage.ReadinessStatus{
   599  		IsReady: true,
   600  	}, nil
   601  }