github.com/cs3org/reva/v2@v2.27.7/pkg/ocm/share/repository/json/json.go (about)

     1  // Copyright 2018-2023 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package json
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"io"
    25  	"os"
    26  	"path/filepath"
    27  	"sync"
    28  	"time"
    29  
    30  	userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
    31  	ocm "github.com/cs3org/go-cs3apis/cs3/sharing/ocm/v1beta1"
    32  	provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
    33  	typespb "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
    34  	"github.com/google/uuid"
    35  	"github.com/pkg/errors"
    36  	"google.golang.org/genproto/protobuf/field_mask"
    37  
    38  	"github.com/cs3org/reva/v2/pkg/errtypes"
    39  	"github.com/cs3org/reva/v2/pkg/ocm/share"
    40  	"github.com/cs3org/reva/v2/pkg/ocm/share/repository/registry"
    41  	"github.com/cs3org/reva/v2/pkg/utils"
    42  	"github.com/cs3org/reva/v2/pkg/utils/cfg"
    43  )
    44  
    45  func init() {
    46  	registry.Register("json", New)
    47  }
    48  
    49  // New returns a new authorizer object.
    50  func New(m map[string]interface{}) (share.Repository, error) {
    51  	var c config
    52  	if err := cfg.Decode(m, &c); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	// load or create file
    57  	model, err := loadOrCreate(c.File)
    58  	if err != nil {
    59  		err = errors.Wrap(err, "error loading the file containing the shares")
    60  		return nil, err
    61  	}
    62  
    63  	mgr := &mgr{
    64  		c:     &c,
    65  		model: model,
    66  	}
    67  
    68  	return mgr, nil
    69  }
    70  
    71  func loadOrCreate(file string) (*shareModel, error) {
    72  	_, err := os.Stat(file)
    73  	if os.IsNotExist(err) {
    74  		if err := os.MkdirAll(filepath.Dir(file), 0700); err != nil {
    75  			return nil, errors.Wrap(err, "error creating the base directory: "+filepath.Dir(file))
    76  		}
    77  		if err := os.WriteFile(file, []byte("{}"), 0700); err != nil {
    78  			return nil, errors.Wrap(err, "error creating the file: "+file)
    79  		}
    80  	}
    81  
    82  	f, err := os.OpenFile(file, os.O_RDONLY, 0644)
    83  	if err != nil {
    84  		err = errors.Wrap(err, "error opening the file: "+file)
    85  		return nil, err
    86  	}
    87  	defer f.Close()
    88  
    89  	var m shareModel
    90  	if err := json.NewDecoder(f).Decode(&m); err != nil {
    91  		if err != io.EOF {
    92  			return nil, errors.Wrap(err, "error decoding data to json")
    93  		}
    94  	}
    95  
    96  	if m.Shares == nil {
    97  		m.Shares = map[string]*ocm.Share{}
    98  	}
    99  	if m.ReceivedShares == nil {
   100  		m.ReceivedShares = map[string]*ocm.ReceivedShare{}
   101  	}
   102  
   103  	return &m, nil
   104  }
   105  
   106  type shareModel struct {
   107  	Shares         map[string]*ocm.Share         `json:"shares"`          // share_id -> share
   108  	ReceivedShares map[string]*ocm.ReceivedShare `json:"received_shares"` // share_id -> share
   109  }
   110  
   111  func (s *shareModel) UnmarshalJSON(d []byte) error {
   112  	m := struct {
   113  		Shares         map[string]json.RawMessage `json:"shares"`
   114  		ReceivedShares map[string]json.RawMessage `json:"received_shares"`
   115  	}{}
   116  
   117  	if err := json.Unmarshal(d, &m); err != nil {
   118  		return err
   119  	}
   120  
   121  	share := map[string]*ocm.Share{}
   122  	for k, v := range m.Shares {
   123  		var s ocm.Share
   124  		if err := utils.UnmarshalJSONToProtoV1(v, &s); err != nil {
   125  			return err
   126  		}
   127  		share[k] = &s
   128  	}
   129  
   130  	received := map[string]*ocm.ReceivedShare{}
   131  	for k, v := range m.ReceivedShares {
   132  		var s ocm.ReceivedShare
   133  		if err := utils.UnmarshalJSONToProtoV1(v, &s); err != nil {
   134  			return err
   135  		}
   136  		received[k] = &s
   137  	}
   138  
   139  	*s = shareModel{
   140  		Shares:         share,
   141  		ReceivedShares: received,
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (s *shareModel) MarshalJSON() ([]byte, error) {
   148  	shares := map[string]json.RawMessage{}
   149  	for k, v := range s.Shares {
   150  		d, err := utils.MarshalProtoV1ToJSON(v)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  		shares[k] = d
   155  	}
   156  
   157  	received := map[string]json.RawMessage{}
   158  	for k, v := range s.ReceivedShares {
   159  		d, err := utils.MarshalProtoV1ToJSON(v)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		received[k] = d
   164  	}
   165  
   166  	return json.Marshal(map[string]any{
   167  		"shares":          shares,
   168  		"received_shares": received,
   169  	})
   170  }
   171  
   172  type config struct {
   173  	File string `mapstructure:"file"`
   174  }
   175  
   176  func (c *config) ApplyDefaults() {
   177  	if c.File == "" {
   178  		c.File = "/var/tmp/reva/ocm-shares.json"
   179  	}
   180  }
   181  
   182  type mgr struct {
   183  	c          *config
   184  	sync.Mutex // concurrent access to the file
   185  	model      *shareModel
   186  }
   187  
   188  func (m *mgr) save() error {
   189  	f, err := os.OpenFile(m.c.File, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   190  	if err != nil {
   191  		return errors.Wrap(err, "error opening file "+m.c.File)
   192  	}
   193  	defer f.Close()
   194  
   195  	if err := json.NewEncoder(f).Encode(m.model); err != nil {
   196  		return errors.Wrap(err, "error encoding to json")
   197  	}
   198  
   199  	return f.Sync()
   200  }
   201  
   202  func (m *mgr) load() error {
   203  	f, err := os.OpenFile(m.c.File, os.O_RDONLY, 0644)
   204  	if err != nil {
   205  		return errors.Wrap(err, "error opening file "+m.c.File)
   206  	}
   207  	defer f.Close()
   208  
   209  	d, err := io.ReadAll(f)
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	var model shareModel
   215  	if err := json.Unmarshal(d, &model); err != nil {
   216  		return err
   217  	}
   218  
   219  	m.model = &model
   220  	return nil
   221  }
   222  
   223  func genID() string {
   224  	return uuid.New().String()
   225  }
   226  
   227  func (m *mgr) StoreShare(ctx context.Context, ocmshare *ocm.Share) (*ocm.Share, error) {
   228  	m.Lock()
   229  	defer m.Unlock()
   230  
   231  	if err := m.load(); err != nil {
   232  		return nil, err
   233  	}
   234  
   235  	if _, err := m.getByKey(ctx, &ocm.ShareKey{
   236  		Owner:      ocmshare.Owner,
   237  		ResourceId: ocmshare.ResourceId,
   238  		Grantee:    ocmshare.Grantee,
   239  	}); err == nil {
   240  		return nil, share.ErrShareAlreadyExisting
   241  	}
   242  
   243  	ocmshare.Id = &ocm.ShareId{OpaqueId: genID()}
   244  	clone, err := cloneShare(ocmshare)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  	m.model.Shares[ocmshare.Id.OpaqueId] = clone
   249  
   250  	if err := m.save(); err != nil {
   251  		return nil, errors.Wrap(err, "error saving share")
   252  	}
   253  
   254  	return ocmshare, nil
   255  }
   256  
   257  func cloneShare(s *ocm.Share) (*ocm.Share, error) {
   258  	d, err := utils.MarshalProtoV1ToJSON(s)
   259  	if err != nil {
   260  		return nil, errtypes.InternalError("failed to marshal ocm share")
   261  	}
   262  	var cloned ocm.Share
   263  	if err := utils.UnmarshalJSONToProtoV1(d, &cloned); err != nil {
   264  		return nil, errtypes.InternalError("failed to unmarshal ocm share")
   265  	}
   266  	return &cloned, nil
   267  }
   268  
   269  func cloneReceivedShare(s *ocm.ReceivedShare) (*ocm.ReceivedShare, error) {
   270  	d, err := utils.MarshalProtoV1ToJSON(s)
   271  	if err != nil {
   272  		return nil, errtypes.InternalError("failed to marshal ocm received share")
   273  	}
   274  	var cloned ocm.ReceivedShare
   275  	if err := utils.UnmarshalJSONToProtoV1(d, &cloned); err != nil {
   276  		return nil, errtypes.InternalError("failed to unmarshal ocm received share")
   277  	}
   278  	return &cloned, nil
   279  }
   280  
   281  func (m *mgr) GetShare(ctx context.Context, user *userpb.User, ref *ocm.ShareReference) (*ocm.Share, error) {
   282  	m.Lock()
   283  	defer m.Unlock()
   284  
   285  	var (
   286  		s   *ocm.Share
   287  		err error
   288  	)
   289  
   290  	if err := m.load(); err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	switch {
   295  	case ref.GetId() != nil:
   296  		s, err = m.getByID(ctx, ref.GetId())
   297  	case ref.GetKey() != nil:
   298  		s, err = m.getByKey(ctx, ref.GetKey())
   299  	case ref.GetToken() != "":
   300  		return m.getByToken(ctx, ref.GetToken())
   301  	default:
   302  		err = errtypes.NotFound(ref.String())
   303  	}
   304  
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	// check if we are the owner
   310  	if utils.UserEqual(user.Id, s.Owner) || utils.UserEqual(user.Id, s.Creator) {
   311  		return s, nil
   312  	}
   313  
   314  	return nil, share.ErrShareNotFound
   315  }
   316  
   317  func (m *mgr) getByToken(ctx context.Context, token string) (*ocm.Share, error) {
   318  	for _, share := range m.model.Shares {
   319  		if share.Token == token {
   320  			return share, nil
   321  		}
   322  	}
   323  	return nil, errtypes.NotFound(token)
   324  }
   325  
   326  func (m *mgr) getByID(ctx context.Context, id *ocm.ShareId) (*ocm.Share, error) {
   327  	if share, ok := m.model.Shares[id.OpaqueId]; ok {
   328  		return share, nil
   329  	}
   330  	return nil, errtypes.NotFound(id.String())
   331  }
   332  
   333  func (m *mgr) getByKey(ctx context.Context, key *ocm.ShareKey) (*ocm.Share, error) {
   334  	for _, share := range m.model.Shares {
   335  		if (utils.UserEqual(key.Owner, share.Owner) || utils.UserEqual(key.Owner, share.Creator)) &&
   336  			utils.ResourceIDEqual(key.ResourceId, share.ResourceId) && utils.GranteeEqual(key.Grantee, share.Grantee) {
   337  			return share, nil
   338  		}
   339  	}
   340  	return nil, share.ErrShareNotFound
   341  }
   342  
   343  func (m *mgr) DeleteShare(ctx context.Context, user *userpb.User, ref *ocm.ShareReference) error {
   344  	m.Lock()
   345  	defer m.Unlock()
   346  
   347  	if err := m.load(); err != nil {
   348  		return err
   349  	}
   350  
   351  	for id, share := range m.model.Shares {
   352  		if sharesEqual(ref, share) {
   353  			if utils.UserEqual(user.Id, share.Owner) || utils.UserEqual(user.Id, share.Creator) {
   354  				delete(m.model.Shares, id)
   355  				return m.save()
   356  			}
   357  		}
   358  	}
   359  	return errtypes.NotFound(ref.String())
   360  }
   361  
   362  func sharesEqual(ref *ocm.ShareReference, s *ocm.Share) bool {
   363  	if ref.GetId() != nil && s.Id != nil {
   364  		if ref.GetId().OpaqueId == s.Id.OpaqueId {
   365  			return true
   366  		}
   367  	} else if ref.GetKey() != nil {
   368  		if (utils.UserEqual(ref.GetKey().Owner, s.Owner) || utils.UserEqual(ref.GetKey().Owner, s.Creator)) &&
   369  			utils.ResourceIDEqual(ref.GetKey().ResourceId, s.ResourceId) && utils.GranteeEqual(ref.GetKey().Grantee, s.Grantee) {
   370  			return true
   371  		}
   372  	}
   373  	return false
   374  }
   375  
   376  func receivedShareEqual(ref *ocm.ShareReference, s *ocm.ReceivedShare) bool {
   377  	if ref.GetId() != nil && s.Id != nil {
   378  		if ref.GetId().OpaqueId == s.Id.OpaqueId {
   379  			return true
   380  		}
   381  	}
   382  	// Match the reserved share by the remote share id
   383  	if ref.GetId() != nil && s.RemoteShareId != "" {
   384  		if ref.GetId().GetOpaqueId() == s.RemoteShareId {
   385  			return true
   386  		}
   387  	}
   388  	return false
   389  }
   390  
   391  // UpdateShare updates the share with the given fields.
   392  func (m *mgr) UpdateShare(ctx context.Context, user *userpb.User, ref *ocm.ShareReference, fields ...*ocm.UpdateOCMShareRequest_UpdateField) (*ocm.Share, error) {
   393  	m.Lock()
   394  	defer m.Unlock()
   395  	if err := m.load(); err != nil {
   396  		return nil, err
   397  	}
   398  	for _, s := range m.model.Shares {
   399  		if sharesEqual(ref, s) {
   400  			if utils.UserEqual(user.Id, s.Owner) || utils.UserEqual(user.Id, s.Creator) {
   401  
   402  				for _, f := range fields {
   403  					if exp := f.GetExpiration(); exp != nil {
   404  						s.Expiration = exp
   405  					}
   406  					if am := f.GetAccessMethods(); am != nil {
   407  						var (
   408  							webdavOptions   *ocm.WebDAVAccessMethod
   409  							webappOptions   *ocm.WebappAccessMethod
   410  							transferOptions *ocm.TransferAccessMethod
   411  							// TODO: *AccessMethod_GenericOptions
   412  
   413  							newWebdavOptions   *ocm.WebDAVAccessMethod
   414  							newWebappOptions   *ocm.WebappAccessMethod
   415  							newTransferOptions *ocm.TransferAccessMethod
   416  							// TODO: *AccessMethod_GenericOptions
   417  						)
   418  
   419  						for _, sm := range s.GetAccessMethods() {
   420  							webdavOptions = sm.GetWebdavOptions()
   421  							webappOptions = sm.GetWebappOptions()
   422  							transferOptions = sm.GetTransferOptions()
   423  						}
   424  
   425  						newWebdavOptions = am.GetWebdavOptions()
   426  						newWebappOptions = am.GetWebappOptions()
   427  						newTransferOptions = am.GetTransferOptions()
   428  
   429  						newAccesMethods := []*ocm.AccessMethod{}
   430  
   431  						if newWebdavOptions != nil {
   432  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   433  								Term: &ocm.AccessMethod_WebdavOptions{
   434  									WebdavOptions: newWebdavOptions,
   435  								},
   436  							})
   437  						} else if webdavOptions != nil {
   438  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   439  								Term: &ocm.AccessMethod_WebdavOptions{
   440  									WebdavOptions: webdavOptions,
   441  								},
   442  							})
   443  						}
   444  
   445  						if newWebappOptions != nil {
   446  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   447  								Term: &ocm.AccessMethod_WebappOptions{
   448  									WebappOptions: newWebappOptions,
   449  								},
   450  							})
   451  						} else if webappOptions != nil {
   452  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   453  								Term: &ocm.AccessMethod_WebappOptions{
   454  									WebappOptions: webappOptions,
   455  								},
   456  							})
   457  						}
   458  
   459  						if newTransferOptions != nil {
   460  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   461  								Term: &ocm.AccessMethod_TransferOptions{
   462  									TransferOptions: newTransferOptions,
   463  								},
   464  							})
   465  						} else if transferOptions != nil {
   466  							newAccesMethods = append(newAccesMethods, &ocm.AccessMethod{
   467  								Term: &ocm.AccessMethod_TransferOptions{
   468  									TransferOptions: transferOptions,
   469  								},
   470  							})
   471  						}
   472  						s.AccessMethods = newAccesMethods
   473  					}
   474  				}
   475  
   476  				clone, err := cloneShare(s)
   477  				if err != nil {
   478  					return nil, err
   479  				}
   480  				m.model.Shares[s.Id.OpaqueId] = clone
   481  
   482  				if err := m.save(); err != nil {
   483  					return nil, errors.Wrap(err, "error saving share")
   484  				}
   485  
   486  				return clone, nil
   487  			}
   488  		}
   489  	}
   490  
   491  	return nil, errtypes.NotFound(ref.String())
   492  }
   493  
   494  func (m *mgr) ListShares(ctx context.Context, user *userpb.User, filters []*ocm.ListOCMSharesRequest_Filter) ([]*ocm.Share, error) {
   495  	var ss []*ocm.Share
   496  
   497  	m.Lock()
   498  	defer m.Unlock()
   499  
   500  	if err := m.load(); err != nil {
   501  		return nil, err
   502  	}
   503  
   504  	for _, share := range m.model.Shares {
   505  		if utils.UserEqual(user.Id, share.Owner) || utils.UserEqual(user.Id, share.Creator) || utils.UserEqual(user.Id, share.GetGrantee().GetUserId()) {
   506  			// no filter we return earlier
   507  			if len(filters) == 0 {
   508  				ss = append(ss, share)
   509  			} else {
   510  				// check filters
   511  				// TODO(labkode): add the rest of filters.
   512  				for _, f := range filters {
   513  					if f.Type == ocm.ListOCMSharesRequest_Filter_TYPE_RESOURCE_ID {
   514  						if utils.ResourceIDEqual(share.ResourceId, f.GetResourceId()) {
   515  							ss = append(ss, share)
   516  						}
   517  					}
   518  				}
   519  			}
   520  		}
   521  	}
   522  	return ss, nil
   523  }
   524  
   525  func (m *mgr) StoreReceivedShare(ctx context.Context, share *ocm.ReceivedShare) (*ocm.ReceivedShare, error) {
   526  	m.Lock()
   527  	defer m.Unlock()
   528  
   529  	if err := m.load(); err != nil {
   530  		return nil, err
   531  	}
   532  
   533  	now := time.Now().UnixNano()
   534  	ts := &typespb.Timestamp{
   535  		Seconds: uint64(now / 1000000000),
   536  		Nanos:   uint32(now % 1000000000),
   537  	}
   538  
   539  	share.Id = &ocm.ShareId{
   540  		OpaqueId: genID(),
   541  	}
   542  	share.Ctime = ts
   543  	share.Mtime = ts
   544  
   545  	clone, err := cloneReceivedShare(share)
   546  	if err != nil {
   547  		return nil, err
   548  	}
   549  
   550  	m.model.ReceivedShares[share.Id.OpaqueId] = clone
   551  	if err := m.save(); err != nil {
   552  		return nil, err
   553  	}
   554  
   555  	return share, nil
   556  }
   557  
   558  func (m *mgr) ListReceivedShares(ctx context.Context, user *userpb.User) ([]*ocm.ReceivedShare, error) {
   559  	var rss []*ocm.ReceivedShare
   560  	m.Lock()
   561  	defer m.Unlock()
   562  
   563  	if err := m.load(); err != nil {
   564  		return nil, err
   565  	}
   566  
   567  	for _, share := range m.model.ReceivedShares {
   568  		if utils.UserEqual(user.Id, share.Owner) || utils.UserEqual(user.Id, share.Creator) {
   569  			// omit shares created by me
   570  			continue
   571  		}
   572  
   573  		if share.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_USER && utils.UserEqual(user.Id, share.Grantee.GetUserId()) {
   574  			rss = append(rss, share)
   575  		}
   576  	}
   577  	return rss, nil
   578  }
   579  
   580  func (m *mgr) GetReceivedShare(ctx context.Context, user *userpb.User, ref *ocm.ShareReference) (*ocm.ReceivedShare, error) {
   581  	m.Lock()
   582  	defer m.Unlock()
   583  
   584  	if err := m.load(); err != nil {
   585  		return nil, err
   586  	}
   587  
   588  	for _, share := range m.model.ReceivedShares {
   589  		if receivedShareEqual(ref, share) {
   590  			if share.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_USER && utils.UserEqual(user.Id, share.Grantee.GetUserId()) {
   591  				return share, nil
   592  			}
   593  		}
   594  	}
   595  	return nil, errtypes.NotFound(ref.String())
   596  }
   597  
   598  func (m *mgr) DeleteReceivedShare(ctx context.Context, user *userpb.User, ref *ocm.ShareReference) error {
   599  	m.Lock()
   600  	defer m.Unlock()
   601  
   602  	if err := m.load(); err != nil {
   603  		return err
   604  	}
   605  
   606  	for id, share := range m.model.ReceivedShares {
   607  		if receivedShareEqual(ref, share) && utils.UserEqual(user.Id, share.GetGrantee().GetUserId()) {
   608  			delete(m.model.ReceivedShares, id)
   609  			return m.save()
   610  		}
   611  	}
   612  	return errtypes.NotFound(ref.String())
   613  }
   614  
   615  func (m *mgr) UpdateReceivedShare(ctx context.Context, user *userpb.User, share *ocm.ReceivedShare, fieldMask *field_mask.FieldMask) (*ocm.ReceivedShare, error) {
   616  	rs, err := m.GetReceivedShare(ctx, user, &ocm.ShareReference{Spec: &ocm.ShareReference_Id{Id: share.Id}})
   617  	if err != nil {
   618  		return nil, err
   619  	}
   620  
   621  	m.Lock()
   622  	defer m.Unlock()
   623  
   624  	if err := m.load(); err != nil {
   625  		return nil, err
   626  	}
   627  
   628  	for _, mask := range fieldMask.Paths {
   629  		switch mask {
   630  		case "state":
   631  			rs.State = share.State
   632  			m.model.ReceivedShares[rs.Id.OpaqueId].State = share.State
   633  		case "protocols":
   634  			m.model.ReceivedShares[rs.Id.OpaqueId].Protocols = share.Protocols
   635  		// TODO case "mount_point":
   636  		default:
   637  			return nil, errtypes.NotSupported("updating " + mask + " is not supported")
   638  		}
   639  	}
   640  	m.model.ReceivedShares[rs.Id.OpaqueId].Mtime = &typespb.Timestamp{Seconds: uint64(time.Now().Second())}
   641  
   642  	if err := m.save(); err != nil {
   643  		return nil, errors.Wrap(err, "error saving model")
   644  	}
   645  
   646  	return rs, nil
   647  }