github.com/cs3org/reva/v2@v2.27.7/pkg/cbox/user/rest/rest.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package rest
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"os"
    25  	"os/signal"
    26  	"strings"
    27  	"syscall"
    28  	"time"
    29  
    30  	userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
    31  	"github.com/cs3org/reva/v2/pkg/appctx"
    32  	utils "github.com/cs3org/reva/v2/pkg/cbox/utils"
    33  	"github.com/cs3org/reva/v2/pkg/user"
    34  	"github.com/cs3org/reva/v2/pkg/user/manager/registry"
    35  	"github.com/gomodule/redigo/redis"
    36  	"github.com/mitchellh/mapstructure"
    37  	"github.com/pkg/errors"
    38  	"github.com/rs/zerolog/log"
    39  )
    40  
    41  func init() {
    42  	registry.Register("rest", New)
    43  }
    44  
    45  type manager struct {
    46  	conf            *config
    47  	redisPool       *redis.Pool
    48  	apiTokenManager *utils.APITokenManager
    49  }
    50  
    51  type config struct {
    52  	// The address at which the redis server is running
    53  	RedisAddress string `mapstructure:"redis_address" docs:"localhost:6379"`
    54  	// The username for connecting to the redis server
    55  	RedisUsername string `mapstructure:"redis_username" docs:""`
    56  	// The password for connecting to the redis server
    57  	RedisPassword string `mapstructure:"redis_password" docs:""`
    58  	// The time in minutes for which the groups to which a user belongs would be cached
    59  	UserGroupsCacheExpiration int `mapstructure:"user_groups_cache_expiration" docs:"5"`
    60  	// The OIDC Provider
    61  	IDProvider string `mapstructure:"id_provider" docs:"http://cernbox.cern.ch"`
    62  	// Base API Endpoint
    63  	APIBaseURL string `mapstructure:"api_base_url" docs:"https://authorization-service-api-dev.web.cern.ch"`
    64  	// Client ID needed to authenticate
    65  	ClientID string `mapstructure:"client_id" docs:"-"`
    66  	// Client Secret
    67  	ClientSecret string `mapstructure:"client_secret" docs:"-"`
    68  
    69  	// Endpoint to generate token to access the API
    70  	OIDCTokenEndpoint string `mapstructure:"oidc_token_endpoint" docs:"https://keycloak-dev.cern.ch/auth/realms/cern/api-access/token"`
    71  	// The target application for which token needs to be generated
    72  	TargetAPI string `mapstructure:"target_api" docs:"authorization-service-api"`
    73  	// The time in seconds between bulk fetch of user accounts
    74  	UserFetchInterval int `mapstructure:"user_fetch_interval" docs:"3600"`
    75  }
    76  
    77  func (c *config) init() {
    78  	if c.UserGroupsCacheExpiration == 0 {
    79  		c.UserGroupsCacheExpiration = 5
    80  	}
    81  	if c.RedisAddress == "" {
    82  		c.RedisAddress = ":6379"
    83  	}
    84  	if c.APIBaseURL == "" {
    85  		c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch"
    86  	}
    87  	if c.TargetAPI == "" {
    88  		c.TargetAPI = "authorization-service-api"
    89  	}
    90  	if c.OIDCTokenEndpoint == "" {
    91  		c.OIDCTokenEndpoint = "https://keycloak-dev.cern.ch/auth/realms/cern/api-access/token"
    92  	}
    93  	if c.IDProvider == "" {
    94  		c.IDProvider = "http://cernbox.cern.ch"
    95  	}
    96  	if c.UserFetchInterval == 0 {
    97  		c.UserFetchInterval = 3600
    98  	}
    99  }
   100  
   101  func parseConfig(m map[string]interface{}) (*config, error) {
   102  	c := &config{}
   103  	if err := mapstructure.Decode(m, c); err != nil {
   104  		return nil, err
   105  	}
   106  	return c, nil
   107  }
   108  
   109  // New returns a user manager implementation that makes calls to the GRAPPA API.
   110  func New(m map[string]interface{}) (user.Manager, error) {
   111  	mgr := &manager{}
   112  	err := mgr.Configure(m)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  	return mgr, err
   117  }
   118  
   119  func (m *manager) Configure(ml map[string]interface{}) error {
   120  	c, err := parseConfig(ml)
   121  	if err != nil {
   122  		return err
   123  	}
   124  	c.init()
   125  	redisPool := initRedisPool(c.RedisAddress, c.RedisUsername, c.RedisPassword)
   126  	apiTokenManager := utils.InitAPITokenManager(c.TargetAPI, c.OIDCTokenEndpoint, c.ClientID, c.ClientSecret)
   127  	m.conf = c
   128  	m.redisPool = redisPool
   129  	m.apiTokenManager = apiTokenManager
   130  
   131  	// Since we're starting a subroutine which would take some time to execute,
   132  	// we can't wait to see if it works before returning the user.Manager object
   133  	// TODO: return err if the fetch fails
   134  	go m.fetchAllUsers()
   135  	return nil
   136  }
   137  
   138  func (m *manager) fetchAllUsers() {
   139  	_ = m.fetchAllUserAccounts()
   140  	ticker := time.NewTicker(time.Duration(m.conf.UserFetchInterval) * time.Second)
   141  	work := make(chan os.Signal, 1)
   142  	signal.Notify(work, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT)
   143  
   144  	for {
   145  		select {
   146  		case <-work:
   147  			return
   148  		case <-ticker.C:
   149  			_ = m.fetchAllUserAccounts()
   150  		}
   151  	}
   152  }
   153  
   154  func (m *manager) fetchAllUserAccounts() error {
   155  	ctx := context.Background()
   156  	url := fmt.Sprintf("%s/api/v1.0/Identity?field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", m.conf.APIBaseURL)
   157  
   158  	for url != "" {
   159  		result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
   160  		if err != nil {
   161  			return err
   162  		}
   163  
   164  		responseData, ok := result["data"].([]interface{})
   165  		if !ok {
   166  			return errors.New("rest: error in type assertion")
   167  		}
   168  		for _, usr := range responseData {
   169  			userData, ok := usr.(map[string]interface{})
   170  			if !ok {
   171  				continue
   172  			}
   173  
   174  			_, err = m.parseAndCacheUser(ctx, userData)
   175  			if err != nil {
   176  				continue
   177  			}
   178  		}
   179  
   180  		url = ""
   181  		if pagination, ok := result["pagination"].(map[string]interface{}); ok {
   182  			if links, ok := pagination["links"].(map[string]interface{}); ok {
   183  				if next, ok := links["next"].(string); ok {
   184  					url = fmt.Sprintf("%s%s", m.conf.APIBaseURL, next)
   185  				}
   186  			}
   187  		}
   188  	}
   189  
   190  	return nil
   191  }
   192  
   193  func (m *manager) parseAndCacheUser(ctx context.Context, userData map[string]interface{}) (*userpb.User, error) {
   194  	upn, ok := userData["upn"].(string)
   195  	if !ok {
   196  		return nil, errors.New("rest: missing upn in user data")
   197  	}
   198  	mail, _ := userData["primaryAccountEmail"].(string)
   199  	name, _ := userData["displayName"].(string)
   200  	uidNumber, _ := userData["uid"].(float64)
   201  	gidNumber, _ := userData["gid"].(float64)
   202  	t, _ := userData["type"].(string)
   203  	userType := getUserType(t, upn)
   204  
   205  	userID := &userpb.UserId{
   206  		OpaqueId: upn,
   207  		Idp:      m.conf.IDProvider,
   208  		Type:     userType,
   209  	}
   210  	u := &userpb.User{
   211  		Id:          userID,
   212  		Username:    upn,
   213  		Mail:        mail,
   214  		DisplayName: name,
   215  		UidNumber:   int64(uidNumber),
   216  		GidNumber:   int64(gidNumber),
   217  	}
   218  
   219  	if err := m.cacheUserDetails(u); err != nil {
   220  		log.Error().Err(err).Msg("rest: error caching user details")
   221  	}
   222  	return u, nil
   223  }
   224  
   225  func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId, skipFetchingGroups bool) (*userpb.User, error) {
   226  	u, err := m.fetchCachedUserDetails(uid)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	if !skipFetchingGroups {
   232  		userGroups, err := m.GetUserGroups(ctx, uid)
   233  		if err != nil {
   234  			return nil, err
   235  		}
   236  		u.Groups = userGroups
   237  	}
   238  
   239  	return u, nil
   240  }
   241  
   242  func (m *manager) GetUserByClaim(ctx context.Context, claim, value string, skipFetchingGroups bool) (*userpb.User, error) {
   243  	u, err := m.fetchCachedUserByParam(claim, value)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	if !skipFetchingGroups {
   249  		userGroups, err := m.GetUserGroups(ctx, u.Id)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  		u.Groups = userGroups
   254  	}
   255  
   256  	return u, nil
   257  }
   258  
   259  func (m *manager) FindUsers(ctx context.Context, query string, skipFetchingGroups bool) ([]*userpb.User, error) {
   260  
   261  	// Look at namespaces filters. If the query starts with:
   262  	// "a" => look into primary/secondary/service accounts
   263  	// "l" => look into lightweight/federated accounts
   264  	// none => look into primary
   265  
   266  	parts := strings.SplitN(query, ":", 2)
   267  
   268  	var namespace string
   269  	if len(parts) == 2 {
   270  		// the query contains a namespace filter
   271  		namespace, query = parts[0], parts[1]
   272  	}
   273  
   274  	users, err := m.findCachedUsers(query)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	userSlice := []*userpb.User{}
   280  
   281  	var accountsFilters []userpb.UserType
   282  	switch namespace {
   283  	case "":
   284  		accountsFilters = []userpb.UserType{userpb.UserType_USER_TYPE_PRIMARY}
   285  	case "a":
   286  		accountsFilters = []userpb.UserType{userpb.UserType_USER_TYPE_PRIMARY, userpb.UserType_USER_TYPE_SECONDARY, userpb.UserType_USER_TYPE_SERVICE}
   287  	case "l":
   288  		accountsFilters = []userpb.UserType{userpb.UserType_USER_TYPE_LIGHTWEIGHT, userpb.UserType_USER_TYPE_FEDERATED}
   289  	}
   290  
   291  	for _, u := range users {
   292  		if isUserAnyType(u, accountsFilters) {
   293  			userSlice = append(userSlice, u)
   294  		}
   295  	}
   296  
   297  	return userSlice, nil
   298  }
   299  
   300  // isUserAnyType returns true if the user's type is one of types list
   301  func isUserAnyType(user *userpb.User, types []userpb.UserType) bool {
   302  	for _, t := range types {
   303  		if user.GetId().Type == t {
   304  			return true
   305  		}
   306  	}
   307  	return false
   308  }
   309  
   310  func (m *manager) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error) {
   311  	groups, err := m.fetchCachedUserGroups(uid)
   312  	if err == nil {
   313  		return groups, nil
   314  	}
   315  
   316  	url := fmt.Sprintf("%s/api/v1.0/Identity/%s/groups?recursive=true", m.conf.APIBaseURL, uid.OpaqueId)
   317  	result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	groupData := result["data"].([]interface{})
   323  	groups = []string{}
   324  
   325  	for _, g := range groupData {
   326  		groupInfo, ok := g.(map[string]interface{})
   327  		if !ok {
   328  			return nil, errors.New("rest: error in type assertion")
   329  		}
   330  		name, ok := groupInfo["displayName"].(string)
   331  		if ok {
   332  			groups = append(groups, name)
   333  		}
   334  	}
   335  
   336  	if err = m.cacheUserGroups(uid, groups); err != nil {
   337  		log := appctx.GetLogger(ctx)
   338  		log.Error().Err(err).Msg("rest: error caching user groups")
   339  	}
   340  
   341  	return groups, nil
   342  }
   343  
   344  func (m *manager) IsInGroup(ctx context.Context, uid *userpb.UserId, group string) (bool, error) {
   345  	userGroups, err := m.GetUserGroups(ctx, uid)
   346  	if err != nil {
   347  		return false, err
   348  	}
   349  
   350  	for _, g := range userGroups {
   351  		if group == g {
   352  			return true, nil
   353  		}
   354  	}
   355  	return false, nil
   356  }
   357  
   358  func getUserType(userType, upn string) userpb.UserType {
   359  	var t userpb.UserType
   360  	switch userType {
   361  	case "Application":
   362  		t = userpb.UserType_USER_TYPE_APPLICATION
   363  	case "Service":
   364  		t = userpb.UserType_USER_TYPE_SERVICE
   365  	case "Secondary":
   366  		t = userpb.UserType_USER_TYPE_SECONDARY
   367  	case "Person":
   368  		switch {
   369  		case strings.HasPrefix(upn, "guest"):
   370  			t = userpb.UserType_USER_TYPE_LIGHTWEIGHT
   371  		case strings.Contains(upn, "@"):
   372  			t = userpb.UserType_USER_TYPE_FEDERATED
   373  		default:
   374  			t = userpb.UserType_USER_TYPE_PRIMARY
   375  		}
   376  	default:
   377  		t = userpb.UserType_USER_TYPE_INVALID
   378  	}
   379  	return t
   380  
   381  }