github.com/argoproj/argo-events@v1.9.1/eventsources/sources/awssns/start.go (about)

     1  /*
     2  Copyright 2018 BlackRock, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  	http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package awssns
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"encoding/pem"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"net/url"
    30  	"reflect"
    31  	"regexp"
    32  	"time"
    33  
    34  	"github.com/aws/aws-sdk-go/aws"
    35  	snslib "github.com/aws/aws-sdk-go/service/sns"
    36  	"github.com/ghodss/yaml"
    37  	"go.uber.org/zap"
    38  
    39  	"github.com/argoproj/argo-events/common"
    40  	"github.com/argoproj/argo-events/common/logging"
    41  	eventsourcecommon "github.com/argoproj/argo-events/eventsources/common"
    42  	commonaws "github.com/argoproj/argo-events/eventsources/common/aws"
    43  	"github.com/argoproj/argo-events/eventsources/common/webhook"
    44  	"github.com/argoproj/argo-events/eventsources/sources"
    45  	metrics "github.com/argoproj/argo-events/metrics"
    46  	apicommon "github.com/argoproj/argo-events/pkg/apis/common"
    47  	"github.com/argoproj/argo-events/pkg/apis/events"
    48  	"github.com/argoproj/argo-events/pkg/apis/eventsource/v1alpha1"
    49  )
    50  
    51  var (
    52  	// controller controls the webhook operations
    53  	controller = webhook.NewController()
    54  
    55  	// used for SNS verification
    56  	snsSigKeys      = map[string][]string{}
    57  	snsKeyRealNames = map[string]string{
    58  		"MessageID": "MessageId",
    59  		"TopicARN":  "TopicArn",
    60  	}
    61  )
    62  
    63  // set up route activation and deactivation channels
    64  func init() {
    65  	go webhook.ProcessRouteStatus(controller)
    66  
    67  	snsSigKeys[messageTypeNotification] = []string{
    68  		"Message",
    69  		"MessageID",
    70  		"Subject",
    71  		"Timestamp",
    72  		"TopicARN",
    73  		"Type",
    74  	}
    75  	snsSigKeys[messageTypeSubscriptionConfirmation] = []string{
    76  		"Message",
    77  		"MessageID",
    78  		"SubscribeURL",
    79  		"Timestamp",
    80  		"Token",
    81  		"TopicARN",
    82  		"Type",
    83  	}
    84  }
    85  
    86  // Implement Router
    87  // 1. GetRoute
    88  // 2. HandleRoute
    89  // 3. PostActivate
    90  // 4. PostDeactivate
    91  
    92  // GetRoute returns the route
    93  func (router *Router) GetRoute() *webhook.Route {
    94  	return router.Route
    95  }
    96  
    97  // HandleRoute handles new routes
    98  func (router *Router) HandleRoute(writer http.ResponseWriter, request *http.Request) {
    99  	route := router.Route
   100  
   101  	logger := route.Logger.With(
   102  		logging.LabelEndpoint, route.Context.Endpoint,
   103  		logging.LabelPort, route.Context.Port,
   104  		logging.LabelHTTPMethod, route.Context.Method,
   105  	)
   106  
   107  	logger.Info("request received from event source")
   108  
   109  	if !route.Active {
   110  		logger.Info("endpoint is not active, won't process the request")
   111  		common.SendErrorResponse(writer, "inactive endpoint")
   112  		return
   113  	}
   114  
   115  	defer func(start time.Time) {
   116  		route.Metrics.EventProcessingDuration(route.EventSourceName, route.EventName, float64(time.Since(start)/time.Millisecond))
   117  	}(time.Now())
   118  
   119  	request.Body = http.MaxBytesReader(writer, request.Body, route.Context.GetMaxPayloadSize())
   120  	body, err := io.ReadAll(request.Body)
   121  	if err != nil {
   122  		logger.Errorw("failed to parse the request body", zap.Error(err))
   123  		common.SendErrorResponse(writer, err.Error())
   124  		route.Metrics.EventProcessingFailed(route.EventSourceName, route.EventName)
   125  		return
   126  	}
   127  
   128  	var notification *httpNotification
   129  	err = yaml.Unmarshal(body, &notification)
   130  	if err != nil {
   131  		logger.Errorw("failed to convert request payload into sns notification", zap.Error(err))
   132  		common.SendErrorResponse(writer, err.Error())
   133  		route.Metrics.EventProcessingFailed(route.EventSourceName, route.EventName)
   134  		return
   135  	}
   136  
   137  	if notification == nil {
   138  		common.SendErrorResponse(writer, "bad request, not a valid SNS notification")
   139  		return
   140  	}
   141  
   142  	// SNS Signature Verification
   143  	if router.eventSource.ValidateSignature {
   144  		err = notification.verify()
   145  		if err != nil {
   146  			logger.Errorw("failed to verify sns message", zap.Error(err))
   147  			common.SendErrorResponse(writer, err.Error())
   148  			route.Metrics.EventProcessingFailed(route.EventSourceName, route.EventName)
   149  			return
   150  		}
   151  	}
   152  
   153  	switch notification.Type {
   154  	case messageTypeSubscriptionConfirmation:
   155  		awsSession := router.session
   156  
   157  		response, err := awsSession.ConfirmSubscription(&snslib.ConfirmSubscriptionInput{
   158  			TopicArn: &router.eventSource.TopicArn,
   159  			Token:    &notification.Token,
   160  		})
   161  		if err != nil {
   162  			logger.Errorw("failed to send confirmation response to aws sns", zap.Error(err))
   163  			common.SendErrorResponse(writer, err.Error())
   164  			route.Metrics.EventProcessingFailed(route.EventSourceName, route.EventName)
   165  			return
   166  		}
   167  
   168  		logger.Info("subscription successfully confirmed to aws sns")
   169  		router.subscriptionArn = response.SubscriptionArn
   170  
   171  	case messageTypeNotification:
   172  		logger.Info("dispatching notification on route's data channel")
   173  
   174  		eventData := &events.SNSEventData{
   175  			Header:   request.Header,
   176  			Body:     (*json.RawMessage)(&body),
   177  			Metadata: router.eventSource.Metadata,
   178  		}
   179  
   180  		eventBytes, err := json.Marshal(eventData)
   181  		if err != nil {
   182  			logger.Errorw("failed to marshal the event data", zap.Error(err))
   183  			common.SendErrorResponse(writer, err.Error())
   184  			route.Metrics.EventProcessingFailed(route.EventSourceName, route.EventName)
   185  			return
   186  		}
   187  		route.DataCh <- eventBytes
   188  	}
   189  
   190  	logger.Info("request has been successfully processed")
   191  }
   192  
   193  // PostActivate refers to operations performed after a route is successfully activated
   194  func (router *Router) PostActivate() error {
   195  	route := router.Route
   196  
   197  	logger := route.Logger.With(
   198  		logging.LabelEndpoint, route.Context.Endpoint,
   199  		logging.LabelPort, route.Context.Port,
   200  		logging.LabelHTTPMethod, route.Context.Method,
   201  		"topic-arn", router.eventSource.TopicArn,
   202  	)
   203  
   204  	// In order to successfully subscribe to sns topic,
   205  	// 1. Fetch credentials if configured explicitly. Users can use something like https://github.com/jtblin/kube2iam
   206  	//    which will help not configure creds explicitly.
   207  	// 2. Get AWS session
   208  	// 3. Subscribe to a topic
   209  
   210  	logger.Info("subscribing to sns topic...")
   211  
   212  	snsEventSource := router.eventSource
   213  
   214  	awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(snsEventSource.Region, snsEventSource.RoleARN, snsEventSource.AccessKey, snsEventSource.SecretKey, nil)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	if snsEventSource.Endpoint == "" {
   220  		router.session = snslib.New(awsSession)
   221  	} else {
   222  		router.session = snslib.New(awsSession, &aws.Config{Endpoint: &snsEventSource.Endpoint, Region: &snsEventSource.Region})
   223  	}
   224  
   225  	formattedURL := common.FormattedURL(snsEventSource.Webhook.URL, snsEventSource.Webhook.Endpoint)
   226  	if _, err := router.session.Subscribe(&snslib.SubscribeInput{
   227  		Endpoint: &formattedURL,
   228  		Protocol: func(endpoint string) *string {
   229  			Protocol := "http"
   230  			if matched, _ := regexp.MatchString(`https://.*`, endpoint); matched {
   231  				Protocol = "https"
   232  				return &Protocol
   233  			}
   234  			return &Protocol
   235  		}(formattedURL),
   236  		TopicArn: &snsEventSource.TopicArn,
   237  	}); err != nil {
   238  		return err
   239  	}
   240  
   241  	return nil
   242  }
   243  
   244  // PostInactivate refers to operations performed after a route is successfully inactivated
   245  func (router *Router) PostInactivate() error {
   246  	// After event source is removed, the subscription is cancelled.
   247  	if _, err := router.session.Unsubscribe(&snslib.UnsubscribeInput{
   248  		SubscriptionArn: router.subscriptionArn,
   249  	}); err != nil {
   250  		return err
   251  	}
   252  	return nil
   253  }
   254  
   255  // EventListener implements Eventing for aws sns event source
   256  type EventListener struct {
   257  	EventSourceName string
   258  	EventName       string
   259  	SNSEventSource  v1alpha1.SNSEventSource
   260  	Metrics         *metrics.Metrics
   261  }
   262  
   263  // GetEventSourceName returns name of event source
   264  func (el *EventListener) GetEventSourceName() string {
   265  	return el.EventSourceName
   266  }
   267  
   268  // GetEventName returns name of event
   269  func (el *EventListener) GetEventName() string {
   270  	return el.EventName
   271  }
   272  
   273  // GetEventSourceType return type of event server
   274  func (el *EventListener) GetEventSourceType() apicommon.EventSourceType {
   275  	return apicommon.SNSEvent
   276  }
   277  
   278  // StartListening starts an SNS event source
   279  func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byte, ...eventsourcecommon.Option) error) error {
   280  	logger := logging.FromContext(ctx).
   281  		With(logging.LabelEventSourceType, el.GetEventSourceType(), logging.LabelEventName, el.GetEventName())
   282  
   283  	defer sources.Recover(el.GetEventName())
   284  
   285  	logger.Info("started processing the AWS SNS event source...")
   286  
   287  	route := webhook.NewRoute(el.SNSEventSource.Webhook, logger, el.GetEventSourceName(), el.GetEventName(), el.Metrics)
   288  
   289  	logger.Info("operating on the route...")
   290  	return webhook.ManageRoute(ctx, &Router{
   291  		Route:       route,
   292  		eventSource: &el.SNSEventSource,
   293  	}, controller, dispatch)
   294  }
   295  
   296  func (m *httpNotification) verifySigningCertUrl() error {
   297  	regexSigningCertHost := `^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$`
   298  	regex := regexp.MustCompile(regexSigningCertHost)
   299  	url, err := url.Parse(m.SigningCertURL)
   300  	if err != nil {
   301  		return fmt.Errorf("SigningCertURL is not a valid URL, %w", err)
   302  	}
   303  	if !regex.MatchString(url.Hostname()) {
   304  		return fmt.Errorf("SigningCertURL hostname `%s` does not match `%s`", url.Hostname(), regexSigningCertHost)
   305  	}
   306  	if url.Scheme != "https" {
   307  		return fmt.Errorf("SigningCertURL is not using https")
   308  	}
   309  	return nil
   310  }
   311  
   312  func (m *httpNotification) verify() error {
   313  	msgSig, err := base64.StdEncoding.DecodeString(m.Signature)
   314  	if err != nil {
   315  		return fmt.Errorf("failed to base64 decode signature, %w", err)
   316  	}
   317  
   318  	if err := m.verifySigningCertUrl(); err != nil {
   319  		return fmt.Errorf("failed to verify SigningCertURL, %w", err)
   320  	}
   321  
   322  	res, err := http.Get(m.SigningCertURL)
   323  	if err != nil {
   324  		return fmt.Errorf("failed to fetch signing cert, %w", err)
   325  	}
   326  	defer res.Body.Close()
   327  
   328  	body, err := io.ReadAll(io.LimitReader(res.Body, 65*1024))
   329  	if err != nil {
   330  		return fmt.Errorf("failed to read signing cert body, %w", err)
   331  	}
   332  
   333  	p, _ := pem.Decode(body)
   334  	if p == nil {
   335  		return fmt.Errorf("nothing found in pem encoded bytes")
   336  	}
   337  
   338  	cert, err := x509.ParseCertificate(p.Bytes)
   339  	if err != nil {
   340  		return fmt.Errorf("failed to parse signing cert, %w", err)
   341  	}
   342  
   343  	err = cert.CheckSignature(x509.SHA1WithRSA, m.sigSerialized(), msgSig)
   344  	if err != nil {
   345  		return fmt.Errorf("message signature check error, %w", err)
   346  	}
   347  
   348  	return nil
   349  }
   350  
   351  func (m *httpNotification) sigSerialized() []byte {
   352  	buf := &bytes.Buffer{}
   353  	v := reflect.ValueOf(m)
   354  
   355  	for _, key := range snsSigKeys[m.Type] {
   356  		field := reflect.Indirect(v).FieldByName(key)
   357  		val := field.String()
   358  		if !field.IsValid() || val == "" {
   359  			continue
   360  		}
   361  		if rn, ok := snsKeyRealNames[key]; ok {
   362  			key = rn
   363  		}
   364  		buf.WriteString(key + "\n")
   365  		buf.WriteString(val + "\n")
   366  	}
   367  
   368  	return buf.Bytes()
   369  }