github.com/cs3org/reva/v2@v2.27.7/pkg/user/manager/owncloudsql/accounts/accounts.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package accounts
    20  
    21  import (
    22  	"context"
    23  	"database/sql"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/cs3org/reva/v2/pkg/appctx"
    28  	"github.com/pkg/errors"
    29  )
    30  
    31  // Accounts represents oc10-style Accounts
    32  type Accounts struct {
    33  	driver                                     string
    34  	db                                         *sql.DB
    35  	joinUsername, joinUUID, enableMedialSearch bool
    36  	selectSQL                                  string
    37  }
    38  
    39  // NewMysql returns a new Cache instance connecting to a MySQL database
    40  func NewMysql(dsn string, joinUsername, joinUUID, enableMedialSearch bool) (*Accounts, error) {
    41  	sqldb, err := sql.Open("mysql", dsn)
    42  	if err != nil {
    43  		return nil, errors.Wrap(err, "error connecting to the database")
    44  	}
    45  
    46  	// FIXME make configurable
    47  	sqldb.SetConnMaxLifetime(time.Minute * 3)
    48  	sqldb.SetConnMaxIdleTime(time.Second * 30)
    49  	sqldb.SetMaxOpenConns(100)
    50  	sqldb.SetMaxIdleConns(10)
    51  
    52  	err = sqldb.Ping()
    53  	if err != nil {
    54  		return nil, errors.Wrap(err, "error connecting to the database")
    55  	}
    56  
    57  	return New("mysql", sqldb, joinUsername, joinUUID, enableMedialSearch)
    58  }
    59  
    60  // New returns a new Cache instance connecting to the given sql.DB
    61  func New(driver string, sqldb *sql.DB, joinUsername, joinUUID, enableMedialSearch bool) (*Accounts, error) {
    62  
    63  	sel := "SELECT id, email, user_id, display_name, quota, last_login, backend, home, state"
    64  	from := `
    65  		FROM oc_accounts a
    66  		`
    67  	if joinUsername {
    68  		sel += ", p.configvalue AS username"
    69  		from += `LEFT JOIN oc_preferences p
    70  						ON a.user_id=p.userid
    71  						AND p.appid='core'
    72  						AND p.configkey='username'`
    73  	} else {
    74  		// fallback to user_id as username
    75  		sel += ", user_id AS username"
    76  	}
    77  	if joinUUID {
    78  		sel += ", p2.configvalue AS ownclouduuid"
    79  		from += `LEFT JOIN oc_preferences p2
    80  						ON a.user_id=p2.userid
    81  						AND p2.appid='core'
    82  						AND p2.configkey='ownclouduuid'`
    83  	} else {
    84  		// fallback to user_id as ownclouduuid
    85  		sel += ", user_id AS ownclouduuid"
    86  	}
    87  
    88  	return &Accounts{
    89  		driver:             driver,
    90  		db:                 sqldb,
    91  		joinUsername:       joinUsername,
    92  		joinUUID:           joinUUID,
    93  		enableMedialSearch: enableMedialSearch,
    94  		selectSQL:          sel + from,
    95  	}, nil
    96  }
    97  
    98  // Account stores information about accounts.
    99  type Account struct {
   100  	ID           uint64
   101  	Email        sql.NullString
   102  	UserID       string
   103  	DisplayName  sql.NullString
   104  	Quota        sql.NullString
   105  	LastLogin    int
   106  	Backend      string
   107  	Home         string
   108  	State        int8
   109  	Username     sql.NullString // optional comes from the oc_preferences
   110  	OwnCloudUUID sql.NullString // optional comes from the oc_preferences
   111  }
   112  
   113  func (as *Accounts) rowToAccount(ctx context.Context, row Scannable) (*Account, error) {
   114  	a := Account{}
   115  	if err := row.Scan(&a.ID, &a.Email, &a.UserID, &a.DisplayName, &a.Quota, &a.LastLogin, &a.Backend, &a.Home, &a.State, &a.Username, &a.OwnCloudUUID); err != nil {
   116  		appctx.GetLogger(ctx).Error().Err(err).Msg("could not scan row, skipping")
   117  		return nil, err
   118  	}
   119  
   120  	return &a, nil
   121  }
   122  
   123  // Scannable describes the interface providing a Scan method
   124  type Scannable interface {
   125  	Scan(...interface{}) error
   126  }
   127  
   128  // GetAccountByClaim fetches an account by mail, username or userid
   129  func (as *Accounts) GetAccountByClaim(ctx context.Context, claim, value string) (*Account, error) {
   130  	// TODO align supported claims with rest driver and the others, maybe refactor into common mapping
   131  	var row *sql.Row
   132  	var where string
   133  	switch claim {
   134  	case "mail":
   135  		where = "WHERE a.email=?"
   136  	// case "uid":
   137  	//	claim = m.c.Schema.UIDNumber
   138  	// case "gid":
   139  	//	claim = m.c.Schema.GIDNumber
   140  	case "username":
   141  		if as.joinUsername {
   142  			where = "WHERE p.configvalue=?"
   143  		} else {
   144  			// use user_id as username
   145  			where = "WHERE a.user_id=?"
   146  		}
   147  	case "userid":
   148  		if as.joinUUID {
   149  			where = "WHERE p2.configvalue=?"
   150  		} else {
   151  			// use user_id as uuid
   152  			where = "WHERE a.user_id=?"
   153  		}
   154  	default:
   155  		return nil, errors.New("owncloudsql: invalid field " + claim)
   156  	}
   157  
   158  	row = as.db.QueryRowContext(ctx, as.selectSQL+where, value)
   159  
   160  	return as.rowToAccount(ctx, row)
   161  }
   162  
   163  func sanitizeWildcards(q string) string {
   164  	return strings.ReplaceAll(strings.ReplaceAll(q, "%", `\%`), "_", `\_`)
   165  }
   166  
   167  // FindAccounts searches userid, displayname and email using the given query. The Wildcard caracters % and _ are escaped.
   168  func (as *Accounts) FindAccounts(ctx context.Context, query string) ([]Account, error) {
   169  	if as.enableMedialSearch {
   170  		query = "%" + sanitizeWildcards(query) + "%"
   171  	}
   172  	// TODO join oc_account_terms
   173  	where := "WHERE a.user_id LIKE ? OR a.display_name LIKE ? OR a.email LIKE ?"
   174  	args := []interface{}{query, query, query}
   175  
   176  	if as.joinUsername {
   177  		where += " OR p.configvalue LIKE ?"
   178  		args = append(args, query)
   179  	}
   180  	if as.joinUUID {
   181  		where += " OR p2.configvalue LIKE ?"
   182  		args = append(args, query)
   183  	}
   184  
   185  	rows, err := as.db.QueryContext(ctx, as.selectSQL+where, args...)
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	defer rows.Close()
   190  
   191  	accounts := []Account{}
   192  	for rows.Next() {
   193  		a := Account{}
   194  		if err := rows.Scan(&a.ID, &a.Email, &a.UserID, &a.DisplayName, &a.Quota, &a.LastLogin, &a.Backend, &a.Home, &a.State, &a.Username, &a.OwnCloudUUID); err != nil {
   195  			appctx.GetLogger(ctx).Error().Err(err).Msg("could not scan row, skipping")
   196  			continue
   197  		}
   198  		accounts = append(accounts, a)
   199  	}
   200  	if err = rows.Err(); err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	return accounts, nil
   205  }
   206  
   207  // GetAccountGroups lasts the groups for an account
   208  func (as *Accounts) GetAccountGroups(ctx context.Context, uid string) ([]string, error) {
   209  	rows, err := as.db.QueryContext(ctx, "SELECT gid FROM oc_group_user WHERE uid=?", uid)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  	defer rows.Close()
   214  
   215  	groups := []string{}
   216  	for rows.Next() {
   217  		var group string
   218  		if err := rows.Scan(&group); err != nil {
   219  			appctx.GetLogger(ctx).Error().Err(err).Msg("could not scan row, skipping")
   220  			continue
   221  		}
   222  		groups = append(groups, group)
   223  	}
   224  	if err = rows.Err(); err != nil {
   225  		return nil, err
   226  	}
   227  	return groups, nil
   228  }