yunion.io/x/cloudmux@v0.3.10-0-alpha.1/pkg/multicloud/aws/waf.go (about)

     1  // Copyright 2019 Yunion
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aws
    16  
    17  import (
    18  	"strconv"
    19  	"strings"
    20  
    21  	"github.com/aws/aws-sdk-go/service/wafv2"
    22  
    23  	"yunion.io/x/jsonutils"
    24  	"yunion.io/x/pkg/errors"
    25  
    26  	api "yunion.io/x/cloudmux/pkg/apis/compute"
    27  	"yunion.io/x/cloudmux/pkg/cloudprovider"
    28  	"yunion.io/x/cloudmux/pkg/multicloud"
    29  )
    30  
    31  const (
    32  	SCOPE_REGIONAL   = "REGIONAL"
    33  	SCOPE_CLOUDFRONT = "CLOUDFRONT"
    34  )
    35  
    36  var (
    37  	WAF_SCOPES = []string{
    38  		SCOPE_REGIONAL,
    39  		SCOPE_CLOUDFRONT,
    40  	}
    41  )
    42  
    43  type SWafRule struct {
    44  	Action struct {
    45  		Block struct {
    46  		} `json:"Block"`
    47  	} `json:"Action"`
    48  	Name string `json:"Name"`
    49  }
    50  
    51  type SVisibilityConfig struct {
    52  	CloudWatchMetricsEnabled bool
    53  	MetricName               string
    54  	SampledRequestsEnabled   bool
    55  }
    56  
    57  type SWebAcl struct {
    58  	multicloud.SResourceBase
    59  	AwsTags
    60  	region *SRegion
    61  	*wafv2.WebACL
    62  
    63  	scope     string
    64  	LockToken string
    65  }
    66  
    67  func (self *SRegion) ListWebACLs(scope string) ([]SWebAcl, error) {
    68  	if scope == SCOPE_CLOUDFRONT && self.RegionId != "us-east-1" {
    69  		return []SWebAcl{}, nil
    70  	}
    71  	client, err := self.getWafClient()
    72  	if err != nil {
    73  		return nil, errors.Wrapf(err, "getWafClient")
    74  	}
    75  	ret := []SWebAcl{}
    76  	input := wafv2.ListWebACLsInput{}
    77  	input.SetScope(scope)
    78  	for {
    79  		resp, err := client.ListWebACLs(&input)
    80  		if err != nil {
    81  			return nil, errors.Wrapf(err, "ListWebACLs")
    82  		}
    83  		part := []SWebAcl{}
    84  		jsonutils.Update(&part, resp.WebACLs)
    85  		ret = append(ret, part...)
    86  		if resp.NextMarker == nil || len(*resp.NextMarker) == 0 {
    87  			break
    88  		}
    89  		input.SetNextMarker(*resp.NextMarker)
    90  	}
    91  	return ret, nil
    92  }
    93  
    94  func (self *SRegion) GetWebAcl(id, name, scope string) (*SWebAcl, error) {
    95  	client, err := self.getWafClient()
    96  	if err != nil {
    97  		return nil, errors.Wrapf(err, "getWafClient")
    98  	}
    99  	input := wafv2.GetWebACLInput{}
   100  	input.SetId(id)
   101  	input.SetName(name)
   102  	input.SetScope(scope)
   103  	resp, err := client.GetWebACL(&input)
   104  	if err != nil {
   105  		if _, ok := err.(*wafv2.WAFNonexistentItemException); ok {
   106  			return nil, errors.Wrapf(cloudprovider.ErrNotFound, err.Error())
   107  		}
   108  		return nil, errors.Wrapf(err, "GetWebAcl")
   109  	}
   110  	ret := &SWebAcl{region: self, scope: scope, WebACL: resp.WebACL, LockToken: *resp.LockToken}
   111  	return ret, nil
   112  }
   113  
   114  func (self *SRegion) DeleteWebAcl(id, name, scope, lockToken string) error {
   115  	client, err := self.getWafClient()
   116  	if err != nil {
   117  		return errors.Wrapf(err, "getWafClient")
   118  	}
   119  	input := wafv2.DeleteWebACLInput{}
   120  	input.SetId(id)
   121  	input.SetName(name)
   122  	input.SetScope(scope)
   123  	input.SetLockToken(lockToken)
   124  	_, err = client.DeleteWebACL(&input)
   125  	return errors.Wrapf(err, "DeleteWebACL")
   126  }
   127  
   128  func (self *SRegion) ListResourcesForWebACL(resType, arn string) ([]string, error) {
   129  	client, err := self.getWafClient()
   130  	if err != nil {
   131  		return nil, errors.Wrapf(err, "getWafClient")
   132  	}
   133  	input := wafv2.ListResourcesForWebACLInput{}
   134  	input.SetResourceType(resType)
   135  	input.SetWebACLArn(arn)
   136  	resp, err := client.ListResourcesForWebACL(&input)
   137  	if err != nil {
   138  		return nil, errors.Wrapf(err, "ListResourcesForWebACL")
   139  	}
   140  	ret := []string{}
   141  	for _, id := range resp.ResourceArns {
   142  		ret = append(ret, *id)
   143  	}
   144  	return ret, nil
   145  }
   146  
   147  func (self *SRegion) GetICloudWafInstanceById(id string) (cloudprovider.ICloudWafInstance, error) {
   148  	idInfo := strings.Split(id, "/")
   149  	if len(idInfo) != 4 {
   150  		return nil, errors.Wrapf(cloudprovider.ErrNotFound, "invalid arn %s", id)
   151  	}
   152  	scope := SCOPE_CLOUDFRONT
   153  	if strings.HasSuffix(idInfo[0], "regional") {
   154  		scope = SCOPE_REGIONAL
   155  	}
   156  	ins, err := self.GetWebAcl(idInfo[3], idInfo[2], scope)
   157  	if err != nil {
   158  		return nil, errors.Wrapf(err, "GetWebAcl(%s, %s, %s)", idInfo[3], idInfo[2], scope)
   159  	}
   160  	return ins, nil
   161  }
   162  
   163  func (self *SRegion) GetICloudWafInstances() ([]cloudprovider.ICloudWafInstance, error) {
   164  	ret := []cloudprovider.ICloudWafInstance{}
   165  	for _, scope := range WAF_SCOPES {
   166  		ins, err := self.ListWebACLs(scope)
   167  		if err != nil {
   168  			return nil, errors.Wrapf(err, "ListWebACLs")
   169  		}
   170  		for i := range ins {
   171  			ins[i].region = self
   172  			ins[i].scope = scope
   173  			ret = append(ret, &ins[i])
   174  		}
   175  	}
   176  	return ret, nil
   177  }
   178  
   179  func (self *SWebAcl) GetEnabled() bool {
   180  	return true
   181  }
   182  
   183  func (self *SWebAcl) GetGlobalId() string {
   184  	return *self.ARN
   185  }
   186  
   187  func (self *SWebAcl) GetName() string {
   188  	return *self.Name
   189  }
   190  
   191  func (self *SWebAcl) GetId() string {
   192  	return *self.ARN
   193  }
   194  
   195  func (self *SWebAcl) GetWafType() cloudprovider.TWafType {
   196  	if self.scope == SCOPE_CLOUDFRONT {
   197  		return cloudprovider.WafTypeCloudFront
   198  	}
   199  	return cloudprovider.WafTypeRegional
   200  }
   201  
   202  func (self *SWebAcl) GetStatus() string {
   203  	return api.WAF_STATUS_AVAILABLE
   204  }
   205  
   206  func (self *SWebAcl) GetDefaultAction() *cloudprovider.DefaultAction {
   207  	ret := &cloudprovider.DefaultAction{}
   208  	if self.WebACL.DefaultAction == nil {
   209  		self.Refresh()
   210  	}
   211  	if self.WebACL.DefaultAction != nil {
   212  		action := self.WebACL.DefaultAction
   213  		if action.Allow != nil {
   214  			ret.Action = cloudprovider.WafActionAllow
   215  		} else if action.Block != nil {
   216  			ret.Action = cloudprovider.WafActionBlock
   217  		}
   218  	}
   219  	return ret
   220  }
   221  
   222  func (self *SWebAcl) Refresh() error {
   223  	acl, err := self.region.GetWebAcl(*self.Id, *self.Name, self.scope)
   224  	if err != nil {
   225  		return errors.Wrapf(err, "GetWebAcl")
   226  	}
   227  	self.WebACL = acl.WebACL
   228  	return jsonutils.Update(self, acl)
   229  }
   230  
   231  func (self *SWebAcl) Delete() error {
   232  	return self.region.DeleteWebAcl(*self.Id, *self.Name, self.scope, self.LockToken)
   233  }
   234  
   235  func (self *SRegion) CreateICloudWafInstance(opts *cloudprovider.WafCreateOptions) (cloudprovider.ICloudWafInstance, error) {
   236  	waf, err := self.CreateWebAcl(opts.Name, opts.Desc, opts.Type, opts.DefaultAction)
   237  	if err != nil {
   238  		return nil, errors.Wrapf(err, "CreateWebAcl")
   239  	}
   240  	return waf, nil
   241  }
   242  
   243  func (self *SRegion) CreateWebAcl(name, desc string, wafType cloudprovider.TWafType, action *cloudprovider.DefaultAction) (*SWebAcl, error) {
   244  	input := wafv2.CreateWebACLInput{}
   245  	input.SetName(name)
   246  	if len(desc) > 0 {
   247  		input.SetDescription(desc)
   248  	}
   249  	switch wafType {
   250  	case cloudprovider.WafTypeRegional, cloudprovider.WafTypeCloudFront:
   251  		input.SetScope(strings.ToUpper(string(wafType)))
   252  	default:
   253  		return nil, errors.Errorf("invalid waf type %s", wafType)
   254  	}
   255  	if action != nil {
   256  		defaultAction := wafv2.DefaultAction{}
   257  		switch action.Action {
   258  		case cloudprovider.WafActionAllow:
   259  			defaultAction.Allow = &wafv2.AllowAction{}
   260  		case cloudprovider.WafActionBlock:
   261  			defaultAction.Block = &wafv2.BlockAction{}
   262  		}
   263  		input.SetDefaultAction(&defaultAction)
   264  	}
   265  	visib := &wafv2.VisibilityConfig{}
   266  	visib.SetSampledRequestsEnabled(true)
   267  	visib.SetCloudWatchMetricsEnabled(true)
   268  	visib.SetMetricName(name)
   269  	input.SetVisibilityConfig(visib)
   270  	client, err := self.getWafClient()
   271  	if err != nil {
   272  		return nil, errors.Wrapf(err, "getWafClient")
   273  	}
   274  	output, err := client.CreateWebACL(&input)
   275  	if err != nil {
   276  		return nil, errors.Wrapf(err, "CreateWebAcl")
   277  	}
   278  	return self.GetWebAcl(*output.Summary.Id, name, *input.Scope)
   279  }
   280  
   281  func reverseConvertField(opts cloudprovider.SWafStatement) *wafv2.FieldToMatch {
   282  	ret := &wafv2.FieldToMatch{}
   283  	switch opts.MatchField {
   284  	case cloudprovider.WafMatchFieldBody:
   285  		body := &wafv2.Body{}
   286  		ret.SetBody(body)
   287  	case cloudprovider.WafMatchFieldJsonBody:
   288  	case cloudprovider.WafMatchFieldMethod:
   289  		method := &wafv2.Method{}
   290  		ret.SetMethod(method)
   291  	case cloudprovider.WafMatchFieldQuery:
   292  		switch opts.MatchFieldKey {
   293  		case "SingleArgument":
   294  			query := &wafv2.SingleQueryArgument{}
   295  			ret.SetSingleQueryArgument(query)
   296  		case "AllArguments":
   297  			query := &wafv2.AllQueryArguments{}
   298  			ret.SetAllQueryArguments(query)
   299  		default:
   300  			query := &wafv2.QueryString{}
   301  			ret.SetQueryString(query)
   302  		}
   303  	case cloudprovider.WafMatchFiledHeader:
   304  		head := &wafv2.SingleHeader{}
   305  		head.SetName(opts.MatchFieldKey)
   306  		ret.SetSingleHeader(head)
   307  	case cloudprovider.WafMatchFiledUriPath:
   308  		uri := &wafv2.UriPath{}
   309  		ret.SetUriPath(uri)
   310  	}
   311  	return ret
   312  }
   313  
   314  func reverseConvertStatement(statement cloudprovider.SWafStatement) *wafv2.Statement {
   315  	ret := &wafv2.Statement{}
   316  	trans := []*wafv2.TextTransformation{}
   317  	if statement.Transformations != nil {
   318  		for i, tran := range *statement.Transformations {
   319  			t := &wafv2.TextTransformation{}
   320  			switch tran {
   321  			case cloudprovider.WafTextTransformationNone:
   322  				t.SetType(wafv2.TextTransformationTypeNone)
   323  			case cloudprovider.WafTextTransformationLowercase:
   324  				t.SetType(wafv2.TextTransformationTypeLowercase)
   325  			case cloudprovider.WafTextTransformationCmdLine:
   326  				t.SetType(wafv2.TextTransformationTypeCmdLine)
   327  			case cloudprovider.WafTextTransformationUrlDecode:
   328  				t.SetType(wafv2.TextTransformationTypeUrlDecode)
   329  			case cloudprovider.WafTextTransformationHtmlEntityDecode:
   330  				t.SetType(wafv2.TextTransformationTypeHtmlEntityDecode)
   331  			case cloudprovider.WafTextTransformationCompressWithSpace:
   332  				t.SetType(wafv2.TextTransformationTypeCompressWhiteSpace)
   333  			}
   334  			t.SetPriority(int64(i))
   335  			trans = append(trans, t)
   336  		}
   337  	}
   338  	rules := []*wafv2.ExcludedRule{}
   339  	if statement.ExcludeRules != nil {
   340  		for _, r := range *statement.ExcludeRules {
   341  			name := r.Name
   342  			rules = append(rules, &wafv2.ExcludedRule{
   343  				Name: &name,
   344  			})
   345  		}
   346  	}
   347  	field := reverseConvertField(statement)
   348  	switch statement.Type {
   349  	case cloudprovider.WafStatementTypeRate:
   350  		rate := &wafv2.RateBasedStatement{}
   351  		limit := int(0)
   352  		if statement.MatchFieldValues != nil && len(*statement.MatchFieldValues) == 1 {
   353  			limit, _ = strconv.Atoi((*statement.MatchFieldValues)[0])
   354  		}
   355  		rate.SetLimit(int64(limit))
   356  		fd := &wafv2.ForwardedIPConfig{}
   357  		if len(statement.ForwardedIPHeader) > 0 {
   358  			fd.SetHeaderName(statement.ForwardedIPHeader)
   359  			rate.SetForwardedIPConfig(fd)
   360  		}
   361  		ret.SetRateBasedStatement(rate)
   362  	case cloudprovider.WafStatementTypeIPSet:
   363  		ipset := &wafv2.IPSetReferenceStatement{}
   364  		ipset.SetARN(statement.IPSetId)
   365  		fd := &wafv2.IPSetForwardedIPConfig{}
   366  		if len(statement.ForwardedIPHeader) > 0 {
   367  			fd.SetHeaderName(statement.ForwardedIPHeader)
   368  			ipset.SetIPSetForwardedIPConfig(fd)
   369  		}
   370  		ret.SetIPSetReferenceStatement(ipset)
   371  	case cloudprovider.WafStatementTypeXssMatch:
   372  		xss := &wafv2.XssMatchStatement{}
   373  		if len(trans) > 0 {
   374  			xss.SetTextTransformations(trans)
   375  		}
   376  		field := &wafv2.FieldToMatch{}
   377  		xss.SetFieldToMatch(field)
   378  		xss.SetTextTransformations(trans)
   379  		ret.SetXssMatchStatement(xss)
   380  	case cloudprovider.WafStatementTypeSize:
   381  		size := &wafv2.SizeConstraintStatement{}
   382  		size.SetFieldToMatch(field)
   383  		value := int(0)
   384  		if statement.MatchFieldValues != nil && len(*statement.MatchFieldValues) == 1 {
   385  			value, _ = strconv.Atoi((*statement.MatchFieldValues)[0])
   386  		}
   387  		size.SetSize(int64(value))
   388  		ret.SetSizeConstraintStatement(size)
   389  	case cloudprovider.WafStatementTypeGeoMatch:
   390  		geo := &wafv2.GeoMatchStatement{}
   391  		values := []*string{}
   392  		if statement.MatchFieldValues != nil {
   393  			for i := range *statement.MatchFieldValues {
   394  				v := (*statement.MatchFieldValues)[i]
   395  				values = append(values, &v)
   396  			}
   397  			geo.SetCountryCodes(values)
   398  		}
   399  		fd := &wafv2.ForwardedIPConfig{}
   400  		if len(statement.ForwardedIPHeader) > 0 {
   401  			fd.SetHeaderName(statement.ForwardedIPHeader)
   402  			geo.SetForwardedIPConfig(fd)
   403  		}
   404  		ret.SetGeoMatchStatement(geo)
   405  	case cloudprovider.WafStatementTypeRegexSet:
   406  		regex := &wafv2.RegexPatternSetReferenceStatement{}
   407  		regex.SetARN(statement.RegexSetId)
   408  		if len(trans) > 0 {
   409  			regex.SetTextTransformations(trans)
   410  		}
   411  		regex.SetFieldToMatch(field)
   412  		ret.SetRegexPatternSetReferenceStatement(regex)
   413  	case cloudprovider.WafStatementTypeByteMatch:
   414  		bm := &wafv2.ByteMatchStatement{}
   415  		if len(trans) > 0 {
   416  			bm.SetTextTransformations(trans)
   417  		}
   418  		bm.SetSearchString([]byte(statement.SearchString))
   419  		if len(statement.Operator) > 0 {
   420  			bm.SetPositionalConstraint(string(statement.Operator))
   421  		}
   422  		bm.SetFieldToMatch(field)
   423  		ret.SetByteMatchStatement(bm)
   424  	case cloudprovider.WafStatementTypeRuleGroup:
   425  		rg := &wafv2.RuleGroupReferenceStatement{}
   426  		rg.SetARN(statement.RuleGroupId)
   427  		if len(rules) > 0 {
   428  			rg.SetExcludedRules(rules)
   429  		}
   430  		ret.SetRuleGroupReferenceStatement(rg)
   431  	case cloudprovider.WafStatementTypeSqliMatch:
   432  		sqli := &wafv2.SqliMatchStatement{}
   433  		if len(trans) > 0 {
   434  			sqli.SetTextTransformations(trans)
   435  		}
   436  		sqli.SetFieldToMatch(field)
   437  		ret.SetSqliMatchStatement(sqli)
   438  	case cloudprovider.WafStatementTypeLabelMatch:
   439  	case cloudprovider.WafStatementTypeManagedRuleGroup:
   440  		rg := &wafv2.ManagedRuleGroupStatement{}
   441  		rg.SetName(statement.ManagedRuleGroupName)
   442  		rg.SetVendorName("aws")
   443  		if len(rules) > 0 {
   444  			rg.SetExcludedRules(rules)
   445  		}
   446  		ret.SetManagedRuleGroupStatement(rg)
   447  	}
   448  	return ret
   449  }
   450  
   451  func (self *SWebAcl) AddRule(opts *cloudprovider.SWafRule) (cloudprovider.ICloudWafRule, error) {
   452  	input := &wafv2.UpdateWebACLInput{}
   453  	input.SetLockToken(self.LockToken)
   454  	input.SetId(*self.Id)
   455  	input.SetName(*self.Name)
   456  	input.SetScope(self.scope)
   457  	input.SetDescription(*self.Description)
   458  	input.SetDefaultAction(self.DefaultAction)
   459  	input.SetVisibilityConfig(self.WebACL.VisibilityConfig)
   460  	rules := self.Rules
   461  	rule := &wafv2.Rule{}
   462  	rule.SetName(opts.Name)
   463  	rule.SetPriority(int64(opts.Priority))
   464  	action := &wafv2.RuleAction{}
   465  	if opts.Action != nil {
   466  		switch opts.Action.Action {
   467  		case cloudprovider.WafActionAllow:
   468  			allow := &wafv2.AllowAction{}
   469  			action.SetAllow(allow)
   470  		case cloudprovider.WafActionBlock:
   471  			block := &wafv2.BlockAction{}
   472  			action.SetBlock(block)
   473  		case cloudprovider.WafActionCount:
   474  			count := &wafv2.CountAction{}
   475  			action.SetCount(count)
   476  		}
   477  	}
   478  	rule.SetAction(action)
   479  	visib := &wafv2.VisibilityConfig{}
   480  	visib.SetSampledRequestsEnabled(false)
   481  	visib.SetCloudWatchMetricsEnabled(true)
   482  	visib.SetMetricName(opts.Name)
   483  	rule.SetVisibilityConfig(visib)
   484  	statement := &wafv2.Statement{}
   485  	switch opts.StatementCondition {
   486  	case cloudprovider.WafStatementConditionOr:
   487  		ss := &wafv2.OrStatement{}
   488  		for _, s := range opts.Statements {
   489  			ss.Statements = append(ss.Statements, reverseConvertStatement(s))
   490  		}
   491  		statement.SetOrStatement(ss)
   492  	case cloudprovider.WafStatementConditionAnd:
   493  		ss := &wafv2.AndStatement{}
   494  		for _, s := range opts.Statements {
   495  			ss.Statements = append(ss.Statements, reverseConvertStatement(s))
   496  		}
   497  		statement.SetAndStatement(ss)
   498  	case cloudprovider.WafStatementConditionNot:
   499  		ss := &wafv2.NotStatement{}
   500  		for _, s := range opts.Statements {
   501  			ss.SetStatement(reverseConvertStatement(s))
   502  			break
   503  		}
   504  		statement.SetNotStatement(ss)
   505  	case cloudprovider.WafStatementConditionNone:
   506  		for _, s := range opts.Statements {
   507  			statement = reverseConvertStatement(s)
   508  			break
   509  		}
   510  	}
   511  	rule.SetStatement(statement)
   512  	rules = append(rules, rule)
   513  	input.SetRules(rules)
   514  	client, err := self.region.getWafClient()
   515  	if err != nil {
   516  		return nil, errors.Wrapf(err, "getWafClient")
   517  	}
   518  	_, err = client.UpdateWebACL(input)
   519  	if err != nil {
   520  		return nil, errors.Wrapf(err, "UpdateWebACL")
   521  	}
   522  	ret := &sWafRule{waf: self, Rule: rule}
   523  	return ret, nil
   524  }
   525  
   526  func (self *SWebAcl) GetCloudResources() ([]cloudprovider.SCloudResource, error) {
   527  	ret := []cloudprovider.SCloudResource{}
   528  	if self.scope != SCOPE_REGIONAL {
   529  		return ret, nil
   530  	}
   531  	for _, resType := range []string{"APPLICATION_LOAD_BALANCER", "API_GATEWAY", "APPSYNC"} {
   532  		resIds, err := self.region.ListResourcesForWebACL(resType, *self.ARN)
   533  		if err != nil {
   534  			return nil, errors.Wrapf(err, "ListResourcesForWebACL(%s, %s)", resType, *self.ARN)
   535  		}
   536  		for _, resId := range resIds {
   537  			ret = append(ret, cloudprovider.SCloudResource{
   538  				Id:   resId,
   539  				Name: resId,
   540  				Type: resType,
   541  			})
   542  		}
   543  	}
   544  	return ret, nil
   545  }