github.com/mweagle/Sparta@v1.15.0/aws/step/step.go (about) 1 package step 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "math/rand" 8 "strings" 9 "time" 10 11 "github.com/aws/aws-sdk-go/aws/session" 12 sparta "github.com/mweagle/Sparta" 13 spartaCF "github.com/mweagle/Sparta/aws/cloudformation" 14 spartaIAM "github.com/mweagle/Sparta/aws/iam" 15 gocf "github.com/mweagle/go-cloudformation" 16 "github.com/pkg/errors" 17 "github.com/sirupsen/logrus" 18 ) 19 20 // Types of state machines per 21 // https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-stepfunctions-statemachine.html 22 const ( 23 stateMachineStandard = "STANDARD" 24 stateMachineExpress = "EXPRESS" 25 ) 26 27 // StateError is the reserved type used for AWS Step function error names 28 // Ref: https://states-language.net/spec.html#appendix-a 29 type StateError string 30 31 const ( 32 // StatesAll is a wild-card which matches any Error Name. 33 StatesAll StateError = "States.ALL" 34 // StatesTimeout is a Task State either ran longer than the 35 // “TimeoutSeconds” value, or failed to heartbeat for a time 36 // longer than the “HeartbeatSeconds” value. 37 StatesTimeout StateError = "States.Timeout" 38 // StatesTaskFailed is a Task State failed during the execution 39 StatesTaskFailed StateError = "States.TaskFailed" 40 // StatesPermissions is a Task State failed because it had 41 // insufficient privileges to execute the specified code. 42 StatesPermissions StateError = "States.Permissions" 43 // StatesResultPathMatchFailure is a Task State’s “ResultPath” field 44 // cannot be applied to the input the state received 45 StatesResultPathMatchFailure StateError = "States.ResultPathMatchFailure" 46 // StatesBranchFailed is a branch of a Parallel state failed 47 StatesBranchFailed StateError = "States.BranchFailed" 48 // StatesNoChoiceMatched is a Choice state failed to find a match for the 49 // condition field extracted from its input 50 StatesNoChoiceMatched StateError = "States.NoChoiceMatched" 51 ) 52 53 // MachineState is the base state for all AWS Step function 54 type MachineState interface { 55 Name() string 56 nodeID() string 57 enableEndState(bool) 58 } 59 60 // TransitionState is the generic state according to 61 // https://states-language.net/spec.html#state-type-table 62 type TransitionState interface { 63 MachineState 64 Next(state MachineState) MachineState 65 // AdjacentStates returns all the MachineStates that are reachable from 66 // the current state 67 AdjacentStates() []MachineState 68 WithComment(string) TransitionState 69 WithInputPath(string) TransitionState 70 WithOutputPath(string) TransitionState 71 } 72 73 // Embedding struct for common properties 74 type baseInnerState struct { 75 name string 76 id int64 77 next MachineState 78 comment string 79 inputPath string 80 outputPath string 81 isEndStateInvalid bool 82 } 83 84 func (bis *baseInnerState) nodeID() string { 85 return fmt.Sprintf("%s-%d", bis.name, bis.id) 86 } 87 88 func (bis *baseInnerState) enableEndState(isEnabled bool) { 89 bis.isEndStateInvalid = !isEnabled 90 } 91 92 // marshalStateJSON for subclass marshalling of state information 93 func (bis *baseInnerState) marshalStateJSON(stateType string, 94 additionalData map[string]interface{}) ([]byte, error) { 95 if additionalData == nil { 96 additionalData = make(map[string]interface{}) 97 } 98 additionalData["Type"] = stateType 99 if bis.next != nil { 100 additionalData["Next"] = bis.next.Name() 101 } 102 if bis.comment != "" { 103 additionalData["Comment"] = bis.comment 104 } 105 if bis.inputPath != "" { 106 additionalData["InputPath"] = bis.inputPath 107 } 108 if bis.outputPath != "" { 109 additionalData["OutputPath"] = bis.outputPath 110 } 111 if !bis.isEndStateInvalid && bis.next == nil { 112 additionalData["End"] = true 113 } 114 // Output the pretty version 115 return json.Marshal(additionalData) 116 } 117 118 /******************************************************************************* 119 ___ _____ _ _____ ___ ___ 120 / __|_ _/_\_ _| __/ __| 121 \__ \ | |/ _ \| | | _|\__ \ 122 |___/ |_/_/ \_\_| |___|___/ 123 /******************************************************************************/ 124 125 //////////////////////////////////////////////////////////////////////////////// 126 // PassState 127 //////////////////////////////////////////////////////////////////////////////// 128 129 // PassState represents a NOP state 130 type PassState struct { 131 baseInnerState 132 ResultPath string 133 Result interface{} 134 } 135 136 // WithResultPath is the fluent builder for the result path 137 func (ps *PassState) WithResultPath(resultPath string) *PassState { 138 ps.ResultPath = resultPath 139 return ps 140 } 141 142 // WithResult is the fluent builder for the result data 143 func (ps *PassState) WithResult(result interface{}) *PassState { 144 ps.Result = result 145 return ps 146 } 147 148 // Next returns the next state 149 func (ps *PassState) Next(nextState MachineState) MachineState { 150 ps.next = nextState 151 return ps 152 } 153 154 // AdjacentStates returns nodes reachable from this node 155 func (ps *PassState) AdjacentStates() []MachineState { 156 if ps.next == nil { 157 return nil 158 } 159 return []MachineState{ps.next} 160 } 161 162 // Name returns the name of this Task state 163 func (ps *PassState) Name() string { 164 return ps.name 165 } 166 167 // WithComment returns the TaskState comment 168 func (ps *PassState) WithComment(comment string) TransitionState { 169 ps.comment = comment 170 return ps 171 } 172 173 // WithInputPath returns the TaskState input data selector 174 func (ps *PassState) WithInputPath(inputPath string) TransitionState { 175 ps.inputPath = inputPath 176 return ps 177 } 178 179 // WithOutputPath returns the TaskState output data selector 180 func (ps *PassState) WithOutputPath(outputPath string) TransitionState { 181 ps.outputPath = outputPath 182 return ps 183 } 184 185 // MarshalJSON for custom marshalling 186 func (ps *PassState) MarshalJSON() ([]byte, error) { 187 additionalParams := make(map[string]interface{}) 188 if ps.ResultPath != "" { 189 additionalParams["ResultPath"] = ps.ResultPath 190 } 191 if ps.Result != nil { 192 additionalParams["Result"] = ps.Result 193 } 194 return ps.marshalStateJSON("Pass", additionalParams) 195 } 196 197 // NewPassState returns a new PassState instance 198 func NewPassState(name string, resultData interface{}) *PassState { 199 return &PassState{ 200 baseInnerState: baseInnerState{ 201 name: name, 202 id: rand.Int63(), 203 }, 204 Result: resultData, 205 } 206 } 207 208 //////////////////////////////////////////////////////////////////////////////// 209 // ChoiceState 210 //////////////////////////////////////////////////////////////////////////////// 211 212 // ChoiceState is a synthetic state that executes a lot of independent 213 // branches in parallel 214 type ChoiceState struct { 215 baseInnerState 216 Choices []ChoiceBranch 217 Default TransitionState 218 } 219 220 // WithDefault is the fluent builder for the default state 221 func (cs *ChoiceState) WithDefault(defaultState TransitionState) *ChoiceState { 222 cs.Default = defaultState 223 return cs 224 } 225 226 // WithResultPath is the fluent builder for the result path 227 func (cs *ChoiceState) WithResultPath(resultPath string) *ChoiceState { 228 return cs 229 } 230 231 // Name returns the name of this Task state 232 func (cs *ChoiceState) Name() string { 233 return cs.name 234 } 235 236 // WithComment returns the TaskState comment 237 func (cs *ChoiceState) WithComment(comment string) *ChoiceState { 238 cs.comment = comment 239 return cs 240 } 241 242 // MarshalJSON for custom marshalling 243 func (cs *ChoiceState) MarshalJSON() ([]byte, error) { 244 /* 245 A state in a Parallel state branch “States” field MUST NOT have a “Next” field that targets a field outside of that “States” field. A state MUST NOT have a “Next” field which matches a state name inside a Parallel state branch’s “States” field unless it is also inside the same “States” field. 246 247 Put another way, states in a branch’s “States” field can transition only to each other, and no state outside of that “States” field can transition into it. 248 */ 249 additionalParams := make(map[string]interface{}) 250 additionalParams["Choices"] = cs.Choices 251 if cs.Default != nil { 252 additionalParams["Default"] = cs.Default.Name() 253 } 254 return cs.marshalStateJSON("Choice", additionalParams) 255 } 256 257 // NewChoiceState returns a "ChoiceState" with the supplied 258 // information 259 func NewChoiceState(choiceStateName string, choices ...ChoiceBranch) *ChoiceState { 260 return &ChoiceState{ 261 baseInnerState: baseInnerState{ 262 name: choiceStateName, 263 id: rand.Int63(), 264 isEndStateInvalid: true, 265 }, 266 Choices: append([]ChoiceBranch{}, choices...), 267 } 268 } 269 270 //////////////////////////////////////////////////////////////////////////////// 271 // TaskRetry 272 //////////////////////////////////////////////////////////////////////////////// 273 274 // TaskRetry is an action to perform in response to a Task failure 275 type TaskRetry struct { 276 ErrorEquals []StateError `json:",omitempty"` 277 //lint:ignore ST1011 we want to give a cue to the client of the units 278 IntervalSeconds time.Duration `json:",omitempty"` 279 MaxAttempts int `json:",omitempty"` 280 BackoffRate float32 `json:",omitempty"` 281 } 282 283 // WithErrors is the fluent builder 284 func (tr *TaskRetry) WithErrors(errors ...StateError) *TaskRetry { 285 if tr.ErrorEquals == nil { 286 tr.ErrorEquals = make([]StateError, 0) 287 } 288 tr.ErrorEquals = append(tr.ErrorEquals, errors...) 289 return tr 290 } 291 292 // WithInterval is the fluent builder 293 func (tr *TaskRetry) WithInterval(interval time.Duration) *TaskRetry { 294 tr.IntervalSeconds = interval 295 return tr 296 } 297 298 // WithMaxAttempts is the fluent builder 299 func (tr *TaskRetry) WithMaxAttempts(maxAttempts int) *TaskRetry { 300 tr.MaxAttempts = maxAttempts 301 return tr 302 } 303 304 // WithBackoffRate is the fluent builder 305 func (tr *TaskRetry) WithBackoffRate(backoffRate float32) *TaskRetry { 306 tr.BackoffRate = backoffRate 307 return tr 308 } 309 310 // NewTaskRetry returns a new TaskRetry instance 311 func NewTaskRetry() *TaskRetry { 312 return &TaskRetry{} 313 } 314 315 //////////////////////////////////////////////////////////////////////////////// 316 // TaskCatch 317 //////////////////////////////////////////////////////////////////////////////// 318 319 // TaskCatch is an action to handle a failing operation 320 type TaskCatch struct { 321 /* 322 The reserved name “States.ALL” appearing in a Retrier’s “ErrorEquals” field is a wild-card and matches any Error Name. Such a value MUST appear alone in the “ErrorEquals” array and MUST appear in the last Catcher in the “Catch” array. 323 */ 324 errorEquals []StateError 325 next TransitionState 326 } 327 328 // MarshalJSON to prevent inadvertent composition 329 func (tc *TaskCatch) MarshalJSON() ([]byte, error) { 330 catchJSON := map[string]interface{}{ 331 "ErrorEquals": tc.errorEquals, 332 "Next": tc.next, 333 } 334 return json.Marshal(catchJSON) 335 } 336 337 // NewTaskCatch returns a new TaskCatch instance 338 func NewTaskCatch(nextState TransitionState, errors ...StateError) *TaskCatch { 339 return &TaskCatch{ 340 errorEquals: errors, 341 next: nextState, 342 } 343 } 344 345 //////////////////////////////////////////////////////////////////////////////// 346 // BaseTask 347 //////////////////////////////////////////////////////////////////////////////// 348 349 // BaseTask represents the core BaseTask control flow options. 350 type BaseTask struct { 351 baseInnerState 352 ResultPath string 353 //lint:ignore ST1011 we want to give a cue to the client of the units 354 TimeoutSeconds time.Duration 355 //lint:ignore ST1011 we want to give a cue to the client of the units 356 HeartbeatSeconds time.Duration 357 LambdaDecorator sparta.TemplateDecorator 358 Retriers []*TaskRetry 359 Catchers []*TaskCatch 360 } 361 362 func (bt *BaseTask) marshalMergedParams(taskResourceType string, 363 taskParams interface{}) ([]byte, error) { 364 jsonBytes, jsonBytesErr := json.Marshal(taskParams) 365 if jsonBytesErr != nil { 366 return nil, errors.Wrapf(jsonBytesErr, "attempting to JSON marshal task params") 367 } 368 369 var unmarshaled interface{} 370 unmarshalErr := json.Unmarshal(jsonBytes, &unmarshaled) 371 if unmarshalErr != nil { 372 return nil, errors.Wrapf(unmarshalErr, "attempting to unmarshall params") 373 } 374 375 mapTyped, mapTypedErr := unmarshaled.(map[string]interface{}) 376 if !mapTypedErr { 377 return nil, errors.Errorf("attempting to type convert unmarshalled params to map[string]interface{}") 378 } 379 additionalParams := bt.additionalParams() 380 additionalParams["Resource"] = taskResourceType 381 additionalParams["Parameters"] = mapTyped 382 return bt.marshalStateJSON("Task", additionalParams) 383 } 384 385 // attributeMap returns the map of attributes necessary 386 // for JSON serialization 387 func (bt *BaseTask) additionalParams() map[string]interface{} { 388 additionalParams := make(map[string]interface{}) 389 390 if bt.TimeoutSeconds.Seconds() != 0 { 391 additionalParams["TimeoutSeconds"] = bt.TimeoutSeconds.Seconds() 392 } 393 if bt.HeartbeatSeconds.Seconds() != 0 { 394 additionalParams["HeartbeatSeconds"] = bt.HeartbeatSeconds.Seconds() 395 } 396 if bt.ResultPath != "" { 397 additionalParams["ResultPath"] = bt.ResultPath 398 } 399 if len(bt.Retriers) != 0 { 400 additionalParams["Retry"] = make([]map[string]interface{}, 0) 401 } 402 if bt.Catchers != nil { 403 catcherMap := make([]map[string]interface{}, len(bt.Catchers)) 404 for index, eachCatcher := range bt.Catchers { 405 catcherMap[index] = map[string]interface{}{ 406 "ErrorEquals": eachCatcher.errorEquals, 407 "Next": eachCatcher.next.Name(), 408 } 409 } 410 additionalParams["Catch"] = catcherMap 411 } 412 return additionalParams 413 } 414 415 // Next returns the next state 416 func (bt *BaseTask) Next(nextState MachineState) MachineState { 417 bt.next = nextState 418 return nextState 419 } 420 421 // AdjacentStates returns nodes reachable from this node 422 func (bt *BaseTask) AdjacentStates() []MachineState { 423 adjacent := []MachineState{} 424 if bt.next != nil { 425 adjacent = append(adjacent, bt.next) 426 } 427 for _, eachCatcher := range bt.Catchers { 428 adjacent = append(adjacent, eachCatcher.next) 429 } 430 return adjacent 431 } 432 433 // Name returns the name of this Task state 434 func (bt *BaseTask) Name() string { 435 return bt.name 436 } 437 438 // WithResultPath is the fluent builder for the result path 439 func (bt *BaseTask) WithResultPath(resultPath string) *BaseTask { 440 bt.ResultPath = resultPath 441 return bt 442 } 443 444 // WithTimeout is the fluent builder for BaseTask 445 func (bt *BaseTask) WithTimeout(timeout time.Duration) *BaseTask { 446 bt.TimeoutSeconds = timeout 447 return bt 448 } 449 450 // WithHeartbeat is the fluent builder for BaseTask 451 func (bt *BaseTask) WithHeartbeat(pulse time.Duration) *BaseTask { 452 bt.HeartbeatSeconds = pulse 453 return bt 454 } 455 456 // WithRetriers is the fluent builder for BaseTask 457 func (bt *BaseTask) WithRetriers(retries ...*TaskRetry) *BaseTask { 458 if bt.Retriers == nil { 459 bt.Retriers = make([]*TaskRetry, 0) 460 } 461 bt.Retriers = append(bt.Retriers, retries...) 462 return bt 463 } 464 465 // WithCatchers is the fluent builder for BaseTask 466 func (bt *BaseTask) WithCatchers(catch ...*TaskCatch) *BaseTask { 467 if bt.Catchers == nil { 468 bt.Catchers = make([]*TaskCatch, 0) 469 } 470 bt.Catchers = append(bt.Catchers, catch...) 471 return bt 472 } 473 474 // WithComment returns the BaseTask comment 475 func (bt *BaseTask) WithComment(comment string) TransitionState { 476 bt.comment = comment 477 return bt 478 } 479 480 // WithInputPath returns the BaseTask input data selector 481 func (bt *BaseTask) WithInputPath(inputPath string) TransitionState { 482 bt.inputPath = inputPath 483 return bt 484 } 485 486 // WithOutputPath returns the BaseTask output data selector 487 func (bt *BaseTask) WithOutputPath(outputPath string) TransitionState { 488 bt.outputPath = outputPath 489 return bt 490 } 491 492 // MarshalJSON to prevent inadvertent composition 493 func (bt *BaseTask) MarshalJSON() ([]byte, error) { 494 495 return nil, errors.Errorf("step.BaseTask doesn't support direct JSON serialization. Prefer using an embedding Task type (eg: TaskState, FargateTaskState)") 496 } 497 498 //////////////////////////////////////////////////////////////////////////////// 499 // LambdaTaskState 500 //////////////////////////////////////////////////////////////////////////////// 501 502 // LambdaTaskState is the core state, responsible for delegating to a Lambda function 503 type LambdaTaskState struct { 504 BaseTask 505 lambdaFn *sparta.LambdaAWSInfo 506 lambdaLogicalResourceName string 507 preexistingDecorator sparta.TemplateDecorator 508 } 509 510 // NewLambdaTaskState returns a LambdaTaskState instance properly initialized 511 func NewLambdaTaskState(stateName string, lambdaFn *sparta.LambdaAWSInfo) *LambdaTaskState { 512 ts := &LambdaTaskState{ 513 BaseTask: BaseTask{ 514 baseInnerState: baseInnerState{ 515 name: stateName, 516 id: rand.Int63(), 517 }, 518 }, 519 lambdaFn: lambdaFn, 520 } 521 ts.LambdaDecorator = func(serviceName string, 522 lambdaResourceName string, 523 lambdaResource gocf.LambdaFunction, 524 resourceMetadata map[string]interface{}, 525 S3Bucket string, 526 S3Key string, 527 buildID string, 528 cfTemplate *gocf.Template, 529 context map[string]interface{}, 530 logger *logrus.Logger) error { 531 if ts.preexistingDecorator != nil { 532 preexistingLambdaDecoratorErr := ts.preexistingDecorator( 533 serviceName, 534 lambdaResourceName, 535 lambdaResource, 536 resourceMetadata, 537 S3Bucket, 538 S3Key, 539 buildID, 540 cfTemplate, 541 context, 542 logger) 543 if preexistingLambdaDecoratorErr != nil { 544 return preexistingLambdaDecoratorErr 545 } 546 } 547 // Save the lambda name s.t. we can create the {"Ref"::"lambdaName"} entry... 548 ts.lambdaLogicalResourceName = lambdaResourceName 549 return nil 550 } 551 // Make sure this Lambda decorator is included in the list of existing decorators 552 553 // If there already is a decorator, then save it... 554 ts.preexistingDecorator = lambdaFn.Decorator 555 ts.lambdaFn.Decorators = append(ts.lambdaFn.Decorators, 556 sparta.TemplateDecoratorHookFunc(ts.LambdaDecorator)) 557 return ts 558 } 559 560 // MarshalJSON for custom marshalling, since this will be stringified and we need it 561 // to turn into a stringified Ref: 562 func (ts *LambdaTaskState) MarshalJSON() ([]byte, error) { 563 additionalParams := ts.BaseTask.additionalParams() 564 additionalParams["Resource"] = gocf.GetAtt(ts.lambdaLogicalResourceName, "Arn") 565 return ts.marshalStateJSON("Task", additionalParams) 566 } 567 568 //////////////////////////////////////////////////////////////////////////////// 569 // WaitDelay 570 //////////////////////////////////////////////////////////////////////////////// 571 572 // WaitDelay is a delay with an interval 573 type WaitDelay struct { 574 baseInnerState 575 delay time.Duration 576 } 577 578 // Name returns the WaitDelay name 579 func (wd *WaitDelay) Name() string { 580 return wd.name 581 } 582 583 // Next sets the step after the wait delay 584 func (wd *WaitDelay) Next(nextState MachineState) MachineState { 585 wd.next = nextState 586 return wd 587 } 588 589 // AdjacentStates returns nodes reachable from this node 590 func (wd *WaitDelay) AdjacentStates() []MachineState { 591 if wd.next == nil { 592 return nil 593 } 594 return []MachineState{wd.next} 595 } 596 597 // WithComment returns the WaitDelay comment 598 func (wd *WaitDelay) WithComment(comment string) TransitionState { 599 wd.comment = comment 600 return wd 601 } 602 603 // WithInputPath returns the TaskState input data selector 604 func (wd *WaitDelay) WithInputPath(inputPath string) TransitionState { 605 wd.inputPath = inputPath 606 return wd 607 } 608 609 // WithOutputPath returns the TaskState output data selector 610 func (wd *WaitDelay) WithOutputPath(outputPath string) TransitionState { 611 wd.outputPath = outputPath 612 return wd 613 } 614 615 // MarshalJSON for custom marshalling 616 func (wd *WaitDelay) MarshalJSON() ([]byte, error) { 617 additionalParams := make(map[string]interface{}) 618 additionalParams["Seconds"] = wd.delay.Seconds() 619 return wd.marshalStateJSON("Wait", additionalParams) 620 } 621 622 // NewWaitDelayState returns a new WaitDelay pointer instance 623 func NewWaitDelayState(stateName string, delay time.Duration) *WaitDelay { 624 return &WaitDelay{ 625 baseInnerState: baseInnerState{ 626 name: stateName, 627 id: rand.Int63(), 628 }, 629 delay: delay, 630 } 631 } 632 633 //////////////////////////////////////////////////////////////////////////////// 634 635 //////////////////////////////////////////////////////////////////////////////// 636 // WaitUntil 637 //////////////////////////////////////////////////////////////////////////////// 638 639 // WaitUntil is a delay with an absolute time gate 640 type WaitUntil struct { 641 baseInnerState 642 Timestamp time.Time 643 } 644 645 // Name returns the WaitDelay name 646 func (wu *WaitUntil) Name() string { 647 return wu.name 648 } 649 650 // Next sets the step after the wait delay 651 func (wu *WaitUntil) Next(nextState MachineState) MachineState { 652 wu.next = nextState 653 return wu 654 } 655 656 // AdjacentStates returns nodes reachable from this node 657 func (wu *WaitUntil) AdjacentStates() []MachineState { 658 if wu.next == nil { 659 return nil 660 } 661 return []MachineState{wu.next} 662 } 663 664 // WithComment returns the WaitDelay comment 665 func (wu *WaitUntil) WithComment(comment string) TransitionState { 666 wu.comment = comment 667 return wu 668 } 669 670 // WithInputPath returns the TaskState input data selector 671 func (wu *WaitUntil) WithInputPath(inputPath string) TransitionState { 672 wu.inputPath = inputPath 673 return wu 674 } 675 676 // WithOutputPath returns the TaskState output data selector 677 func (wu *WaitUntil) WithOutputPath(outputPath string) TransitionState { 678 wu.outputPath = outputPath 679 return wu 680 } 681 682 // MarshalJSON for custom marshalling 683 func (wu *WaitUntil) MarshalJSON() ([]byte, error) { 684 additionalParams := make(map[string]interface{}) 685 additionalParams["Timestamp"] = wu.Timestamp.Format(time.RFC3339) 686 return wu.marshalStateJSON("Wait", additionalParams) 687 } 688 689 // NewWaitUntilState returns a new WaitDelay pointer instance 690 func NewWaitUntilState(stateName string, waitUntil time.Time) *WaitUntil { 691 return &WaitUntil{ 692 baseInnerState: baseInnerState{ 693 name: stateName, 694 id: rand.Int63(), 695 }, 696 Timestamp: waitUntil, 697 } 698 } 699 700 //////////////////////////////////////////////////////////////////////////////// 701 702 // WaitDynamicUntil is a delay based on a previous response 703 type WaitDynamicUntil struct { 704 baseInnerState 705 TimestampPath string 706 SecondsPath string 707 } 708 709 // Name returns the WaitDelay name 710 func (wdu *WaitDynamicUntil) Name() string { 711 return wdu.name 712 } 713 714 // Next sets the step after the wait delay 715 func (wdu *WaitDynamicUntil) Next(nextState MachineState) MachineState { 716 wdu.next = nextState 717 return wdu 718 } 719 720 // AdjacentStates returns nodes reachable from this node 721 func (wdu *WaitDynamicUntil) AdjacentStates() []MachineState { 722 if wdu.next == nil { 723 return nil 724 } 725 return []MachineState{wdu.next} 726 } 727 728 // WithComment returns the WaitDelay comment 729 func (wdu *WaitDynamicUntil) WithComment(comment string) TransitionState { 730 wdu.comment = comment 731 return wdu 732 } 733 734 // WithInputPath returns the TaskState input data selector 735 func (wdu *WaitDynamicUntil) WithInputPath(inputPath string) TransitionState { 736 wdu.inputPath = inputPath 737 return wdu 738 } 739 740 // WithOutputPath returns the TaskState output data selector 741 func (wdu *WaitDynamicUntil) WithOutputPath(outputPath string) TransitionState { 742 wdu.outputPath = outputPath 743 return wdu 744 } 745 746 // MarshalJSON for custom marshalling 747 func (wdu *WaitDynamicUntil) MarshalJSON() ([]byte, error) { 748 additionalParams := make(map[string]interface{}) 749 if wdu.TimestampPath != "" { 750 additionalParams["TimestampPath"] = wdu.TimestampPath 751 } 752 if wdu.SecondsPath != "" { 753 additionalParams["SecondsPath"] = wdu.SecondsPath 754 } 755 return wdu.marshalStateJSON("Wait", additionalParams) 756 } 757 758 // NewWaitDynamicUntilState returns a new WaitDynamicUntil pointer instance 759 func NewWaitDynamicUntilState(stateName string, timestampPath string) *WaitDynamicUntil { 760 return &WaitDynamicUntil{ 761 baseInnerState: baseInnerState{ 762 name: stateName, 763 id: rand.Int63(), 764 }, 765 TimestampPath: timestampPath, 766 } 767 } 768 769 // NewDynamicWaitDurationState returns a new WaitDynamicUntil pointer instance 770 func NewDynamicWaitDurationState(stateName string, secondsPath string) *WaitDynamicUntil { 771 return &WaitDynamicUntil{ 772 baseInnerState: baseInnerState{ 773 name: stateName, 774 id: rand.Int63(), 775 }, 776 SecondsPath: secondsPath, 777 } 778 } 779 780 /* 781 Validate-All": { 782 "Type": "Map", 783 "InputPath": "$.detail", 784 "ItemsPath": "$.shipped", 785 "MaxConcurrency": 0, 786 "Parameters": { 787 "parcel.$": "$$.Map.Item.Value", 788 "courier.$": "$.delivery-partner" 789 }, 790 "Iterator": { 791 "StartAt": "Validate", 792 "States": { 793 "Validate": { 794 "Type": "Task", 795 "Resource": "arn:aws:lambda:us-east-1:123456789012:function:ship-val", 796 "End": true 797 } 798 } 799 }, 800 "ResultPath": "$.detail.shipped", 801 "End": true 802 */ 803 804 //////////////////////////////////////////////////////////////////////////////// 805 // StateMachine 806 //////////////////////////////////////////////////////////////////////////////// 807 808 // StateMachine is the top level item 809 type StateMachine struct { 810 name string 811 comment string 812 stateDefinitionError error 813 machineType string 814 loggingConfiguration *gocf.StepFunctionsStateMachineLoggingConfiguration 815 startAt MachineState 816 uniqueStates map[string]MachineState 817 roleArn gocf.Stringable 818 // internal flag to suppress the automatic "End" property 819 // from being serialized for Map states 820 disableEndState bool 821 } 822 823 //Comment sets the StateMachine comment 824 func (sm *StateMachine) Comment(comment string) *StateMachine { 825 sm.comment = comment 826 return sm 827 } 828 829 //WithRoleArn sets the state machine roleArn 830 func (sm *StateMachine) WithRoleArn(roleArn gocf.Stringable) *StateMachine { 831 sm.roleArn = roleArn 832 return sm 833 } 834 835 // validate performs any validation against the state machine 836 // prior to marshaling 837 func (sm *StateMachine) validate() []error { 838 validationErrors := make([]error, 0) 839 if sm.stateDefinitionError != nil { 840 validationErrors = append(validationErrors, sm.stateDefinitionError) 841 } 842 843 // TODO - add Catcher validator 844 /* 845 Each Catcher MUST contain a field named “ErrorEquals”, specified exactly as with the Retrier “ErrorEquals” field, and a field named “Next” whose value MUST be a string exactly matching a State Name. 846 847 When a state reports an error and either there is no Retry field, or retries have failed to resolve the error, the interpreter scans through the Catchers in array order, and when the Error Name appears in the value of a Catcher’s “ErrorEquals” field, transitions the machine to the state named in the value of the “Next” field. 848 849 The reserved name “States.ALL” appearing in a Retrier’s “ErrorEquals” field is a wild-card and matches any Error Name. Such a value MUST appear alone in the “ErrorEquals” array and MUST appear in the last Catcher in the “Catch” array. 850 */ 851 return validationErrors 852 } 853 854 // StateMachineDecorator is a decorator that returns a default 855 // CloudFormationResource named decorator 856 func (sm *StateMachine) StateMachineDecorator() sparta.ServiceDecoratorHookFunc { 857 cfName := sparta.CloudFormationResourceName("StateMachine", "StateMachine") 858 return sm.StateMachineNamedDecorator(cfName) 859 } 860 861 // StateMachineNamedDecorator is the hook exposed by the StateMachine 862 // to insert the AWS Step function into the CloudFormation template 863 func (sm *StateMachine) StateMachineNamedDecorator(stepFunctionResourceName string) sparta.ServiceDecoratorHookFunc { 864 return func(context map[string]interface{}, 865 serviceName string, 866 template *gocf.Template, 867 S3Bucket string, 868 S3Key string, 869 buildID string, 870 awsSession *session.Session, 871 noop bool, 872 logger *logrus.Logger) error { 873 874 machineErrors := sm.validate() 875 if len(machineErrors) != 0 { 876 errorText := make([]string, len(machineErrors)) 877 for index := range machineErrors { 878 errorText[index] = machineErrors[index].Error() 879 } 880 return errors.Errorf("Invalid state machine. Errors: %s", 881 strings.Join(errorText, ", ")) 882 } 883 884 lambdaFunctionResourceNames := []string{} 885 for _, eachState := range sm.uniqueStates { 886 switch taskState := eachState.(type) { 887 case *LambdaTaskState: 888 { 889 lambdaFunctionResourceNames = append(lambdaFunctionResourceNames, 890 taskState.lambdaLogicalResourceName) 891 } 892 case *MapState: 893 { 894 for _, eachUniqueState := range taskState.States.uniqueStates { 895 switch typedMapState := eachUniqueState.(type) { 896 case *LambdaTaskState: 897 { 898 lambdaFunctionResourceNames = append(lambdaFunctionResourceNames, 899 typedMapState.lambdaLogicalResourceName) 900 } 901 } 902 } 903 } 904 case *ParallelState: 905 { 906 for _, eachBranch := range taskState.Branches { 907 for _, eachUniqueState := range eachBranch.uniqueStates { 908 switch typedParallelState := eachUniqueState.(type) { 909 case *LambdaTaskState: 910 { 911 lambdaFunctionResourceNames = append(lambdaFunctionResourceNames, 912 typedParallelState.lambdaLogicalResourceName) 913 } 914 } 915 } 916 } 917 } 918 } 919 } 920 921 // Assume policy document 922 regionalPrincipal := gocf.Join(".", 923 gocf.String("states"), 924 gocf.Ref("AWS::Region"), 925 gocf.String("amazonaws.com")) 926 var AssumePolicyDocument = sparta.ArbitraryJSONObject{ 927 "Version": "2012-10-17", 928 "Statement": []sparta.ArbitraryJSONObject{ 929 { 930 "Effect": "Allow", 931 "Principal": sparta.ArbitraryJSONObject{ 932 "Service": regionalPrincipal, 933 }, 934 "Action": []string{"sts:AssumeRole"}, 935 }, 936 }, 937 } 938 var iamRoleResourceName string 939 if len(lambdaFunctionResourceNames) != 0 && sm.roleArn == nil { 940 statesIAMRole := &gocf.IAMRole{ 941 AssumeRolePolicyDocument: AssumePolicyDocument, 942 } 943 statements := make([]spartaIAM.PolicyStatement, 0) 944 for _, eachLambdaName := range lambdaFunctionResourceNames { 945 statements = append(statements, 946 spartaIAM.PolicyStatement{ 947 Effect: "Allow", 948 Action: []string{ 949 "lambda:InvokeFunction", 950 }, 951 Resource: gocf.GetAtt(eachLambdaName, "Arn").String(), 952 }, 953 ) 954 } 955 iamPolicies := gocf.IAMRolePolicyList{} 956 iamPolicies = append(iamPolicies, gocf.IAMRolePolicy{ 957 PolicyDocument: sparta.ArbitraryJSONObject{ 958 "Version": "2012-10-17", 959 "Statement": statements, 960 }, 961 PolicyName: gocf.String("StatesExecutionPolicy"), 962 }) 963 statesIAMRole.Policies = &iamPolicies 964 iamRoleResourceName = sparta.CloudFormationResourceName("StatesIAMRole", 965 "StatesIAMRole") 966 template.AddResource(iamRoleResourceName, statesIAMRole) 967 } 968 969 // Sweet - serialize it without indentation so that the 970 // ConvertToTemplateExpression can actually parse the inline `Ref` objects 971 jsonBytes, jsonBytesErr := json.Marshal(sm) 972 if jsonBytesErr != nil { 973 return errors.Errorf("Failed to marshal: %s", jsonBytesErr.Error()) 974 } 975 logger.WithFields(logrus.Fields{ 976 "StateMachine": string(jsonBytes), 977 }).Debug("State machine definition") 978 979 // Super, now parse this into an Fn::Join representation 980 // so that we can get inline expansion of the AWS pseudo params 981 smReader := bytes.NewReader(jsonBytes) 982 templateExpr, templateExprErr := spartaCF.ConvertToInlineJSONTemplateExpression(smReader, nil) 983 if nil != templateExprErr { 984 return errors.Errorf("Failed to parser: %s", templateExprErr.Error()) 985 } 986 987 // Awsome - add an AWS::StepFunction to the template with this info and roll with it... 988 stepFunctionResource := &gocf.StepFunctionsStateMachine{ 989 StateMachineName: gocf.String(sm.name), 990 DefinitionString: templateExpr, 991 LoggingConfiguration: sm.loggingConfiguration, 992 } 993 if iamRoleResourceName != "" { 994 stepFunctionResource.RoleArn = gocf.GetAtt(iamRoleResourceName, "Arn").String() 995 } else if sm.roleArn != nil { 996 stepFunctionResource.RoleArn = sm.roleArn.String() 997 } 998 if sm.machineType != "" { 999 stepFunctionResource.StateMachineType = gocf.String(sm.machineType) 1000 } 1001 template.AddResource(stepFunctionResourceName, stepFunctionResource) 1002 return nil 1003 } 1004 } 1005 1006 // MarshalJSON for custom marshalling 1007 func (sm *StateMachine) MarshalJSON() ([]byte, error) { 1008 1009 // If there aren't any states, then it's the end 1010 return json.Marshal(&struct { 1011 Comment string `json:",omitempty"` 1012 StartAt string `json:",omitempty"` 1013 States map[string]MachineState `json:",omitempty"` 1014 End bool `json:",omitempty"` 1015 }{ 1016 Comment: sm.comment, 1017 StartAt: sm.startAt.Name(), 1018 States: sm.uniqueStates, 1019 End: (len(sm.uniqueStates) == 1) && !sm.disableEndState, 1020 }) 1021 } 1022 1023 func createStateMachine(stateMachineName string, 1024 machineType string, 1025 startState MachineState) *StateMachine { 1026 uniqueStates := make(map[string]MachineState) 1027 pendingStates := []MachineState{startState} 1028 // Map of basename to nodeID to check for duplicates 1029 duplicateStateNames := make(map[string]bool) 1030 1031 // TODO - check duplicate names 1032 1033 nodeVisited := func(node MachineState) bool { 1034 if node == nil { 1035 return true 1036 } 1037 existingNode, visited := uniqueStates[node.Name()] 1038 if visited && existingNode.nodeID() != node.nodeID() { 1039 // Check for different nodeids. 1040 duplicateStateNames[node.Name()] = true 1041 } 1042 return visited 1043 } 1044 1045 for len(pendingStates) != 0 { 1046 headState, tailStates := pendingStates[0], pendingStates[1:] 1047 1048 // Does this already exist? 1049 headStateName := headState.Name() 1050 uniqueStates[headStateName] = headState 1051 1052 switch stateNode := headState.(type) { 1053 case *ChoiceState: 1054 for _, eachChoice := range stateNode.Choices { 1055 if !nodeVisited(eachChoice.nextState()) { 1056 tailStates = append(tailStates, eachChoice.nextState()) 1057 } 1058 } 1059 if !nodeVisited(stateNode.Default) { 1060 tailStates = append(tailStates, stateNode.Default) 1061 } 1062 1063 case TransitionState: 1064 for _, eachAdjacentState := range stateNode.AdjacentStates() { 1065 if !nodeVisited(eachAdjacentState) { 1066 tailStates = append(tailStates, eachAdjacentState) 1067 } 1068 } 1069 // Are there any Catchers in here? 1070 } 1071 pendingStates = tailStates 1072 } 1073 1074 // Walk all the states and assemble them into the states slice 1075 sm := &StateMachine{ 1076 name: stateMachineName, 1077 startAt: startState, 1078 uniqueStates: uniqueStates, 1079 } 1080 if machineType != "" { 1081 sm.machineType = machineType 1082 } 1083 // Store duplicate state names 1084 if len(duplicateStateNames) != 0 { 1085 duplicateNames := []string{} 1086 for eachKey := range duplicateStateNames { 1087 duplicateNames = append(duplicateNames, eachKey) 1088 } 1089 sm.stateDefinitionError = fmt.Errorf("duplicate state names: %s", 1090 strings.Join(duplicateNames, ",")) 1091 } 1092 return sm 1093 1094 } 1095 1096 // NewExpressStateMachine returns a new Express StateMachine instance. See 1097 // https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-stepfunctions-statemachine.html 1098 // for more information. 1099 func NewExpressStateMachine(stateMachineName string, 1100 loggingConfiguration *gocf.StepFunctionsStateMachineLoggingConfiguration, 1101 startState TransitionState) *StateMachine { 1102 1103 sm := createStateMachine(stateMachineName, 1104 stateMachineExpress, 1105 startState) 1106 sm.loggingConfiguration = loggingConfiguration 1107 return sm 1108 } 1109 1110 // NewStateMachine returns a new StateMachine instance 1111 func NewStateMachine(stateMachineName string, startState MachineState) *StateMachine { 1112 1113 return createStateMachine(stateMachineName, 1114 stateMachineStandard, 1115 startState) 1116 } 1117 1118 ////////////////////////////////////////////////////////////////////////////////