github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/context_response.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/vektah/gqlparser/v2/gqlerror"
     9  )
    10  
    11  type responseContext struct {
    12  	errorPresenter ErrorPresenterFunc
    13  	recover        RecoverFunc
    14  
    15  	errors   gqlerror.List
    16  	errorsMu sync.Mutex
    17  
    18  	extensions   map[string]interface{}
    19  	extensionsMu sync.Mutex
    20  }
    21  
    22  const resultCtx key = "result_context"
    23  
    24  func getResponseContext(ctx context.Context) *responseContext {
    25  	val, ok := ctx.Value(resultCtx).(*responseContext)
    26  	if !ok {
    27  		panic("missing response context")
    28  	}
    29  	return val
    30  }
    31  
    32  func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
    33  	return context.WithValue(ctx, resultCtx, &responseContext{
    34  		errorPresenter: presenterFunc,
    35  		recover:        recoverFunc,
    36  	})
    37  }
    38  
    39  // AddErrorf writes a formatted error to the client, first passing it through the error presenter.
    40  func AddErrorf(ctx context.Context, format string, args ...interface{}) {
    41  	AddError(ctx, fmt.Errorf(format, args...))
    42  }
    43  
    44  // AddError sends an error to the client, first passing it through the error presenter.
    45  func AddError(ctx context.Context, err error) {
    46  	c := getResponseContext(ctx)
    47  
    48  	presentedError := c.errorPresenter(ctx, ErrorOnPath(ctx, err))
    49  
    50  	c.errorsMu.Lock()
    51  	defer c.errorsMu.Unlock()
    52  	c.errors = append(c.errors, presentedError)
    53  }
    54  
    55  func Recover(ctx context.Context, err interface{}) (userMessage error) {
    56  	c := getResponseContext(ctx)
    57  	return ErrorOnPath(ctx, c.recover(ctx, err))
    58  }
    59  
    60  // HasFieldError returns true if the given field has already errored
    61  func HasFieldError(ctx context.Context, rctx *FieldContext) bool {
    62  	c := getResponseContext(ctx)
    63  
    64  	c.errorsMu.Lock()
    65  	defer c.errorsMu.Unlock()
    66  
    67  	if len(c.errors) == 0 {
    68  		return false
    69  	}
    70  
    71  	path := rctx.Path()
    72  	for _, err := range c.errors {
    73  		if equalPath(err.Path, path) {
    74  			return true
    75  		}
    76  	}
    77  	return false
    78  }
    79  
    80  // GetFieldErrors returns a list of errors that occurred in the given field
    81  func GetFieldErrors(ctx context.Context, rctx *FieldContext) gqlerror.List {
    82  	c := getResponseContext(ctx)
    83  
    84  	c.errorsMu.Lock()
    85  	defer c.errorsMu.Unlock()
    86  
    87  	if len(c.errors) == 0 {
    88  		return nil
    89  	}
    90  
    91  	path := rctx.Path()
    92  	var errs gqlerror.List
    93  	for _, err := range c.errors {
    94  		if equalPath(err.Path, path) {
    95  			errs = append(errs, err)
    96  		}
    97  	}
    98  	return errs
    99  }
   100  
   101  func GetErrors(ctx context.Context) gqlerror.List {
   102  	resCtx := getResponseContext(ctx)
   103  	resCtx.errorsMu.Lock()
   104  	defer resCtx.errorsMu.Unlock()
   105  
   106  	if len(resCtx.errors) == 0 {
   107  		return nil
   108  	}
   109  
   110  	errs := resCtx.errors
   111  	cpy := make(gqlerror.List, len(errs))
   112  	for i := range errs {
   113  		errCpy := *errs[i]
   114  		cpy[i] = &errCpy
   115  	}
   116  	return cpy
   117  }
   118  
   119  // RegisterExtension allows you to add a new extension into the graphql response
   120  func RegisterExtension(ctx context.Context, key string, value interface{}) {
   121  	c := getResponseContext(ctx)
   122  	c.extensionsMu.Lock()
   123  	defer c.extensionsMu.Unlock()
   124  
   125  	if c.extensions == nil {
   126  		c.extensions = make(map[string]interface{})
   127  	}
   128  
   129  	if _, ok := c.extensions[key]; ok {
   130  		panic(fmt.Errorf("extension already registered for key %s", key))
   131  	}
   132  
   133  	c.extensions[key] = value
   134  }
   135  
   136  // GetExtensions returns any extensions registered in the current result context
   137  func GetExtensions(ctx context.Context) map[string]interface{} {
   138  	ext := getResponseContext(ctx).extensions
   139  	if ext == nil {
   140  		return map[string]interface{}{}
   141  	}
   142  
   143  	return ext
   144  }
   145  
   146  func GetExtension(ctx context.Context, name string) interface{} {
   147  	ext := getResponseContext(ctx).extensions
   148  	if ext == nil {
   149  		return nil
   150  	}
   151  
   152  	return ext[name]
   153  }