github.com/cs3org/reva/v2@v2.27.7/pkg/ocm/provider/authorizer/mentix/mentix.go (about)

     1  // Copyright 2018-2023 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package mentix
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"net/url"
    28  	"strings"
    29  	"sync"
    30  	"time"
    31  
    32  	ocmprovider "github.com/cs3org/go-cs3apis/cs3/ocm/provider/v1beta1"
    33  	"github.com/cs3org/reva/v2/pkg/errtypes"
    34  	"github.com/cs3org/reva/v2/pkg/ocm/provider"
    35  	"github.com/cs3org/reva/v2/pkg/ocm/provider/authorizer/registry"
    36  	"github.com/cs3org/reva/v2/pkg/rhttp"
    37  	"github.com/cs3org/reva/v2/pkg/utils/cfg"
    38  	"github.com/pkg/errors"
    39  )
    40  
    41  func init() {
    42  	registry.Register("mentix", New)
    43  }
    44  
    45  // Client is a Mentix API client.
    46  type Client struct {
    47  	BaseURL    string
    48  	HTTPClient *http.Client
    49  }
    50  
    51  // New returns a new authorizer object.
    52  func New(m map[string]interface{}) (provider.Authorizer, error) {
    53  	var c config
    54  	if err := cfg.Decode(m, &c); err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	client := &Client{
    59  		BaseURL: c.URL,
    60  		HTTPClient: rhttp.GetHTTPClient(
    61  			rhttp.Context(context.Background()),
    62  			rhttp.Timeout(time.Duration(c.Timeout*int64(time.Second))),
    63  			rhttp.Insecure(c.Insecure),
    64  		),
    65  	}
    66  
    67  	return &authorizer{
    68  		client:      client,
    69  		providerIPs: sync.Map{},
    70  		conf:        &c,
    71  	}, nil
    72  }
    73  
    74  type config struct {
    75  	URL                   string `mapstructure:"url"`
    76  	Timeout               int64  `mapstructure:"timeout"`
    77  	RefreshInterval       int64  `mapstructure:"refresh"`
    78  	VerifyRequestHostname bool   `mapstructure:"verify_request_hostname"`
    79  	Insecure              bool   `mapstructure:"insecure" docs:"false;Whether to skip certificate checks when sending requests."`
    80  }
    81  
    82  func (c *config) ApplyDefaults() {
    83  	if c.URL == "" {
    84  		c.URL = "http://localhost:9600/mentix/cs3"
    85  	}
    86  }
    87  
    88  type authorizer struct {
    89  	providers           []*ocmprovider.ProviderInfo
    90  	providersExpiration int64
    91  	client              *Client
    92  	providerIPs         sync.Map
    93  	conf                *config
    94  }
    95  
    96  func normalizeDomain(d string) (string, error) {
    97  	var urlString string
    98  	if strings.Contains(d, "://") {
    99  		urlString = d
   100  	} else {
   101  		urlString = "https://" + d
   102  	}
   103  
   104  	u, err := url.Parse(urlString)
   105  	if err != nil {
   106  		return "", err
   107  	}
   108  
   109  	return u.Hostname(), nil
   110  }
   111  
   112  func (a *authorizer) fetchProviders() ([]*ocmprovider.ProviderInfo, error) {
   113  	if (a.providers != nil) && (time.Now().Unix() < a.providersExpiration) {
   114  		return a.providers, nil
   115  	}
   116  
   117  	req, err := http.NewRequest(http.MethodGet, a.client.BaseURL, nil)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	req.Header.Set("Accept", "application/json; charset=utf-8")
   122  	req.Header.Set("Content-Type", "application/json; charset=utf-8")
   123  
   124  	res, err := a.client.HTTPClient.Do(req)
   125  	if err != nil {
   126  		err = errors.Wrap(err,
   127  			fmt.Sprintf("mentix: error fetching provider list from: %s", a.client.BaseURL))
   128  		return nil, err
   129  	}
   130  
   131  	defer res.Body.Close()
   132  
   133  	providers := make([]*ocmprovider.ProviderInfo, 0)
   134  	if err = json.NewDecoder(res.Body).Decode(&providers); err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	a.providers = a.getOCMProviders(providers)
   139  	if a.conf.RefreshInterval > 0 {
   140  		a.providersExpiration = time.Now().Unix() + a.conf.RefreshInterval
   141  	}
   142  	return a.providers, nil
   143  }
   144  
   145  func (a *authorizer) GetInfoByDomain(ctx context.Context, domain string) (*ocmprovider.ProviderInfo, error) {
   146  	normalizedDomain, err := normalizeDomain(domain)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	providers, err := a.fetchProviders()
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	for _, p := range providers {
   156  		if strings.Contains(p.Domain, normalizedDomain) {
   157  			return p, nil
   158  		}
   159  	}
   160  	return nil, errtypes.NotFound(domain)
   161  }
   162  
   163  func (a *authorizer) IsProviderAllowed(ctx context.Context, pi *ocmprovider.ProviderInfo) error {
   164  	providers, err := a.fetchProviders()
   165  	if err != nil {
   166  		return err
   167  	}
   168  	normalizedDomain, err := normalizeDomain(pi.Domain)
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	var providerAuthorized bool
   174  	if normalizedDomain != "" {
   175  		for _, p := range providers {
   176  			if p.Domain == normalizedDomain {
   177  				providerAuthorized = true
   178  				break
   179  			}
   180  		}
   181  	} else {
   182  		providerAuthorized = true
   183  	}
   184  
   185  	switch {
   186  	case !providerAuthorized:
   187  		return errtypes.NotFound(pi.GetDomain())
   188  	case !a.conf.VerifyRequestHostname:
   189  		return nil
   190  	case len(pi.Services) == 0:
   191  		return errtypes.NotSupported(
   192  			fmt.Sprintf("mentix: provider %s has no supported services", pi.GetDomain()))
   193  	}
   194  
   195  	var ocmHost string
   196  	for _, p := range providers {
   197  		if p.Domain == normalizedDomain {
   198  			ocmHost, err = a.getOCMHost(p)
   199  			if err != nil {
   200  				return err
   201  			}
   202  			break
   203  		}
   204  	}
   205  	if ocmHost == "" {
   206  		return errtypes.NotSupported(
   207  			fmt.Sprintf("mentix: provider %s is missing OCM endpoint", pi.GetDomain()))
   208  	}
   209  
   210  	providerAuthorized = false
   211  	var ipList []string
   212  	if hostIPs, ok := a.providerIPs.Load(ocmHost); ok {
   213  		ipList = hostIPs.([]string)
   214  	} else {
   215  		addr, err := net.LookupIP(ocmHost)
   216  		if err != nil {
   217  			return errors.Wrap(err,
   218  				fmt.Sprintf("mentix: error looking up IPs for OCM endpoint %s", ocmHost))
   219  		}
   220  		for _, a := range addr {
   221  			ipList = append(ipList, a.String())
   222  		}
   223  		a.providerIPs.Store(ocmHost, ipList)
   224  	}
   225  
   226  	for _, ip := range ipList {
   227  		if ip == pi.Services[0].Host {
   228  			providerAuthorized = true
   229  			break
   230  		}
   231  	}
   232  	if !providerAuthorized {
   233  		return errtypes.BadRequest(
   234  			fmt.Sprintf(
   235  				"Invalid requesting OCM endpoint IP %s of provider %s",
   236  				pi.Services[0].Host, pi.GetDomain()))
   237  	}
   238  
   239  	return nil
   240  }
   241  
   242  func (a *authorizer) ListAllProviders(ctx context.Context) ([]*ocmprovider.ProviderInfo, error) {
   243  	providers, err := a.fetchProviders()
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	return providers, nil
   248  }
   249  
   250  func (a *authorizer) getOCMProviders(providers []*ocmprovider.ProviderInfo) (po []*ocmprovider.ProviderInfo) {
   251  	for _, p := range providers {
   252  		_, err := a.getOCMHost(p)
   253  		if err == nil {
   254  			po = append(po, p)
   255  		}
   256  	}
   257  	return
   258  }
   259  
   260  func (a *authorizer) getOCMHost(provider *ocmprovider.ProviderInfo) (string, error) {
   261  	for _, s := range provider.Services {
   262  		if s.Endpoint.Type.Name == "OCM" {
   263  			return s.Host, nil
   264  		}
   265  	}
   266  	return "", errtypes.NotFound("OCM Host")
   267  }