github.com/greenpau/go-authcrunch@v1.1.4/pkg/authn/transformer/transformer.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transformer
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"github.com/greenpau/go-authcrunch/pkg/acl"
    21  	cfgutil "github.com/greenpau/go-authcrunch/pkg/util/cfg"
    22  	"strings"
    23  )
    24  
    25  // Config represents a common set of configuration settings
    26  // applicable to the cookies issued by authn.Authenticator.
    27  type Config struct {
    28  	Matchers []string `json:"matchers,omitempty" xml:"matchers,omitempty" yaml:"matchers,omitempty"`
    29  	Actions  []string `json:"actions,omitempty" xml:"actions,omitempty" yaml:"actions,omitempty"`
    30  }
    31  
    32  type transform struct {
    33  	matcher *acl.AccessList
    34  	actions [][]string
    35  }
    36  
    37  // Factory holds configuration and associated finctions
    38  // for the cookies issued by authn.Authenticator.
    39  type Factory struct {
    40  	configs    []*Config
    41  	transforms []*transform
    42  }
    43  
    44  // NewFactory returns an instance of cookie factory.
    45  func NewFactory(cfgs []*Config) (*Factory, error) {
    46  	f := &Factory{}
    47  	if len(cfgs) == 0 {
    48  		return nil, fmt.Errorf("transformer has no config")
    49  	}
    50  	f.configs = cfgs
    51  
    52  	for _, cfg := range cfgs {
    53  		if len(cfg.Matchers) < 1 {
    54  			return nil, fmt.Errorf("transformer has no matchers: %v", cfg)
    55  		}
    56  		if len(cfg.Actions) < 1 {
    57  			return nil, fmt.Errorf("transformer has no actions: %v", cfg)
    58  		}
    59  
    60  		var actions [][]string
    61  		for _, encodedArgs := range cfg.Actions {
    62  			args, err := cfgutil.DecodeArgs(encodedArgs)
    63  			if err != nil {
    64  				return nil, fmt.Errorf("transformer for %q erred during arg decoding: %v", encodedArgs, err)
    65  			}
    66  			switch args[0] {
    67  			case "require":
    68  				actions = append(actions, args)
    69  			case "block", "deny":
    70  				actions = append(actions, args)
    71  			case "ui":
    72  				if len(args) < 4 {
    73  					return nil, fmt.Errorf("transformer for %q erred: ui config too short", encodedArgs)
    74  				}
    75  				switch args[1] {
    76  				case "link":
    77  					actions = append(actions, args[1:])
    78  				default:
    79  					return nil, fmt.Errorf("transformer for %q erred: invalid ui config", encodedArgs)
    80  				}
    81  			case "add", "overwrite", "drop":
    82  				if len(args) < 3 {
    83  					return nil, fmt.Errorf("transformer for %q erred: invalid add/overwrite config", encodedArgs)
    84  				}
    85  				actions = append(actions, args)
    86  			case "delete":
    87  				if len(args) < 2 {
    88  					return nil, fmt.Errorf("transformer for %q erred: invalid delete config", encodedArgs)
    89  				}
    90  				actions = append(actions, args)
    91  			case "action":
    92  				if len(args) < 3 {
    93  					return nil, fmt.Errorf("transformer for %q erred: action config too short", encodedArgs)
    94  				}
    95  				switch args[1] {
    96  				case "add", "overwrite", "delete", "drop":
    97  				default:
    98  					return nil, fmt.Errorf("transformer for %q erred: invalid action config", encodedArgs)
    99  				}
   100  				actions = append(actions, args[1:])
   101  			default:
   102  				return nil, fmt.Errorf("transformer has unsupported action: %v", args)
   103  			}
   104  		}
   105  		matcher := acl.NewAccessList()
   106  		matchRuleConfigs := []*acl.RuleConfiguration{
   107  			{
   108  				Conditions: cfg.Matchers,
   109  				Action:     "allow",
   110  			},
   111  		}
   112  		if err := matcher.AddRules(context.Background(), matchRuleConfigs); err != nil {
   113  			return nil, err
   114  		}
   115  		tr := &transform{
   116  			matcher: matcher,
   117  			actions: actions,
   118  		}
   119  		f.transforms = append(f.transforms, tr)
   120  	}
   121  	return f, nil
   122  }
   123  
   124  // Transform performs user data transformation.
   125  func (f *Factory) Transform(m map[string]interface{}) error {
   126  	var challenges, frontendLinks []string
   127  	if _, exists := m["mail"]; exists {
   128  		m["email"] = m["mail"].(string)
   129  		delete(m, "mail")
   130  	}
   131  	for _, transform := range f.transforms {
   132  		if matched := transform.matcher.Allow(context.Background(), m); !matched {
   133  			continue
   134  		}
   135  		for _, args := range transform.actions {
   136  			switch args[0] {
   137  			case "block", "deny":
   138  				return fmt.Errorf("transformer action is block/deny")
   139  			case "require":
   140  				challenges = append(challenges, cfgutil.EncodeArgs(args[1:]))
   141  			case "link":
   142  				frontendLinks = append(frontendLinks, cfgutil.EncodeArgs(args[1:]))
   143  			default:
   144  				if err := transformData(args, m, transform.matcher); err != nil {
   145  					return fmt.Errorf("transformer for %v erred: %v", args, err)
   146  				}
   147  			}
   148  		}
   149  	}
   150  	if len(challenges) > 0 {
   151  		m["challenges"] = challenges
   152  	}
   153  	if len(frontendLinks) > 0 {
   154  		m["frontend_links"] = frontendLinks
   155  	}
   156  
   157  	return nil
   158  }
   159  
   160  func transformData(args []string, m map[string]interface{}, matcher *acl.AccessList) error {
   161  	if len(args) < 3 {
   162  		return fmt.Errorf("too short")
   163  	}
   164  	switch args[0] {
   165  	case "add", "delete", "overwrite", "drop":
   166  	default:
   167  		return fmt.Errorf("unsupported action %v", args[0])
   168  	}
   169  
   170  	k, dt := acl.GetFieldDataType(args[1])
   171  	switch args[0] {
   172  	case "add":
   173  		switch dt {
   174  		case "list_str":
   175  			var entries, newEntries []string
   176  			switch val := m[k].(type) {
   177  			case string:
   178  				entries = strings.Split(val, " ")
   179  			case []string:
   180  				entries = val
   181  			case []interface{}:
   182  				for _, entry := range val {
   183  					switch e := entry.(type) {
   184  					case string:
   185  						entries = append(entries, e)
   186  					}
   187  				}
   188  			case nil:
   189  			default:
   190  				return fmt.Errorf("unsupported %q field type %T with value: %v in %v", k, val, val, args)
   191  			}
   192  			entries = append(entries, args[2:]...)
   193  			entryMap := make(map[string]bool)
   194  			for _, e := range entries {
   195  				e = strings.TrimSpace(e)
   196  				if e == "" {
   197  					continue
   198  				}
   199  				v, err := repl(m, e)
   200  				if err != nil {
   201  					return err
   202  				}
   203  				if _, exists := entryMap[v]; exists {
   204  					continue
   205  				}
   206  				entryMap[v] = true
   207  				newEntries = append(newEntries, v)
   208  			}
   209  			m[k] = newEntries
   210  		case "str":
   211  			var e string
   212  			switch val := m[k].(type) {
   213  			case string:
   214  				e = val + " " + strings.Join(args[2:], " ")
   215  			case nil:
   216  				e = strings.Join(args[2:], " ")
   217  			}
   218  
   219  			v, err := repl(m, e)
   220  			if err != nil {
   221  				return err
   222  			}
   223  			m[k] = v
   224  		default:
   225  			// Handle custom fields.
   226  			if args[1] == "nested" {
   227  				nestedKeys, nestedValues, err := parseCustomNestedFieldValues(args[2:])
   228  				if err != nil {
   229  					return fmt.Errorf("failed transforming %q field for %q action in %v: %v", k, args[0], args, err)
   230  				}
   231  
   232  				// Use pointers to create nested map.
   233  				var mp map[string]interface{}
   234  				mp = m
   235  				for i, v := range nestedKeys {
   236  					if i == len(nestedKeys)-1 {
   237  						// Handle last element.
   238  						mp[v] = nestedValues
   239  						continue
   240  					}
   241  					mv, exists := mp[v]
   242  					if !exists {
   243  						mp[v] = make(map[string]interface{})
   244  						mp = mp[v].(map[string]interface{})
   245  						continue
   246  					}
   247  					mp = mv.(map[string]interface{})
   248  				}
   249  				break
   250  			}
   251  			v, err := parseCustomFieldValues(m, args[2:])
   252  			if err != nil {
   253  				return fmt.Errorf("failed transforming %q field for %q action in %v: %v", k, args[0], args, err)
   254  			}
   255  			m[args[1]] = v
   256  		}
   257  	case "overwrite":
   258  		switch dt {
   259  		case "list_str":
   260  			m[k] = append([]string{}, args[2:]...)
   261  		case "str":
   262  			m[k] = strings.Join(args[2:], " ")
   263  		default:
   264  			return fmt.Errorf("unsupported %q field for %q action in %v", k, args[0], args)
   265  		}
   266  	case "drop":
   267  		if len(args) != 3 {
   268  			return fmt.Errorf("malformed %q action in %v", args[0], args)
   269  		}
   270  		if args[1] != "matched" || args[2] != "role" {
   271  			return fmt.Errorf("malformed %q action in %v", args[0], args)
   272  		}
   273  
   274  		if args[1] == "matched" && args[2] == "role" {
   275  			if _, exists := m["roles"]; exists {
   276  				var entries, newEntries []string
   277  				switch val := m["roles"].(type) {
   278  				case []string:
   279  					entries = val
   280  				case []interface{}:
   281  					for _, entry := range val {
   282  						switch e := entry.(type) {
   283  						case string:
   284  							entries = append(entries, e)
   285  						}
   286  						return fmt.Errorf("failed to %q action in %v due to unsupported data type inside the input data", args[0], args)
   287  					}
   288  				default:
   289  					return fmt.Errorf("failed to %q action in %v due to unsupported data type inside the input data", args[0], args)
   290  				}
   291  
   292  				for _, e := range entries {
   293  					em := map[string]interface{}{
   294  						"roles": []string{e},
   295  					}
   296  					if matched := matcher.Allow(context.Background(), em); matched {
   297  						continue
   298  					}
   299  					newEntries = append(newEntries, e)
   300  
   301  				}
   302  				m["roles"] = newEntries
   303  			}
   304  		}
   305  	default:
   306  		return fmt.Errorf("unsupported %q action in %v", args[0], args)
   307  	}
   308  	return nil
   309  }
   310  
   311  func parseCustomFieldValues(m map[string]interface{}, args []string) (interface{}, error) {
   312  	var x int
   313  	for i, arg := range args {
   314  		if arg == "as" {
   315  			x = i
   316  			break
   317  		}
   318  	}
   319  	if x == 0 {
   320  		return nil, fmt.Errorf("as type directive not found")
   321  	}
   322  	if len(args[x:]) < 2 {
   323  		return nil, fmt.Errorf("as type directive is too short")
   324  	}
   325  	dt := strings.Join(args[x+1:], "_")
   326  	switch dt {
   327  	case "string_list", "list":
   328  		values, err := replArr(m, args[:x])
   329  		if err != nil {
   330  			return nil, err
   331  		}
   332  		return values, nil
   333  	case "string":
   334  		value, err := repl(m, args[x-1])
   335  		if err != nil {
   336  			return nil, err
   337  		}
   338  		return value, nil
   339  	}
   340  	return nil, fmt.Errorf("unsupported %q data type", dt)
   341  }
   342  
   343  func parseCustomNestedFieldValues(args []string) ([]string, interface{}, error) {
   344  	var x, y int
   345  	for i, arg := range args {
   346  		if arg == "with" {
   347  			y = i
   348  		}
   349  		if arg == "as" {
   350  			x = i
   351  			break
   352  		}
   353  	}
   354  	if x == 0 {
   355  		return nil, nil, fmt.Errorf("as type directive not found")
   356  	}
   357  	if len(args[x:]) < 2 {
   358  		return nil, nil, fmt.Errorf("as type directive is too short")
   359  	}
   360  
   361  	dt := strings.Join(args[x+1:], "_")
   362  	args = args[:x]
   363  
   364  	if (dt != "map") && (y < 1) {
   365  		return nil, nil, fmt.Errorf("the with keyword not found")
   366  	}
   367  
   368  	switch dt {
   369  	case "string_list", "list":
   370  		return args[:y], args[y+1:], nil
   371  	case "string":
   372  		return args[:y], args[y+1], nil
   373  	case "map":
   374  		m := make(map[string]interface{})
   375  		return args, m, nil
   376  	}
   377  	return nil, nil, fmt.Errorf("unsupported %q data type", dt)
   378  }
   379  
   380  func hasReplPattern(s string) bool {
   381  	if strings.IndexRune(s, '{') < 0 {
   382  		return false
   383  	}
   384  	if strings.IndexRune(s, '}') < 0 {
   385  		return false
   386  	}
   387  	return true
   388  }
   389  
   390  func getReplPattern(s string) string {
   391  	i := strings.IndexRune(s, '{')
   392  	j := strings.IndexRune(s, '}')
   393  	return string(s[i : j+1])
   394  }
   395  
   396  func getReplKey(s string) string {
   397  	i := strings.IndexRune(s, '.')
   398  	return string(s[i+1 : len(s)-1])
   399  }
   400  
   401  func getReplValue(m map[string]interface{}, s string) (string, error) {
   402  	var value string
   403  	v, exists := m[s]
   404  	if !exists {
   405  		return value, fmt.Errorf("transform replace field %q not found", s)
   406  	}
   407  	switch val := v.(type) {
   408  	case string:
   409  		value = val
   410  	default:
   411  		return "", fmt.Errorf("transform replace field %q value type %T is unsupported", s, val)
   412  	}
   413  	return value, nil
   414  }
   415  
   416  func repl(m map[string]interface{}, s string) (string, error) {
   417  	for {
   418  		if !hasReplPattern(s) {
   419  			break
   420  		}
   421  		ptrn := getReplPattern(s)
   422  		if !strings.HasPrefix(ptrn, "{claims.") {
   423  			return "", fmt.Errorf("transform replace pattern %q is unsupported", ptrn)
   424  		}
   425  		v, err := getReplValue(m, getReplKey(ptrn))
   426  		if err != nil {
   427  			return "", err
   428  		}
   429  		s = strings.ReplaceAll(s, ptrn, v)
   430  	}
   431  	return s, nil
   432  }
   433  
   434  func replArr(m map[string]interface{}, arr []string) ([]string, error) {
   435  	var values []string
   436  	for _, s := range arr {
   437  		value, err := repl(m, s)
   438  		if err != nil {
   439  			return values, err
   440  		}
   441  		values = append(values, value)
   442  	}
   443  	return values, nil
   444  }