go.uber.org/yarpc@v1.72.1/api/x/restriction/restricter.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  // Package restriction is an experimental package for preventing unwanted
    22  // transport-encoding pairs.
    23  //
    24  // This package is under `x/` and subject to change. See README for details on
    25  // 'x' packages.
    26  package restriction
    27  
    28  import (
    29  	"errors"
    30  	"fmt"
    31  	"strings"
    32  
    33  	"go.uber.org/yarpc/api/transport"
    34  )
    35  
    36  // Checker is used by encoding clients, for example Protobuf and Thrift, to
    37  // prevent unwanted transport-encoding combinations.
    38  //
    39  // Errors indicate whitelisted combinations.
    40  type Checker interface {
    41  	Check(encoding transport.Encoding, transportName string) error
    42  }
    43  
    44  // Tuple defines a combination to whitelist.
    45  type Tuple struct {
    46  	Transport string
    47  	Encoding  transport.Encoding
    48  }
    49  
    50  // Validate verifes that a tuple has all fields set.
    51  func (t Tuple) Validate() error {
    52  	if t.Transport == "" || t.Encoding == "" {
    53  		return errors.New("tuple missing must have all fields set")
    54  	}
    55  	return nil
    56  }
    57  
    58  // String implements fmt.Stringer.
    59  func (t Tuple) String() string {
    60  	return fmt.Sprintf("%s/%s", t.Transport, t.Encoding)
    61  }
    62  
    63  type checker struct {
    64  	availableMsg string
    65  	tuples       map[Tuple]struct{}
    66  }
    67  
    68  // NewChecker creates a Checker with a whitelist tuple combinations.
    69  func NewChecker(tuples ...Tuple) (Checker, error) {
    70  	if len(tuples) == 0 {
    71  		return nil, errors.New("NewChecker requires at least one whitelisted tuple")
    72  	}
    73  
    74  	m := make(map[Tuple]struct{}, len(tuples))
    75  	for _, t := range tuples {
    76  		if err := t.Validate(); err != nil {
    77  			return nil, err
    78  		}
    79  		m[t] = struct{}{}
    80  	}
    81  
    82  	elements := make([]string, 0, len(tuples))
    83  	for _, t := range tuples {
    84  		elements = append(elements, t.String())
    85  	}
    86  
    87  	return &checker{
    88  		tuples:       m,
    89  		availableMsg: strings.Join(elements, ","),
    90  	}, nil
    91  }
    92  
    93  // Check returns nil for supported transport/encoding combinations and errors
    94  // for unsupported combinations. Errors indicate whitelisted combinations.
    95  //
    96  // Nil Checker will alwas return nil.
    97  func (r *checker) Check(encoding transport.Encoding, transportName string) error {
    98  	t := Tuple{
    99  		Transport: transportName,
   100  		Encoding:  encoding,
   101  	}
   102  
   103  	if _, ok := r.tuples[t]; ok {
   104  		return nil
   105  	}
   106  
   107  	return fmt.Errorf("%q is not a whitelisted combination, available: %q", t.String(), r.availableMsg)
   108  }