github.com/cs3org/reva/v2@v2.27.7/pkg/ocm/provider/authorizer/json/json.go (about)

     1  // Copyright 2018-2023 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package json
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"net"
    25  	"net/url"
    26  	"os"
    27  	"regexp"
    28  	"strings"
    29  	"sync"
    30  
    31  	ocmprovider "github.com/cs3org/go-cs3apis/cs3/ocm/provider/v1beta1"
    32  	"github.com/pkg/errors"
    33  
    34  	"github.com/cs3org/reva/v2/pkg/appctx"
    35  	"github.com/cs3org/reva/v2/pkg/errtypes"
    36  	"github.com/cs3org/reva/v2/pkg/ocm/provider"
    37  	"github.com/cs3org/reva/v2/pkg/ocm/provider/authorizer/registry"
    38  	"github.com/cs3org/reva/v2/pkg/utils/cfg"
    39  )
    40  
    41  func init() {
    42  	registry.Register("json", New)
    43  }
    44  
    45  var (
    46  	ErrNoIP = errtypes.NotSupported("No IP provided")
    47  )
    48  
    49  // New returns a new authorizer object.
    50  func New(m map[string]interface{}) (provider.Authorizer, error) {
    51  	var c config
    52  	if err := cfg.Decode(m, &c); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	providers := []*ocmprovider.ProviderInfo{}
    57  	f, err := os.ReadFile(c.Providers)
    58  	if err != nil {
    59  		if !os.IsNotExist(err) {
    60  			return nil, err
    61  		}
    62  	} else {
    63  		err = json.Unmarshal(f, &providers)
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  	}
    68  
    69  	a := &authorizer{
    70  		providerIPs: sync.Map{},
    71  		conf:        &c,
    72  	}
    73  	a.providers = a.getOCMProviders(providers)
    74  
    75  	return a, nil
    76  }
    77  
    78  type config struct {
    79  	Providers             string `mapstructure:"providers"`
    80  	VerifyRequestHostname bool   `mapstructure:"verify_request_hostname"`
    81  }
    82  
    83  func (c *config) ApplyTemplates() {
    84  	if c.Providers == "" {
    85  		c.Providers = "/etc/revad/ocm-providers.json"
    86  	}
    87  }
    88  
    89  type authorizer struct {
    90  	providers   []*ocmprovider.ProviderInfo
    91  	providerIPs sync.Map
    92  	conf        *config
    93  }
    94  
    95  func normalizeDomain(d string) (string, error) {
    96  	var urlString string
    97  	if strings.Contains(d, "://") {
    98  		urlString = d
    99  	} else {
   100  		urlString = "https://" + d
   101  	}
   102  
   103  	u, err := url.Parse(urlString)
   104  	if err != nil {
   105  		return "", err
   106  	}
   107  
   108  	return u.Host, nil
   109  }
   110  
   111  func (a *authorizer) GetInfoByDomain(_ context.Context, domain string) (*ocmprovider.ProviderInfo, error) {
   112  	normalizedDomain, err := normalizeDomain(domain)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  	for _, p := range a.providers {
   117  		// we can exit early if this an exact match
   118  		if strings.Contains(p.Domain, normalizedDomain) {
   119  			return p, nil
   120  		}
   121  
   122  		// check if the domain matches a regex
   123  		if ok, err := regexp.MatchString(p.Domain, normalizedDomain); ok && err == nil {
   124  			// overwrite wildcards with the actual domain
   125  			var services []*ocmprovider.Service
   126  			for _, s := range p.Services {
   127  				services = append(services, &ocmprovider.Service{
   128  					Host: strings.ReplaceAll(s.Host, p.Domain, normalizedDomain),
   129  					Endpoint: &ocmprovider.ServiceEndpoint{
   130  						Type:        s.Endpoint.Type,
   131  						Name:        s.Endpoint.Name,
   132  						Path:        strings.ReplaceAll(s.Endpoint.Path, p.Domain, normalizedDomain),
   133  						IsMonitored: s.Endpoint.IsMonitored,
   134  						Properties:  s.Endpoint.Properties,
   135  					},
   136  					ApiVersion:          s.ApiVersion,
   137  					AdditionalEndpoints: s.AdditionalEndpoints,
   138  				})
   139  			}
   140  			return &ocmprovider.ProviderInfo{
   141  				Name:         p.Name,
   142  				FullName:     p.FullName,
   143  				Description:  p.Description,
   144  				Organization: p.Organization,
   145  				Domain:       normalizedDomain,
   146  				Homepage:     p.Homepage,
   147  				Email:        p.Email,
   148  				Services:     services,
   149  				Properties:   p.Properties,
   150  			}, nil
   151  		}
   152  	}
   153  	return nil, errtypes.NotFound(domain)
   154  }
   155  
   156  func (a *authorizer) IsProviderAllowed(ctx context.Context, pi *ocmprovider.ProviderInfo) error {
   157  	log := appctx.GetLogger(ctx)
   158  	var err error
   159  	normalizedDomain, err := normalizeDomain(pi.Domain)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	var providerAuthorized bool
   164  	if normalizedDomain != "" {
   165  		for _, p := range a.providers {
   166  			if ok, err := regexp.MatchString(p.Domain, normalizedDomain); ok && err == nil {
   167  				providerAuthorized = true
   168  				break
   169  			}
   170  		}
   171  	} else {
   172  		providerAuthorized = true
   173  	}
   174  
   175  	switch {
   176  	case !providerAuthorized:
   177  		return errtypes.NotFound(pi.GetDomain())
   178  	case !a.conf.VerifyRequestHostname:
   179  		return nil
   180  	case len(pi.Services) == 0:
   181  		return ErrNoIP
   182  	}
   183  
   184  	var ocmHost string
   185  	for _, p := range a.providers {
   186  		log.Debug().Msgf("Comparing '%s' to '%s'", p.Domain, normalizedDomain)
   187  		if p.Domain == normalizedDomain {
   188  			ocmHost, err = a.getOCMHost(p)
   189  			if err != nil {
   190  				return err
   191  			}
   192  			break
   193  		}
   194  	}
   195  	if ocmHost == "" {
   196  		return errtypes.InternalError("json: ocm host not specified for mesh provider")
   197  	}
   198  
   199  	providerAuthorized = false
   200  	var ipList []string
   201  	if hostIPs, ok := a.providerIPs.Load(ocmHost); ok {
   202  		ipList = hostIPs.([]string)
   203  	} else {
   204  		host, _, err := net.SplitHostPort(ocmHost)
   205  		if err != nil {
   206  			return errors.Wrap(err, "json: error looking up client IP")
   207  		}
   208  		addr, err := net.LookupIP(host)
   209  		if err != nil {
   210  			return errors.Wrap(err, "json: error looking up client IP")
   211  		}
   212  		for _, a := range addr {
   213  			ipList = append(ipList, a.String())
   214  		}
   215  		a.providerIPs.Store(ocmHost, ipList)
   216  	}
   217  
   218  	for _, ip := range ipList {
   219  		if ip == pi.Services[0].Host {
   220  			providerAuthorized = true
   221  			break
   222  		}
   223  	}
   224  	if !providerAuthorized {
   225  		return errtypes.NotFound("OCM Host")
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  func (a *authorizer) ListAllProviders(ctx context.Context) ([]*ocmprovider.ProviderInfo, error) {
   232  	return a.providers, nil
   233  }
   234  
   235  func (a *authorizer) getOCMProviders(providers []*ocmprovider.ProviderInfo) (po []*ocmprovider.ProviderInfo) {
   236  	for _, p := range providers {
   237  		_, err := a.getOCMHost(p)
   238  		if err == nil {
   239  			po = append(po, p)
   240  		}
   241  	}
   242  	return
   243  }
   244  
   245  func (a *authorizer) getOCMHost(pi *ocmprovider.ProviderInfo) (string, error) {
   246  	for _, s := range pi.Services {
   247  		if s.Endpoint.Type.Name == "OCM" {
   248  			return s.Host, nil
   249  		}
   250  	}
   251  	return "", errtypes.NotFound("OCM Host")
   252  }