github.com/bartle-stripe/trillian@v1.2.1/testonly/hammer/replay.go (about)

     1  // Copyright 2018 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hammer
    16  
    17  import (
    18  	"context"
    19  	"encoding/binary"
    20  	"fmt"
    21  	"io"
    22  	"os"
    23  	"reflect"
    24  	"sync"
    25  
    26  	"github.com/golang/glog"
    27  	"github.com/golang/protobuf/proto"
    28  	"github.com/golang/protobuf/ptypes"
    29  	"github.com/golang/protobuf/ptypes/any"
    30  	"github.com/google/trillian"
    31  	"google.golang.org/grpc"
    32  )
    33  
    34  type recordingInterceptor struct {
    35  	mu     sync.Mutex
    36  	outLog io.Writer
    37  }
    38  
    39  // NewRecordingInterceptor returns a grpc.UnaryClientInterceptor that logs outgoing
    40  // requests to file.
    41  func NewRecordingInterceptor(filename string) (grpc.UnaryClientInterceptor, error) {
    42  	o, err := os.Create(filename)
    43  	if err != nil {
    44  		return nil, fmt.Errorf("failed to create log file: %v", err)
    45  	}
    46  	ri := recordingInterceptor{outLog: o}
    47  	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    48  		return ri.invoke(ctx, method, req, reply, cc, invoker, opts...)
    49  	}, nil
    50  }
    51  
    52  func (ri *recordingInterceptor) invoke(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    53  	if msg, ok := req.(proto.Message); ok {
    54  		ri.dumpMessage(msg)
    55  	} else {
    56  		glog.Warningf("failed to convert request %T to proto.Message", req)
    57  	}
    58  	err := invoker(ctx, method, req, reply, cc, opts...)
    59  	if err == nil {
    60  		if msg, ok := reply.(proto.Message); ok {
    61  			ri.dumpMessage(msg)
    62  		} else {
    63  			glog.Warningf("failed to convert response %T to proto.Message", req)
    64  		}
    65  	}
    66  	return err
    67  }
    68  
    69  func (ri *recordingInterceptor) dumpMessage(in proto.Message) {
    70  	ri.mu.Lock()
    71  	defer ri.mu.Unlock()
    72  	if err := writeMessage(ri.outLog, in); err != nil {
    73  		glog.Error(err.Error())
    74  	}
    75  }
    76  
    77  func writeMessage(w io.Writer, in proto.Message) error {
    78  	a, err := ptypes.MarshalAny(in)
    79  	if err != nil {
    80  		return fmt.Errorf("failed to marshal %T %+v to any.Any: %v", in, in, err)
    81  	}
    82  	data, err := proto.Marshal(a)
    83  	if err != nil {
    84  		return fmt.Errorf("failed to marshal any.Any: %v", err)
    85  	}
    86  	// Encode as [4-byte big-endian length, message]
    87  	lenData := make([]byte, 4)
    88  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
    89  	w.Write(lenData)
    90  	w.Write(data)
    91  	return nil
    92  }
    93  
    94  func readMessage(r io.Reader) (*any.Any, error) {
    95  	// Decode from [4-byte big-endian length, message]
    96  	var l uint32
    97  	if err := binary.Read(r, binary.BigEndian, &l); err != nil {
    98  		if err != io.EOF {
    99  			err = fmt.Errorf("corrupt data: expected 4-byte length: %v", err)
   100  		}
   101  		return nil, err
   102  	}
   103  	data := make([]byte, l)
   104  	n, err := r.Read(data)
   105  	if uint32(n) < l {
   106  		return nil, fmt.Errorf("corrupt data: expected %d bytes of data, only found %d bytes", n, l)
   107  	}
   108  	if err != nil {
   109  		return nil, fmt.Errorf("failed to read %d bytes of data: %v", l, err)
   110  	}
   111  	var a any.Any
   112  	if err := proto.Unmarshal(data, &a); err != nil {
   113  		return nil, fmt.Errorf("failed to unmarshal into any.Any: %v", err)
   114  	}
   115  	return &a, nil
   116  }
   117  
   118  // ReplayFile reads recorded gRPC requests and re-issues them using the given
   119  // client.  If a request has a MapId field, and its value is present in mapmap,
   120  // then the MapId field is replaced before replay.
   121  func ReplayFile(ctx context.Context, r io.Reader, cl trillian.TrillianMapClient, mapmap map[int64]int64) {
   122  	for {
   123  		a, err := readMessage(r)
   124  		if err != nil {
   125  			if err != io.EOF {
   126  				glog.Errorf("Error reading message: %v", err)
   127  			}
   128  			return
   129  		}
   130  		glog.V(2).Infof("Replay %q", a.TypeUrl)
   131  		replayMessage(ctx, cl, a, mapmap)
   132  	}
   133  }
   134  
   135  // convertMessage modifies msg in-place so that the contents of a "MapId" field
   136  // are updated according to mapmap.
   137  func convertMessage(msg proto.Message, mapmap map[int64]int64) {
   138  	// Look for a "MapId" field that we can overwrite if needed.
   139  	pVal := reflect.ValueOf(msg)
   140  	if fieldVal := pVal.Elem().FieldByName("MapId"); fieldVal.CanSet() {
   141  		from := fieldVal.Int()
   142  		if to, ok := mapmap[from]; ok {
   143  			glog.V(2).Infof("Replacing msg.MapId=%d with %d in %T", from, to, msg)
   144  			fieldVal.SetInt(to)
   145  		}
   146  	}
   147  }
   148  
   149  func replayMessage(ctx context.Context, cl trillian.TrillianMapClient, a *any.Any, mapmap map[int64]int64) error {
   150  	var da ptypes.DynamicAny
   151  	if err := ptypes.UnmarshalAny(a, &da); err != nil {
   152  		return fmt.Errorf("failed to unmarshal from any.Any: %v", err)
   153  	}
   154  	req := da.Message
   155  	convertMessage(req, mapmap)
   156  	glog.V(2).Infof("Request req=%T %+v", req, req)
   157  	var err error
   158  	var rsp proto.Message
   159  	if cl != nil {
   160  		switch req := req.(type) {
   161  		case *trillian.GetMapLeavesRequest:
   162  			rsp, err = cl.GetLeaves(ctx, req)
   163  		case *trillian.GetMapLeavesByRevisionRequest:
   164  			rsp, err = cl.GetLeavesByRevision(ctx, req)
   165  		case *trillian.SetMapLeavesRequest:
   166  			rsp, err = cl.SetLeaves(ctx, req)
   167  		case *trillian.GetSignedMapRootRequest:
   168  			rsp, err = cl.GetSignedMapRoot(ctx, req)
   169  		case *trillian.GetSignedMapRootByRevisionRequest:
   170  			rsp, err = cl.GetSignedMapRootByRevision(ctx, req)
   171  		case *trillian.InitMapRequest:
   172  			rsp, err = cl.InitMap(ctx, req)
   173  		}
   174  		if rsp != nil {
   175  			glog.V(1).Infof("Request:  req=%T %+v", req, req)
   176  			glog.V(1).Infof("Response: rsp=%T %+v err=%v", rsp, rsp, err)
   177  		}
   178  	}
   179  	return err
   180  }