github.com/prebid/prebid-server@v0.275.0/hooks/hookexecution/execution.go (about)

     1  package hookexecution
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/prebid/prebid-server/hooks"
    11  	"github.com/prebid/prebid-server/hooks/hookstage"
    12  	"github.com/prebid/prebid-server/metrics"
    13  )
    14  
    15  type hookResponse[T any] struct {
    16  	Err           error
    17  	ExecutionTime time.Duration
    18  	HookID        HookID
    19  	Result        hookstage.HookResult[T]
    20  }
    21  
    22  type hookHandler[H any, P any] func(
    23  	context.Context,
    24  	hookstage.ModuleInvocationContext,
    25  	H,
    26  	P,
    27  ) (hookstage.HookResult[P], error)
    28  
    29  func executeStage[H any, P any](
    30  	executionCtx executionContext,
    31  	plan hooks.Plan[H],
    32  	payload P,
    33  	hookHandler hookHandler[H, P],
    34  	metricEngine metrics.MetricsEngine,
    35  ) (StageOutcome, P, stageModuleContext, *RejectError) {
    36  	stageOutcome := StageOutcome{}
    37  	stageOutcome.Groups = make([]GroupOutcome, 0, len(plan))
    38  	stageModuleCtx := stageModuleContext{}
    39  	stageModuleCtx.groupCtx = make([]groupModuleContext, 0, len(plan))
    40  
    41  	for _, group := range plan {
    42  		groupOutcome, newPayload, moduleContexts, rejectErr := executeGroup(executionCtx, group, payload, hookHandler, metricEngine)
    43  		stageOutcome.ExecutionTimeMillis += groupOutcome.ExecutionTimeMillis
    44  		stageOutcome.Groups = append(stageOutcome.Groups, groupOutcome)
    45  		stageModuleCtx.groupCtx = append(stageModuleCtx.groupCtx, moduleContexts)
    46  		if rejectErr != nil {
    47  			return stageOutcome, payload, stageModuleCtx, rejectErr
    48  		}
    49  
    50  		payload = newPayload
    51  	}
    52  
    53  	return stageOutcome, payload, stageModuleCtx, nil
    54  }
    55  
    56  func executeGroup[H any, P any](
    57  	executionCtx executionContext,
    58  	group hooks.Group[H],
    59  	payload P,
    60  	hookHandler hookHandler[H, P],
    61  	metricEngine metrics.MetricsEngine,
    62  ) (GroupOutcome, P, groupModuleContext, *RejectError) {
    63  	var wg sync.WaitGroup
    64  	rejected := make(chan struct{})
    65  	resp := make(chan hookResponse[P])
    66  
    67  	for _, hook := range group.Hooks {
    68  		mCtx := executionCtx.getModuleContext(hook.Module)
    69  		wg.Add(1)
    70  		go func(hw hooks.HookWrapper[H], moduleCtx hookstage.ModuleInvocationContext) {
    71  			defer wg.Done()
    72  			executeHook(moduleCtx, hw, payload, hookHandler, group.Timeout, resp, rejected)
    73  		}(hook, mCtx)
    74  	}
    75  
    76  	go func() {
    77  		wg.Wait()
    78  		close(resp)
    79  	}()
    80  
    81  	hookResponses := collectHookResponses(resp, rejected)
    82  
    83  	return handleHookResponses(executionCtx, hookResponses, payload, metricEngine)
    84  }
    85  
    86  func executeHook[H any, P any](
    87  	moduleCtx hookstage.ModuleInvocationContext,
    88  	hw hooks.HookWrapper[H],
    89  	payload P,
    90  	hookHandler hookHandler[H, P],
    91  	timeout time.Duration,
    92  	resp chan<- hookResponse[P],
    93  	rejected <-chan struct{},
    94  ) {
    95  	hookRespCh := make(chan hookResponse[P], 1)
    96  	startTime := time.Now()
    97  	hookId := HookID{ModuleCode: hw.Module, HookImplCode: hw.Code}
    98  
    99  	go func() {
   100  		ctx, cancel := context.WithTimeout(context.Background(), timeout)
   101  		defer cancel()
   102  		result, err := hookHandler(ctx, moduleCtx, hw.Hook, payload)
   103  		hookRespCh <- hookResponse[P]{
   104  			Result: result,
   105  			Err:    err,
   106  		}
   107  	}()
   108  
   109  	select {
   110  	case res := <-hookRespCh:
   111  		res.HookID = hookId
   112  		res.ExecutionTime = time.Since(startTime)
   113  		resp <- res
   114  	case <-time.After(timeout):
   115  		resp <- hookResponse[P]{
   116  			Err:           TimeoutError{},
   117  			ExecutionTime: time.Since(startTime),
   118  			HookID:        hookId,
   119  			Result:        hookstage.HookResult[P]{},
   120  		}
   121  	case <-rejected:
   122  		return
   123  	}
   124  }
   125  
   126  func collectHookResponses[P any](resp <-chan hookResponse[P], rejected chan<- struct{}) []hookResponse[P] {
   127  	hookResponses := make([]hookResponse[P], 0)
   128  	for r := range resp {
   129  		hookResponses = append(hookResponses, r)
   130  		if r.Result.Reject {
   131  			close(rejected)
   132  			break
   133  		}
   134  	}
   135  
   136  	return hookResponses
   137  }
   138  
   139  func handleHookResponses[P any](
   140  	executionCtx executionContext,
   141  	hookResponses []hookResponse[P],
   142  	payload P,
   143  	metricEngine metrics.MetricsEngine,
   144  ) (GroupOutcome, P, groupModuleContext, *RejectError) {
   145  	groupOutcome := GroupOutcome{}
   146  	groupOutcome.InvocationResults = make([]HookOutcome, 0, len(hookResponses))
   147  	groupModuleCtx := make(groupModuleContext, len(hookResponses))
   148  
   149  	for _, r := range hookResponses {
   150  		groupModuleCtx[r.HookID.ModuleCode] = r.Result.ModuleContext
   151  		if r.ExecutionTime > groupOutcome.ExecutionTimeMillis {
   152  			groupOutcome.ExecutionTimeMillis = r.ExecutionTime
   153  		}
   154  
   155  		updatedPayload, hookOutcome, rejectErr := handleHookResponse(executionCtx, payload, r, metricEngine)
   156  		groupOutcome.InvocationResults = append(groupOutcome.InvocationResults, hookOutcome)
   157  		payload = updatedPayload
   158  
   159  		if rejectErr != nil {
   160  			return groupOutcome, payload, groupModuleCtx, rejectErr
   161  		}
   162  	}
   163  
   164  	return groupOutcome, payload, groupModuleCtx, nil
   165  }
   166  
   167  // moduleReplacer changes unwanted symbols to be in compliance with metric naming requirements
   168  var moduleReplacer = strings.NewReplacer(".", "_", "-", "_")
   169  
   170  // handleHookResponse is a strategy function that selects and applies
   171  // one of the available algorithms to handle hook response.
   172  func handleHookResponse[P any](
   173  	ctx executionContext,
   174  	payload P,
   175  	hr hookResponse[P],
   176  	metricEngine metrics.MetricsEngine,
   177  ) (P, HookOutcome, *RejectError) {
   178  	var rejectErr *RejectError
   179  	labels := metrics.ModuleLabels{Module: moduleReplacer.Replace(hr.HookID.ModuleCode), Stage: ctx.stage, AccountID: ctx.accountId}
   180  	metricEngine.RecordModuleCalled(labels, hr.ExecutionTime)
   181  
   182  	hookOutcome := HookOutcome{
   183  		Status:        StatusSuccess,
   184  		HookID:        hr.HookID,
   185  		Message:       hr.Result.Message,
   186  		Errors:        hr.Result.Errors,
   187  		Warnings:      hr.Result.Warnings,
   188  		DebugMessages: hr.Result.DebugMessages,
   189  		AnalyticsTags: hr.Result.AnalyticsTags,
   190  		ExecutionTime: ExecutionTime{ExecutionTimeMillis: hr.ExecutionTime},
   191  	}
   192  
   193  	if hr.Err != nil || hr.Result.Reject {
   194  		handleHookError(hr, &hookOutcome, metricEngine, labels)
   195  		rejectErr = handleHookReject(ctx, hr, &hookOutcome, metricEngine, labels)
   196  	} else {
   197  		payload = handleHookMutations(payload, hr, &hookOutcome, metricEngine, labels)
   198  	}
   199  
   200  	return payload, hookOutcome, rejectErr
   201  }
   202  
   203  // handleHookError sets an appropriate status to HookOutcome depending on the type of hook execution error.
   204  func handleHookError[P any](
   205  	hr hookResponse[P],
   206  	hookOutcome *HookOutcome,
   207  	metricEngine metrics.MetricsEngine,
   208  	labels metrics.ModuleLabels,
   209  ) {
   210  	if hr.Err == nil {
   211  		return
   212  	}
   213  
   214  	hookOutcome.Errors = append(hookOutcome.Errors, hr.Err.Error())
   215  	switch hr.Err.(type) {
   216  	case TimeoutError:
   217  		metricEngine.RecordModuleTimeout(labels)
   218  		hookOutcome.Status = StatusTimeout
   219  	case FailureError:
   220  		metricEngine.RecordModuleFailed(labels)
   221  		hookOutcome.Status = StatusFailure
   222  	default:
   223  		metricEngine.RecordModuleExecutionError(labels)
   224  		hookOutcome.Status = StatusExecutionFailure
   225  	}
   226  }
   227  
   228  // handleHookReject rejects execution at the current stage.
   229  // In case the stage does not support rejection, hook execution marked as failed.
   230  func handleHookReject[P any](
   231  	ctx executionContext,
   232  	hr hookResponse[P],
   233  	hookOutcome *HookOutcome,
   234  	metricEngine metrics.MetricsEngine,
   235  	labels metrics.ModuleLabels,
   236  ) *RejectError {
   237  	if !hr.Result.Reject {
   238  		return nil
   239  	}
   240  
   241  	stage := hooks.Stage(ctx.stage)
   242  	if !stage.IsRejectable() {
   243  		metricEngine.RecordModuleExecutionError(labels)
   244  		hookOutcome.Status = StatusExecutionFailure
   245  		hookOutcome.Errors = append(
   246  			hookOutcome.Errors,
   247  			fmt.Sprintf(
   248  				"Module (name: %s, hook code: %s) tried to reject request on the %s stage that does not support rejection",
   249  				hr.HookID.ModuleCode,
   250  				hr.HookID.HookImplCode,
   251  				ctx.stage,
   252  			),
   253  		)
   254  		return nil
   255  	}
   256  
   257  	rejectErr := &RejectError{NBR: hr.Result.NbrCode, Hook: hr.HookID, Stage: ctx.stage}
   258  	hookOutcome.Action = ActionReject
   259  	hookOutcome.Errors = append(hookOutcome.Errors, rejectErr.Error())
   260  	metricEngine.RecordModuleSuccessRejected(labels)
   261  
   262  	return rejectErr
   263  }
   264  
   265  // handleHookMutations applies mutations returned by hook to provided payload.
   266  func handleHookMutations[P any](
   267  	payload P,
   268  	hr hookResponse[P],
   269  	hookOutcome *HookOutcome,
   270  	metricEngine metrics.MetricsEngine,
   271  	labels metrics.ModuleLabels,
   272  ) P {
   273  	if len(hr.Result.ChangeSet.Mutations()) == 0 {
   274  		metricEngine.RecordModuleSuccessNooped(labels)
   275  		hookOutcome.Action = ActionNone
   276  		return payload
   277  	}
   278  
   279  	hookOutcome.Action = ActionUpdate
   280  	successfulMutations := 0
   281  	for _, mut := range hr.Result.ChangeSet.Mutations() {
   282  		p, err := mut.Apply(payload)
   283  		if err != nil {
   284  			hookOutcome.Warnings = append(
   285  				hookOutcome.Warnings,
   286  				fmt.Sprintf("failed to apply hook mutation: %s", err),
   287  			)
   288  			continue
   289  		}
   290  
   291  		payload = p
   292  		hookOutcome.DebugMessages = append(
   293  			hookOutcome.DebugMessages,
   294  			fmt.Sprintf(
   295  				"Hook mutation successfully applied, affected key: %s, mutation type: %s",
   296  				strings.Join(mut.Key(), "."),
   297  				mut.Type(),
   298  			),
   299  		)
   300  		successfulMutations++
   301  	}
   302  
   303  	// if at least one mutation from a given module was successfully applied
   304  	// we consider that the module was processed successfully
   305  	if successfulMutations > 0 {
   306  		metricEngine.RecordModuleSuccessUpdated(labels)
   307  	} else {
   308  		hookOutcome.Status = StatusExecutionFailure
   309  		metricEngine.RecordModuleExecutionError(labels)
   310  	}
   311  
   312  	return payload
   313  }