github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/adapters/cloud/aws/lambda/adapt_test.go (about)

     1  package lambda
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/khulnasoft-lab/defsec/pkg/providers/aws/lambda"
     8  
     9  	"github.com/aws/aws-sdk-go-v2/service/lambda/types"
    10  
    11  	lambdaapi "github.com/aws/aws-sdk-go-v2/service/lambda"
    12  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws"
    13  	"github.com/khulnasoft-lab/defsec/internal/adapters/cloud/aws/test"
    14  	"github.com/khulnasoft-lab/defsec/pkg/state"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  type functionDetails struct {
    20  	name        string
    21  	permissions []permissionDetails
    22  	tracing     string
    23  }
    24  
    25  type permissionDetails struct {
    26  	action    string
    27  	principal string
    28  	source    string
    29  }
    30  
    31  func Test_Lambda(t *testing.T) {
    32  
    33  	tests := []struct {
    34  		name    string
    35  		details functionDetails
    36  	}{
    37  		{
    38  			name: "defaults",
    39  			details: functionDetails{
    40  				name: "myfunction",
    41  			},
    42  		},
    43  		{
    44  			name: "pass-through tracing",
    45  			details: functionDetails{
    46  				name:    "myfunction",
    47  				tracing: "PassThrough",
    48  			},
    49  		},
    50  		{
    51  			name: "active tracing",
    52  			details: functionDetails{
    53  				name:    "myfunction",
    54  				tracing: "Active",
    55  			},
    56  		},
    57  		{
    58  			name: "with permissions",
    59  			details: functionDetails{
    60  				name:    "myfunction",
    61  				tracing: "Active",
    62  				permissions: []permissionDetails{
    63  					{
    64  						action:    "lambda:InvokeFunction",
    65  						principal: "1234567890",
    66  						source:    "*",
    67  					},
    68  				},
    69  			},
    70  		},
    71  	}
    72  
    73  	ra, stack, err := test.CreateLocalstackAdapter(t)
    74  	defer func() { _ = stack.Stop() }()
    75  	require.NoError(t, err)
    76  
    77  	for _, tt := range tests {
    78  		t.Run(tt.name, func(t *testing.T) {
    79  			funcARN := bootstrapFunction(t, ra, tt.details)
    80  			defer removeFunction(t, ra, funcARN)
    81  
    82  			testState := &state.State{}
    83  			lambdaAdapter := &adapter{}
    84  			err = lambdaAdapter.Adapt(ra, testState)
    85  			require.NoError(t, err)
    86  
    87  			require.Len(t, testState.AWS.Lambda.Functions, 1)
    88  			got := testState.AWS.Lambda.Functions[0]
    89  
    90  			if tt.details.tracing == "" {
    91  				tt.details.tracing = lambda.TracingModePassThrough
    92  			}
    93  			assert.Equal(t, tt.details.tracing, got.Tracing.Mode.Value())
    94  
    95  			assert.Equal(t, len(tt.details.permissions), len(got.Permissions))
    96  
    97  			for _, expectedPermission := range tt.details.permissions {
    98  				var found bool
    99  				for _, actualPermission := range got.Permissions {
   100  					if actualPermission.Principal.Value() == expectedPermission.principal && actualPermission.SourceARN.Value() == expectedPermission.source {
   101  						found = true
   102  						break
   103  					}
   104  				}
   105  				assert.True(t, found)
   106  			}
   107  		})
   108  	}
   109  }
   110  
   111  func bootstrapFunction(t *testing.T, ra *aws.RootAdapter, spec functionDetails) string {
   112  
   113  	api := lambdaapi.NewFromConfig(ra.SessionConfig())
   114  
   115  	output, err := api.CreateFunction(ra.Context(), &lambdaapi.CreateFunctionInput{
   116  		Code: &types.FunctionCode{
   117  			ZipFile: []byte{80, 75, 05, 06, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
   118  		},
   119  		FunctionName: &spec.name,
   120  		Role:         &spec.name,
   121  		TracingConfig: &types.TracingConfig{
   122  			Mode: types.TracingMode(spec.tracing),
   123  		},
   124  	})
   125  	require.NoError(t, err)
   126  
   127  	for i, permission := range spec.permissions {
   128  		statementID := fmt.Sprintf("%d", i)
   129  		_, err = api.AddPermission(ra.Context(), &lambdaapi.AddPermissionInput{
   130  			Action:       &permission.action,
   131  			FunctionName: &spec.name,
   132  			Qualifier:    output.Version,
   133  			Principal:    &permission.principal,
   134  			StatementId:  &statementID,
   135  			SourceArn:    &permission.source,
   136  		})
   137  		require.NoError(t, err)
   138  	}
   139  
   140  	return *output.FunctionArn
   141  }
   142  
   143  func removeFunction(t *testing.T, ra *aws.RootAdapter, arn string) {
   144  	api := lambdaapi.NewFromConfig(ra.SessionConfig())
   145  	_, err := api.DeleteFunction(ra.Context(), &lambdaapi.DeleteFunctionInput{
   146  		FunctionName: &arn,
   147  	})
   148  	require.NoError(t, err)
   149  }