github.com/cs3org/reva/v2@v2.27.7/pkg/ocm/invite/repository/sql/sql.go (about)

     1  // Copyright 2018-2023 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package sql
    20  
    21  import (
    22  	"context"
    23  	"database/sql"
    24  	"fmt"
    25  	"time"
    26  
    27  	gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
    28  	userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
    29  	invitepb "github.com/cs3org/go-cs3apis/cs3/ocm/invite/v1beta1"
    30  	types "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
    31  	conversions "github.com/cs3org/reva/v2/pkg/cbox/utils"
    32  	"github.com/cs3org/reva/v2/pkg/errtypes"
    33  	"github.com/cs3org/reva/v2/pkg/ocm/invite"
    34  	"github.com/cs3org/reva/v2/pkg/rgrpc/todo/pool"
    35  	"github.com/cs3org/reva/v2/pkg/utils/cfg"
    36  	"github.com/go-sql-driver/mysql"
    37  
    38  	"github.com/cs3org/reva/v2/pkg/ocm/invite/repository/registry"
    39  	"github.com/cs3org/reva/v2/pkg/sharedconf"
    40  	"github.com/pkg/errors"
    41  )
    42  
    43  // This module implement the invite.Repository interface as a mysql driver.
    44  //
    45  // The OCM Invitation tokens are saved in the table:
    46  //     ocm_tokens(*token*, initiator, expiration, description)
    47  //
    48  // The OCM remote user are saved in the table:
    49  //     ocm_remote_users(*initiator*, *opaque_user_id*, *idp*, email, display_name)
    50  
    51  func init() {
    52  	registry.Register("sql", New)
    53  }
    54  
    55  type mgr struct {
    56  	c      *config
    57  	db     *sql.DB
    58  	client gatewayv1beta1.GatewayAPIClient
    59  }
    60  
    61  type config struct {
    62  	DBUsername string `mapstructure:"db_username"`
    63  	DBPassword string `mapstructure:"db_password"`
    64  	DBAddress  string `mapstructure:"db_address"`
    65  	DBName     string `mapstructure:"db_name"`
    66  	GatewaySvc string `mapstructure:"gatewaysvc"`
    67  }
    68  
    69  func (c *config) ApplyDefaults() {
    70  	c.GatewaySvc = sharedconf.GetGatewaySVC(c.GatewaySvc)
    71  }
    72  
    73  // New creates a sql repository for ocm tokens and users.
    74  func New(m map[string]interface{}) (invite.Repository, error) {
    75  	var c config
    76  	if err := cfg.Decode(m, &c); err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", c.DBUsername, c.DBPassword, c.DBAddress, c.DBName))
    81  	if err != nil {
    82  		return nil, errors.Wrap(err, "sql: error opening connection to mysql database")
    83  	}
    84  
    85  	gw, err := pool.GetGatewayServiceClient(c.GatewaySvc)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	mgr := mgr{
    91  		c:      &c,
    92  		db:     db,
    93  		client: gw,
    94  	}
    95  	return &mgr, nil
    96  }
    97  
    98  // AddToken stores the token in the repository.
    99  func (m *mgr) AddToken(ctx context.Context, token *invitepb.InviteToken) error {
   100  	query := "INSERT INTO ocm_tokens SET token=?,initiator=?,expiration=?,description=?"
   101  	_, err := m.db.ExecContext(ctx, query, token.Token, conversions.FormatUserID(token.UserId), timestampToTime(token.Expiration), token.Description)
   102  	return err
   103  }
   104  
   105  func timestampToTime(t *types.Timestamp) time.Time {
   106  	return time.Unix(int64(t.Seconds), int64(t.Nanos))
   107  }
   108  
   109  type dbToken struct {
   110  	Token       string
   111  	Initiator   string
   112  	Expiration  time.Time
   113  	Description string
   114  }
   115  
   116  // GetToken gets the token from the repository.
   117  func (m *mgr) GetToken(ctx context.Context, token string) (*invitepb.InviteToken, error) {
   118  	query := "SELECT token, initiator, expiration, description FROM ocm_tokens where token=?"
   119  
   120  	var tkn dbToken
   121  	if err := m.db.QueryRowContext(ctx, query, token).Scan(&tkn.Token, &tkn.Initiator, &tkn.Expiration, &tkn.Description); err != nil {
   122  		if errors.Is(err, sql.ErrNoRows) {
   123  			return nil, invite.ErrTokenNotFound
   124  		}
   125  		return nil, err
   126  	}
   127  	return m.convertToInviteToken(ctx, tkn)
   128  }
   129  
   130  func (m *mgr) convertToInviteToken(ctx context.Context, tkn dbToken) (*invitepb.InviteToken, error) {
   131  	user, err := conversions.ExtractUserID(ctx, m.client, tkn.Initiator)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	return &invitepb.InviteToken{
   136  		Token:  tkn.Token,
   137  		UserId: user,
   138  		Expiration: &types.Timestamp{
   139  			Seconds: uint64(tkn.Expiration.Unix()),
   140  		},
   141  		Description: tkn.Description,
   142  	}, nil
   143  }
   144  
   145  func (m *mgr) ListTokens(ctx context.Context, initiator *userpb.UserId) ([]*invitepb.InviteToken, error) {
   146  	query := "SELECT token, initiator, expiration, description FROM ocm_tokens WHERE initiator=? AND expiration > NOW()"
   147  
   148  	tokens := []*invitepb.InviteToken{}
   149  	rows, err := m.db.QueryContext(ctx, query, conversions.FormatUserID(initiator))
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	var tkn dbToken
   155  	for rows.Next() {
   156  		if err := rows.Scan(&tkn.Token, &tkn.Initiator, &tkn.Expiration, &tkn.Description); err != nil {
   157  			continue
   158  		}
   159  		token, err := m.convertToInviteToken(ctx, tkn)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		tokens = append(tokens, token)
   164  	}
   165  
   166  	return tokens, nil
   167  }
   168  
   169  // AddRemoteUser stores the remote user.
   170  func (m *mgr) AddRemoteUser(ctx context.Context, initiator *userpb.UserId, remoteUser *userpb.User) error {
   171  	query := "INSERT INTO ocm_remote_users SET initiator=?, opaque_user_id=?, idp=?, email=?, display_name=?"
   172  	if _, err := m.db.ExecContext(ctx, query, conversions.FormatUserID(initiator), conversions.FormatUserID(remoteUser.Id), remoteUser.Id.Idp, remoteUser.Mail, remoteUser.DisplayName); err != nil {
   173  		// check if the user already exist in the db
   174  		// https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html#error_er_dup_entry
   175  		var e *mysql.MySQLError
   176  		if errors.As(err, &e) && e.Number == 1062 {
   177  			return invite.ErrUserAlreadyAccepted
   178  		}
   179  		return err
   180  	}
   181  	return nil
   182  }
   183  
   184  type dbOCMUser struct {
   185  	OpaqueUserID string
   186  	Idp          string
   187  	Email        string
   188  	DisplayName  string
   189  }
   190  
   191  // GetRemoteUser retrieves details about a remote user who has accepted an invite to share.
   192  func (m *mgr) GetRemoteUser(ctx context.Context, initiator *userpb.UserId, remoteUserID *userpb.UserId) (*userpb.User, error) {
   193  	query := "SELECT opaque_user_id, idp, email, display_name FROM ocm_remote_users WHERE initiator=? AND opaque_user_id=? AND idp=?"
   194  
   195  	var user dbOCMUser
   196  	if err := m.db.QueryRowContext(ctx, query, conversions.FormatUserID(initiator), conversions.FormatUserID(remoteUserID), remoteUserID.Idp).
   197  		Scan(&user.OpaqueUserID, &user.Idp, &user.Email, &user.DisplayName); err != nil {
   198  		if errors.Is(err, sql.ErrNoRows) {
   199  			return nil, errtypes.NotFound(remoteUserID.OpaqueId)
   200  		}
   201  		return nil, err
   202  	}
   203  	return user.toCS3User(), nil
   204  }
   205  
   206  func (u *dbOCMUser) toCS3User() *userpb.User {
   207  	return &userpb.User{
   208  		Id: &userpb.UserId{
   209  			Idp:      u.Idp,
   210  			OpaqueId: u.OpaqueUserID,
   211  			Type:     userpb.UserType_USER_TYPE_FEDERATED,
   212  		},
   213  		Mail:        u.Email,
   214  		DisplayName: u.DisplayName,
   215  	}
   216  }
   217  
   218  // FindRemoteUsers finds remote users who have accepted invites based on their attributes.
   219  func (m *mgr) FindRemoteUsers(ctx context.Context, initiator *userpb.UserId, attr string) ([]*userpb.User, error) {
   220  	// TODO: (gdelmont) this query can get really slow in case the number of rows is too high.
   221  	// For the time being this is not expected, but if in future this happens, consider to add
   222  	// a fulltext index.
   223  	query := "SELECT opaque_user_id, idp, email, display_name FROM ocm_remote_users WHERE initiator=? AND (opaque_user_id LIKE ? OR idp LIKE ? OR email LIKE ? OR display_name LIKE ?)"
   224  	s := "%" + attr + "%"
   225  	params := []any{conversions.FormatUserID(initiator), s, s, s, s}
   226  
   227  	rows, err := m.db.QueryContext(ctx, query, params...)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	var u dbOCMUser
   233  	var users []*userpb.User
   234  	for rows.Next() {
   235  		if err := rows.Scan(&u.OpaqueUserID, &u.Idp, &u.Email, &u.DisplayName); err != nil {
   236  			continue
   237  		}
   238  		users = append(users, u.toCS3User())
   239  	}
   240  	if err := rows.Err(); err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	return users, nil
   245  }
   246  
   247  func (m *mgr) DeleteRemoteUser(ctx context.Context, initiator *userpb.UserId, remoteUser *userpb.UserId) error {
   248  	query := "DELETE FROM ocm_remote_users WHERE initiator=? AND opaque_user_id=? AND idp=?"
   249  	_, err := m.db.ExecContext(ctx, query, conversions.FormatUserID(initiator), conversions.FormatUserID(remoteUser), remoteUser.Idp)
   250  	return err
   251  }