github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/apiserver/observer/recorder.go (about)

     1  // Copyright 2017 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package observer
     5  
     6  import (
     7  	"encoding/json"
     8  	"reflect"
     9  
    10  	"github.com/juju/errors"
    11  
    12  	"github.com/juju/juju/apiserver/params"
    13  	"github.com/juju/juju/core/auditlog"
    14  	"github.com/juju/juju/rpc"
    15  )
    16  
    17  const (
    18  	// CaptureArgs means we'll serialize the API arguments and store
    19  	// them in the audit log.
    20  	CaptureArgs = true
    21  
    22  	// NoCaptureArgs means don't do that.
    23  	NoCaptureArgs = false
    24  )
    25  
    26  // NewRecorderFactory makes a new rpc.RecorderFactory to make
    27  // recorders that that will update the observer and the auditlog
    28  // recorder when it records a request or reply. The auditlog recorder
    29  // can be nil.
    30  func NewRecorderFactory(
    31  	observerFactory rpc.ObserverFactory,
    32  	recorder *auditlog.Recorder,
    33  	captureArgs bool,
    34  ) rpc.RecorderFactory {
    35  	return func() rpc.Recorder {
    36  		return &combinedRecorder{
    37  			observer:    observerFactory.RPCObserver(),
    38  			recorder:    recorder,
    39  			captureArgs: captureArgs,
    40  		}
    41  	}
    42  }
    43  
    44  // combinedRecorder wraps an observer (which might be a multiplexer)
    45  // up with an auditlog recorder into an rpc.Recorder.
    46  type combinedRecorder struct {
    47  	observer    rpc.Observer
    48  	recorder    *auditlog.Recorder
    49  	captureArgs bool
    50  }
    51  
    52  // HandleRequest implements rpc.Recorder.
    53  func (cr *combinedRecorder) HandleRequest(hdr *rpc.Header, body interface{}) error {
    54  	cr.observer.ServerRequest(hdr, body)
    55  	if cr.recorder == nil {
    56  		return nil
    57  	}
    58  	var args string
    59  	if cr.captureArgs {
    60  		jsonArgs, err := json.Marshal(body)
    61  		if err != nil {
    62  			return errors.Trace(err)
    63  		}
    64  		args = string(jsonArgs)
    65  	}
    66  	return errors.Trace(cr.recorder.AddRequest(auditlog.RequestArgs{
    67  		RequestID: hdr.RequestId,
    68  		Facade:    hdr.Request.Type,
    69  		Method:    hdr.Request.Action,
    70  		Version:   hdr.Request.Version,
    71  		Args:      args,
    72  	}))
    73  }
    74  
    75  // HandleReply implements rpc.Recorder.
    76  func (cr *combinedRecorder) HandleReply(req rpc.Request, replyHdr *rpc.Header, body interface{}) error {
    77  	cr.observer.ServerReply(req, replyHdr, body)
    78  	if cr.recorder == nil {
    79  		return nil
    80  	}
    81  	var responseErrors []*auditlog.Error
    82  	if replyHdr.Error == "" {
    83  		responseErrors = extractErrors(body)
    84  	} else {
    85  		responseErrors = []*auditlog.Error{{
    86  			Message: replyHdr.Error,
    87  			Code:    replyHdr.ErrorCode,
    88  		}}
    89  	}
    90  	return errors.Trace(cr.recorder.AddResponse(auditlog.ResponseErrorsArgs{
    91  		RequestID: replyHdr.RequestId,
    92  		Errors:    responseErrors,
    93  	}))
    94  }
    95  
    96  func extractErrors(body interface{}) []*auditlog.Error {
    97  	// To find errors in the API responses, we look for a struct where
    98  	// there is an attribute that is:
    99  	// * a slice of structs that have an attribute that is *Error or
   100  	// * a plain old *Error
   101  	// I thought we'd need to handle a []*Error as well, but it turns
   102  	// out we don't use that in the API.
   103  	value := reflect.ValueOf(body)
   104  	if value.Kind() != reflect.Struct {
   105  		return nil
   106  	}
   107  
   108  	// Prefer a slice of structs with Errors.
   109  	for i := 0; i < value.NumField(); i++ {
   110  		attr := value.Field(i)
   111  		if errors, ok := tryStructSliceErrors(attr); ok {
   112  			return convertErrors(errors)
   113  		}
   114  	}
   115  
   116  	for i := 0; i < value.NumField(); i++ {
   117  		attr := value.Field(i)
   118  		if err, ok := tryErrorPointer(attr); ok {
   119  			return convertErrors([]*params.Error{err})
   120  		}
   121  	}
   122  	return nil
   123  }
   124  
   125  func tryErrorPointer(value reflect.Value) (*params.Error, bool) {
   126  	if !value.CanInterface() {
   127  		return nil, false
   128  	}
   129  	err, ok := value.Interface().(*params.Error)
   130  	return err, ok
   131  }
   132  
   133  func tryStructSliceErrors(value reflect.Value) ([]*params.Error, bool) {
   134  	if value.Kind() != reflect.Slice {
   135  		return nil, false
   136  	}
   137  	itemType := value.Type().Elem()
   138  	if itemType.Kind() != reflect.Struct {
   139  		return nil, false
   140  	}
   141  	errorField, found := findErrorField(itemType)
   142  	if !found {
   143  		return nil, false
   144  	}
   145  
   146  	result := make([]*params.Error, value.Len())
   147  	for i := 0; i < value.Len(); i++ {
   148  		item := value.Index(i)
   149  		// We know item's a struct.
   150  		errorValue := item.Field(errorField)
   151  		// This will assign nil if we couldn't extract the field (it
   152  		// wasn't exported for example), but that's OK.
   153  		result[i], _ = tryErrorPointer(errorValue)
   154  	}
   155  	return result, true
   156  }
   157  
   158  var errorType = reflect.TypeOf(params.Error{})
   159  
   160  func findErrorField(itemType reflect.Type) (int, bool) {
   161  	for i := 0; i < itemType.NumField(); i++ {
   162  		field := itemType.Field(i)
   163  		if field.Type.Kind() != reflect.Ptr {
   164  			continue
   165  		}
   166  		if field.Type.Elem() == errorType {
   167  			return i, true
   168  		}
   169  	}
   170  	return 0, false
   171  }
   172  
   173  func convertErrors(errors []*params.Error) []*auditlog.Error {
   174  	result := make([]*auditlog.Error, len(errors))
   175  	for i, err := range errors {
   176  		if err == nil {
   177  			continue
   178  		}
   179  		result[i] = &auditlog.Error{
   180  			Message: err.Message,
   181  			Code:    err.Code,
   182  		}
   183  	}
   184  	return result
   185  }