github.com/openfga/openfga@v1.5.4-rc1/pkg/storage/sqlcommon/sqlcommon.go (about) 1 package sqlcommon 2 3 import ( 4 "context" 5 "database/sql" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "strings" 10 "time" 11 12 sq "github.com/Masterminds/squirrel" 13 "github.com/go-sql-driver/mysql" 14 "github.com/oklog/ulid/v2" 15 openfgav1 "github.com/openfga/api/proto/openfga/v1" 16 "github.com/pressly/goose/v3" 17 "google.golang.org/protobuf/proto" 18 "google.golang.org/protobuf/types/known/structpb" 19 20 "github.com/openfga/openfga/internal/build" 21 "github.com/openfga/openfga/pkg/logger" 22 "github.com/openfga/openfga/pkg/storage" 23 tupleUtils "github.com/openfga/openfga/pkg/tuple" 24 ) 25 26 // Config defines the configuration parameters 27 // for setting up and managing a sql connection. 28 type Config struct { 29 Username string 30 Password string 31 Logger logger.Logger 32 MaxTuplesPerWriteField int 33 MaxTypesPerModelField int 34 35 MaxOpenConns int 36 MaxIdleConns int 37 ConnMaxIdleTime time.Duration 38 ConnMaxLifetime time.Duration 39 40 ExportMetrics bool 41 } 42 43 // DatastoreOption defines a function type 44 // used for configuring a Config object. 45 type DatastoreOption func(*Config) 46 47 // WithUsername returns a DatastoreOption that sets the username in the Config. 48 func WithUsername(username string) DatastoreOption { 49 return func(config *Config) { 50 config.Username = username 51 } 52 } 53 54 // WithPassword returns a DatastoreOption that sets the password in the Config. 55 func WithPassword(password string) DatastoreOption { 56 return func(config *Config) { 57 config.Password = password 58 } 59 } 60 61 // WithLogger returns a DatastoreOption that sets the Logger in the Config. 62 func WithLogger(l logger.Logger) DatastoreOption { 63 return func(cfg *Config) { 64 cfg.Logger = l 65 } 66 } 67 68 // WithMaxTuplesPerWrite returns a DatastoreOption that sets 69 // the maximum number of tuples per write in the Config. 70 func WithMaxTuplesPerWrite(maxTuples int) DatastoreOption { 71 return func(cfg *Config) { 72 cfg.MaxTuplesPerWriteField = maxTuples 73 } 74 } 75 76 // WithMaxTypesPerAuthorizationModel returns a DatastoreOption that sets 77 // the maximum number of types per authorization model in the Config. 78 func WithMaxTypesPerAuthorizationModel(maxTypes int) DatastoreOption { 79 return func(cfg *Config) { 80 cfg.MaxTypesPerModelField = maxTypes 81 } 82 } 83 84 // WithMaxOpenConns returns a DatastoreOption that sets the 85 // maximum number of open connections in the Config. 86 func WithMaxOpenConns(c int) DatastoreOption { 87 return func(cfg *Config) { 88 cfg.MaxOpenConns = c 89 } 90 } 91 92 // WithMaxIdleConns returns a DatastoreOption that sets the 93 // maximum number of idle connections in the Config. 94 func WithMaxIdleConns(c int) DatastoreOption { 95 return func(cfg *Config) { 96 cfg.MaxIdleConns = c 97 } 98 } 99 100 // WithConnMaxIdleTime returns a DatastoreOption that sets 101 // the maximum idle time for a connection in the Config. 102 func WithConnMaxIdleTime(d time.Duration) DatastoreOption { 103 return func(cfg *Config) { 104 cfg.ConnMaxIdleTime = d 105 } 106 } 107 108 // WithConnMaxLifetime returns a DatastoreOption that sets 109 // the maximum lifetime for a connection in the Config. 110 func WithConnMaxLifetime(d time.Duration) DatastoreOption { 111 return func(cfg *Config) { 112 cfg.ConnMaxLifetime = d 113 } 114 } 115 116 // WithMetrics returns a DatastoreOption that 117 // enables the export of metrics in the Config. 118 func WithMetrics() DatastoreOption { 119 return func(cfg *Config) { 120 cfg.ExportMetrics = true 121 } 122 } 123 124 // NewConfig creates a new Config instance with default values 125 // and applies any provided DatastoreOption modifications. 126 func NewConfig(opts ...DatastoreOption) *Config { 127 cfg := &Config{} 128 129 for _, opt := range opts { 130 opt(cfg) 131 } 132 133 if cfg.Logger == nil { 134 cfg.Logger = logger.NewNoopLogger() 135 } 136 137 if cfg.MaxTuplesPerWriteField == 0 { 138 cfg.MaxTuplesPerWriteField = storage.DefaultMaxTuplesPerWrite 139 } 140 141 if cfg.MaxTypesPerModelField == 0 { 142 cfg.MaxTypesPerModelField = storage.DefaultMaxTypesPerAuthorizationModel 143 } 144 145 return cfg 146 } 147 148 // ContToken represents a continuation token structure used in pagination. 149 type ContToken struct { 150 Ulid string `json:"ulid"` 151 ObjectType string `json:"ObjectType"` 152 } 153 154 // NewContToken creates a new instance of ContToken 155 // with the provided ULID and object type. 156 func NewContToken(ulid, objectType string) *ContToken { 157 return &ContToken{ 158 Ulid: ulid, 159 ObjectType: objectType, 160 } 161 } 162 163 // UnmarshallContToken takes a string representation of a continuation 164 // token and attempts to unmarshal it into a ContToken struct. 165 func UnmarshallContToken(from string) (*ContToken, error) { 166 var token ContToken 167 if err := json.Unmarshal([]byte(from), &token); err != nil { 168 return nil, storage.ErrInvalidContinuationToken 169 } 170 return &token, nil 171 } 172 173 // SQLTupleIterator is a struct that implements the storage.TupleIterator 174 // interface for iterating over tuples fetched from a SQL database. 175 type SQLTupleIterator struct { 176 rows *sql.Rows 177 resultCh chan *storage.TupleRecord 178 errCh chan error 179 } 180 181 // Ensures that SQLTupleIterator implements the TupleIterator interface. 182 var _ storage.TupleIterator = (*SQLTupleIterator)(nil) 183 184 // NewSQLTupleIterator returns a SQL tuple iterator. 185 func NewSQLTupleIterator(rows *sql.Rows) *SQLTupleIterator { 186 return &SQLTupleIterator{ 187 rows: rows, 188 resultCh: make(chan *storage.TupleRecord, 1), 189 errCh: make(chan error, 1), 190 } 191 } 192 193 func (t *SQLTupleIterator) next() (*storage.TupleRecord, error) { 194 if !t.rows.Next() { 195 if err := t.rows.Err(); err != nil { 196 return nil, err 197 } 198 return nil, storage.ErrIteratorDone 199 } 200 201 var conditionName sql.NullString 202 var conditionContext []byte 203 var record storage.TupleRecord 204 err := t.rows.Scan( 205 &record.Store, 206 &record.ObjectType, 207 &record.ObjectID, 208 &record.Relation, 209 &record.User, 210 &conditionName, 211 &conditionContext, 212 &record.Ulid, 213 &record.InsertedAt, 214 ) 215 if err != nil { 216 return nil, err 217 } 218 219 record.ConditionName = conditionName.String 220 221 if conditionContext != nil { 222 var conditionContextStruct structpb.Struct 223 if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil { 224 return nil, err 225 } 226 record.ConditionContext = &conditionContextStruct 227 } 228 229 return &record, nil 230 } 231 232 // ToArray converts the tupleIterator to an []*openfgav1.Tuple and a possibly empty continuation token. 233 // If the continuation token exists it is the ulid of the last element of the returned array. 234 func (t *SQLTupleIterator) ToArray( 235 opts storage.PaginationOptions, 236 ) ([]*openfgav1.Tuple, []byte, error) { 237 var res []*openfgav1.Tuple 238 for i := 0; i < opts.PageSize; i++ { 239 tupleRecord, err := t.next() 240 if err != nil { 241 if err == storage.ErrIteratorDone { 242 return res, nil, nil 243 } 244 return nil, nil, err 245 } 246 res = append(res, tupleRecord.AsTuple()) 247 } 248 249 // Check if we are at the end of the iterator. 250 // If we are then we do not need to return a continuation token. 251 // This is why we have LIMIT+1 in the query. 252 tupleRecord, err := t.next() 253 if err != nil { 254 if errors.Is(err, storage.ErrIteratorDone) { 255 return res, nil, nil 256 } 257 return nil, nil, err 258 } 259 260 contToken, err := json.Marshal(NewContToken(tupleRecord.Ulid, "")) 261 if err != nil { 262 return nil, nil, err 263 } 264 265 return res, contToken, nil 266 } 267 268 // Next will return the next available item. 269 func (t *SQLTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) { 270 if ctx.Err() != nil { 271 return nil, ctx.Err() 272 } 273 274 record, err := t.next() 275 if err != nil { 276 return nil, err 277 } 278 279 return record.AsTuple(), nil 280 } 281 282 // Stop terminates iteration. 283 func (t *SQLTupleIterator) Stop() { 284 t.rows.Close() 285 } 286 287 // HandleSQLError processes an SQL error and converts it into a more 288 // specific error type based on the nature of the SQL error. 289 func HandleSQLError(err error, args ...interface{}) error { 290 if errors.Is(err, sql.ErrNoRows) { 291 return storage.ErrNotFound 292 } else if errors.Is(err, storage.ErrIteratorDone) { 293 return err 294 } else if strings.Contains(err.Error(), "duplicate key value") { // Postgres. 295 if len(args) > 0 { 296 if tk, ok := args[0].(*openfgav1.TupleKey); ok { 297 return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE) 298 } 299 } 300 return storage.ErrCollision 301 } else if me, ok := err.(*mysql.MySQLError); ok && me.Number == 1062 { 302 if len(args) > 0 { 303 if tk, ok := args[0].(*openfgav1.TupleKey); ok { 304 return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE) 305 } 306 } 307 return storage.ErrCollision 308 } 309 310 return fmt.Errorf("sql error: %w", err) 311 } 312 313 // DBInfo encapsulates DB information for use in common method. 314 type DBInfo struct { 315 db *sql.DB 316 stbl sq.StatementBuilderType 317 sqlTime interface{} 318 } 319 320 // NewDBInfo constructs a [DBInfo] object. 321 func NewDBInfo(db *sql.DB, stbl sq.StatementBuilderType, sqlTime interface{}) *DBInfo { 322 return &DBInfo{ 323 db: db, 324 stbl: stbl, 325 sqlTime: sqlTime, 326 } 327 } 328 329 // Write provides the common method for writing to database across sql storage. 330 func Write( 331 ctx context.Context, 332 dbInfo *DBInfo, 333 store string, 334 deletes storage.Deletes, 335 writes storage.Writes, 336 now time.Time, 337 ) error { 338 txn, err := dbInfo.db.BeginTx(ctx, nil) 339 if err != nil { 340 return HandleSQLError(err) 341 } 342 defer func() { 343 _ = txn.Rollback() 344 }() 345 346 changelogBuilder := dbInfo.stbl. 347 Insert("changelog"). 348 Columns( 349 "store", "object_type", "object_id", "relation", "_user", 350 "condition_name", "condition_context", "operation", "ulid", "inserted_at", 351 ) 352 353 deleteBuilder := dbInfo.stbl.Delete("tuple") 354 355 for _, tk := range deletes { 356 id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String() 357 objectType, objectID := tupleUtils.SplitObject(tk.GetObject()) 358 359 res, err := deleteBuilder. 360 Where(sq.Eq{ 361 "store": store, 362 "object_type": objectType, 363 "object_id": objectID, 364 "relation": tk.GetRelation(), 365 "_user": tk.GetUser(), 366 "user_type": tupleUtils.GetUserTypeFromUser(tk.GetUser()), 367 }). 368 RunWith(txn). // Part of a txn. 369 ExecContext(ctx) 370 if err != nil { 371 return HandleSQLError(err, tk) 372 } 373 374 rowsAffected, err := res.RowsAffected() 375 if err != nil { 376 return HandleSQLError(err) 377 } 378 379 if rowsAffected != 1 { 380 return storage.InvalidWriteInputError( 381 tk, 382 openfgav1.TupleOperation_TUPLE_OPERATION_DELETE, 383 ) 384 } 385 386 changelogBuilder = changelogBuilder.Values( 387 store, objectType, objectID, 388 tk.GetRelation(), tk.GetUser(), 389 "", nil, // Redact condition info for deletes since we only need the base triplet (object, relation, user). 390 openfgav1.TupleOperation_TUPLE_OPERATION_DELETE, 391 id, dbInfo.sqlTime, 392 ) 393 } 394 395 insertBuilder := dbInfo.stbl. 396 Insert("tuple"). 397 Columns( 398 "store", "object_type", "object_id", "relation", "_user", "user_type", 399 "condition_name", "condition_context", "ulid", "inserted_at", 400 ) 401 402 for _, tk := range writes { 403 id := ulid.MustNew(ulid.Timestamp(now), ulid.DefaultEntropy()).String() 404 objectType, objectID := tupleUtils.SplitObject(tk.GetObject()) 405 406 conditionName, conditionContext, err := marshalRelationshipCondition(tk.GetCondition()) 407 if err != nil { 408 return err 409 } 410 411 _, err = insertBuilder. 412 Values( 413 store, 414 objectType, 415 objectID, 416 tk.GetRelation(), 417 tk.GetUser(), 418 tupleUtils.GetUserTypeFromUser(tk.GetUser()), 419 conditionName, 420 conditionContext, 421 id, 422 dbInfo.sqlTime, 423 ). 424 RunWith(txn). // Part of a txn. 425 ExecContext(ctx) 426 if err != nil { 427 return HandleSQLError(err, tk) 428 } 429 430 changelogBuilder = changelogBuilder.Values( 431 store, 432 objectType, 433 objectID, 434 tk.GetRelation(), 435 tk.GetUser(), 436 conditionName, 437 conditionContext, 438 openfgav1.TupleOperation_TUPLE_OPERATION_WRITE, 439 id, 440 dbInfo.sqlTime, 441 ) 442 } 443 444 if len(writes) > 0 || len(deletes) > 0 { 445 _, err := changelogBuilder.RunWith(txn).ExecContext(ctx) // Part of a txn. 446 if err != nil { 447 return HandleSQLError(err) 448 } 449 } 450 451 if err := txn.Commit(); err != nil { 452 return HandleSQLError(err) 453 } 454 455 return nil 456 } 457 458 // WriteAuthorizationModel writes an authorization model for the given store. 459 func WriteAuthorizationModel( 460 ctx context.Context, 461 dbInfo *DBInfo, 462 store string, 463 model *openfgav1.AuthorizationModel, 464 ) error { 465 schemaVersion := model.GetSchemaVersion() 466 typeDefinitions := model.GetTypeDefinitions() 467 468 if len(typeDefinitions) < 1 { 469 return nil 470 } 471 472 pbdata, err := proto.Marshal(model) 473 if err != nil { 474 return err 475 } 476 477 _, err = dbInfo.stbl. 478 Insert("authorization_model"). 479 Columns("store", "authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf"). 480 Values(store, model.GetId(), schemaVersion, "", nil, pbdata). 481 ExecContext(ctx) 482 if err != nil { 483 return HandleSQLError(err) 484 } 485 486 return nil 487 } 488 489 func constructAuthorizationModelFromSQLRows(rows *sql.Rows) (*openfgav1.AuthorizationModel, error) { 490 var modelID string 491 var schemaVersion string 492 var typeDefs []*openfgav1.TypeDefinition 493 for rows.Next() { 494 var typeName string 495 var marshalledTypeDef []byte 496 var marshalledModel []byte 497 err := rows.Scan(&modelID, &schemaVersion, &typeName, &marshalledTypeDef, &marshalledModel) 498 if err != nil { 499 return nil, HandleSQLError(err) 500 } 501 502 if len(marshalledModel) > 0 { 503 // Prefer building an authorization model from the first row that has it available. 504 var model openfgav1.AuthorizationModel 505 if err := proto.Unmarshal(marshalledModel, &model); err != nil { 506 return nil, err 507 } 508 509 return &model, nil 510 } 511 512 var typeDef openfgav1.TypeDefinition 513 if err := proto.Unmarshal(marshalledTypeDef, &typeDef); err != nil { 514 return nil, err 515 } 516 517 typeDefs = append(typeDefs, &typeDef) 518 } 519 520 if err := rows.Err(); err != nil { 521 return nil, HandleSQLError(err) 522 } 523 524 if len(typeDefs) == 0 { 525 return nil, storage.ErrNotFound 526 } 527 528 return &openfgav1.AuthorizationModel{ 529 SchemaVersion: schemaVersion, 530 Id: modelID, 531 TypeDefinitions: typeDefs, 532 }, nil 533 } 534 535 // FindLatestAuthorizationModel reads the latest authorization model corresponding to the store. 536 func FindLatestAuthorizationModel( 537 ctx context.Context, 538 dbInfo *DBInfo, 539 store string, 540 ) (*openfgav1.AuthorizationModel, error) { 541 rows, err := dbInfo.stbl. 542 Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf"). 543 From("authorization_model"). 544 Where(sq.Eq{"store": store}). 545 OrderBy("authorization_model_id desc"). 546 Limit(1). 547 QueryContext(ctx) 548 if err != nil { 549 return nil, HandleSQLError(err) 550 } 551 defer rows.Close() 552 return constructAuthorizationModelFromSQLRows(rows) 553 } 554 555 // ReadAuthorizationModel reads the model corresponding to store and model ID. 556 func ReadAuthorizationModel( 557 ctx context.Context, 558 dbInfo *DBInfo, 559 store, modelID string, 560 ) (*openfgav1.AuthorizationModel, error) { 561 rows, err := dbInfo.stbl. 562 Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf"). 563 From("authorization_model"). 564 Where(sq.Eq{ 565 "store": store, 566 "authorization_model_id": modelID, 567 }). 568 QueryContext(ctx) 569 if err != nil { 570 return nil, HandleSQLError(err) 571 } 572 defer rows.Close() 573 return constructAuthorizationModelFromSQLRows(rows) 574 } 575 576 // IsReady returns true if the connection to the datastore is successful 577 // and the datastore has the latest migration applied. 578 func IsReady(ctx context.Context, db *sql.DB) (storage.ReadinessStatus, error) { 579 ctx, cancel := context.WithTimeout(ctx, 2*time.Second) 580 defer cancel() 581 582 if err := db.PingContext(ctx); err != nil { 583 return storage.ReadinessStatus{}, err 584 } 585 586 revision, err := goose.GetDBVersion(db) 587 if err != nil { 588 return storage.ReadinessStatus{}, err 589 } 590 591 if revision < build.MinimumSupportedDatastoreSchemaRevision { 592 return storage.ReadinessStatus{ 593 Message: fmt.Sprintf("datastore requires migrations: at revision '%d', but requires '%d'. Run 'openfga migrate'.", revision, build.MinimumSupportedDatastoreSchemaRevision), 594 IsReady: false, 595 }, nil 596 } 597 598 return storage.ReadinessStatus{ 599 IsReady: true, 600 }, nil 601 }