github.com/philippseith/signalr@v0.6.3/hubprotocol_test.go (about)

     1  package signalr
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/parser"
     8  	"go/token"
     9  	"io"
    10  	"reflect"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/dave/jennifer/jen"
    15  	. "github.com/onsi/ginkgo"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("Protocol", func() {
    20  	for _, p := range []hubProtocol{
    21  		&jsonHubProtocol{},
    22  		&messagePackHubProtocol{},
    23  	} {
    24  		protocol := p
    25  		protocol.setDebugLogger(testLogger())
    26  		Describe(fmt.Sprintf("%T: WriteMessage/ParseMessages roundtrip", protocol), func() {
    27  			Context("InvocationMessage", func() {
    28  				for _, a := range [][]interface{}{
    29  					make([]interface{}, 0),
    30  					{1, 2, 3},
    31  					{1, 0xffffff},
    32  					{-5, []int{1000, 2}, simpleStruct{AsInt: 3, AsString: "3"}},
    33  					{[]simpleStruct{
    34  						{AsInt: 3, AsString: "3"},
    35  						{AsInt: 40, AsString: "40"},
    36  					}},
    37  					{map[string]int{"1": 2, "2": 4, "3": 8}},
    38  					{map[int]simpleStruct{1: {AsInt: 1, AsString: "1"}, 2: {AsInt: 2, AsString: "2"}}},
    39  				} {
    40  					arguments := a
    41  					want := invocationMessage{
    42  						Type:         1,
    43  						Target:       "A",
    44  						InvocationID: "B",
    45  						Arguments:    arguments,
    46  						StreamIds:    []string{"C", "D"},
    47  					}
    48  					It(fmt.Sprintf("be equal after roundtrip with arguments %v", arguments), func() {
    49  						buf := bytes.Buffer{}
    50  						Expect(protocol.WriteMessage(want, &buf)).NotTo(HaveOccurred())
    51  						Expect(len(msg)).NotTo(Equal(0))
    52  						var remainBuf bytes.Buffer
    53  						got, err := protocol.ParseMessages(&buf, &remainBuf)
    54  						Expect(err).NotTo(HaveOccurred())
    55  						Expect(len(got)).To(Equal(1))
    56  						Expect(got[0]).To(BeAssignableToTypeOf(invocationMessage{}))
    57  						gotMsg := got[0].(invocationMessage)
    58  						Expect(gotMsg.Target).To(Equal(want.Target))
    59  						Expect(gotMsg.InvocationID).To(Equal(want.InvocationID))
    60  						Expect(gotMsg.StreamIds).To(Equal(want.StreamIds))
    61  						Expect(len(gotMsg.Arguments)).To(Equal(len(want.Arguments)))
    62  						for i, gotArg := range gotMsg.Arguments {
    63  							// We can not directly compare gotArg and want.Arguments[i]
    64  							// because msgpack serializes numbers to the shortest possible type
    65  							t := reflect.TypeOf(want.Arguments[i])
    66  							value := reflect.New(t)
    67  							Expect(protocol.UnmarshalArgument(gotArg, value.Interface())).NotTo(HaveOccurred())
    68  							Expect(reflect.Indirect(value).Interface()).To(Equal(want.Arguments[i]))
    69  						}
    70  					})
    71  				}
    72  			})
    73  			Context("StreamItemMessage", func() {
    74  				for _, w := range []streamItemMessage{
    75  					{Type: 2, InvocationID: "1", Item: "3"},
    76  					{Type: 2, InvocationID: "2", Item: 3},
    77  					{Type: 2, InvocationID: "3", Item: uint(3)},
    78  					{Type: 2, InvocationID: "4", Item: simpleStruct{AsInt: 3, AsString: "3"}},
    79  					{Type: 2, InvocationID: "5", Item: []int64{1, 2, 3}},
    80  					{Type: 2, InvocationID: "6", Item: []int{1, 2, 3}},
    81  					{Type: 2, InvocationID: "7", Item: map[string]int{"1": 4, "2": 5, "3": 6}},
    82  					{Type: 2, InvocationID: "9"},
    83  				} {
    84  					want := w
    85  					It(fmt.Sprintf("should be equal after roundtrip of %#v", want), func(done Done) {
    86  						buf := bytes.Buffer{}
    87  						Expect(protocol.WriteMessage(want, &buf)).NotTo(HaveOccurred())
    88  						var remainBuf bytes.Buffer
    89  						got, err := protocol.ParseMessages(&buf, &remainBuf)
    90  						Expect(err).NotTo(HaveOccurred())
    91  						Expect(len(got)).To(Equal(1))
    92  						Expect(got[0]).To(BeAssignableToTypeOf(streamItemMessage{}))
    93  						gotMsg := got[0].(streamItemMessage)
    94  						Expect(gotMsg.InvocationID).To(Equal(want.InvocationID))
    95  						if want.Item == nil {
    96  							var v interface{}
    97  							Expect(protocol.UnmarshalArgument(gotMsg.Item, &v)).NotTo(HaveOccurred())
    98  							Expect(v).To(BeNil())
    99  						} else {
   100  							// We can not directly compare gotArg and want.Arguments[i]
   101  							// because msgpack serializes numbers to the shortest possible type
   102  							t := reflect.TypeOf(want.Item)
   103  							value := reflect.New(t)
   104  							Expect(protocol.UnmarshalArgument(gotMsg.Item, value.Interface())).NotTo(HaveOccurred())
   105  							Expect(reflect.Indirect(value).Interface()).To(Equal(want.Item))
   106  						}
   107  						close(done)
   108  					})
   109  				}
   110  			})
   111  			Context("CompletionMessage", func() {
   112  				for _, w := range []completionMessage{
   113  					{Type: 3, InvocationID: "1", Result: "3"},
   114  					{Type: 3, InvocationID: "2", Result: 3},
   115  					{Type: 3, InvocationID: "3", Result: uint(3)},
   116  					{Type: 3, InvocationID: "4", Result: simpleStruct{AsInt: 3, AsString: "3"}},
   117  					{Type: 3, InvocationID: "5", Result: []int64{1, 2, 3}},
   118  					{Type: 3, InvocationID: "6", Result: []int{1, 2, 3}},
   119  					{Type: 3, InvocationID: "7", Result: map[string]int{"1": 4, "2": 5, "3": 6}},
   120  					{Type: 3, InvocationID: "8"},
   121  					{Type: 3, InvocationID: "9", Error: "Failed"},
   122  				} {
   123  					want := w
   124  					It(fmt.Sprintf("should be equal after roundtrip of %#v", want), func(done Done) {
   125  						buf := bytes.Buffer{}
   126  						Expect(protocol.WriteMessage(want, &buf)).NotTo(HaveOccurred())
   127  						var remainBuf bytes.Buffer
   128  						got, err := protocol.ParseMessages(&buf, &remainBuf)
   129  						Expect(err).NotTo(HaveOccurred())
   130  						Expect(len(got)).To(Equal(1))
   131  						Expect(got[0]).To(BeAssignableToTypeOf(completionMessage{}))
   132  						gotMsg := got[0].(completionMessage)
   133  						Expect(gotMsg.InvocationID).To(Equal(want.InvocationID))
   134  						if want.Result == nil {
   135  							// Important: In contrast to StreamItemMessage a nil Result is not transmitted
   136  							// So if a stream ends with a nil item,
   137  							// a sender can not send a completionMessage with nil result to transmit this!
   138  							Expect(gotMsg.Result).To(BeNil())
   139  							Expect(gotMsg.Error).To(Equal(want.Error))
   140  						} else {
   141  							// We can not directly compare gotArg and want.Arguments[i]
   142  							// because msgpack serializes numbers to the shortest possible type
   143  							t := reflect.TypeOf(want.Result)
   144  							value := reflect.New(t)
   145  							Expect(protocol.UnmarshalArgument(gotMsg.Result, value.Interface())).NotTo(HaveOccurred())
   146  							Expect(reflect.Indirect(value).Interface()).To(Equal(want.Result))
   147  							Expect(gotMsg.Error).To(Equal(want.Error))
   148  						}
   149  						close(done)
   150  					})
   151  				}
   152  			})
   153  			Context("Multiple messages", func() {
   154  				It("should parse multiple messages sent in one step", func(done Done) {
   155  					buf := bytes.Buffer{}
   156  					streamItem := streamItemMessage{Type: 2, InvocationID: "2", Item: "A"}
   157  					completion := completionMessage{Type: 3, InvocationID: "2", Result: "B"}
   158  					Expect(protocol.WriteMessage(streamItem, &buf)).NotTo(HaveOccurred())
   159  					Expect(protocol.WriteMessage(completion, &buf)).NotTo(HaveOccurred())
   160  					var remainBuf bytes.Buffer
   161  					got, err := protocol.ParseMessages(&buf, &remainBuf)
   162  					Expect(err).NotTo(HaveOccurred())
   163  					Expect(len(got)).To(Equal(2))
   164  					Expect(got[0]).To(BeAssignableToTypeOf(streamItemMessage{}))
   165  					gotStreamItem := got[0].(streamItemMessage)
   166  					var item string
   167  					Expect(protocol.UnmarshalArgument(gotStreamItem.Item, &item)).NotTo(HaveOccurred())
   168  					Expect(item).To(Equal(streamItem.Item))
   169  					Expect(got[1]).To(BeAssignableToTypeOf(completionMessage{}))
   170  					gotCompletion := got[1].(completionMessage)
   171  					var result string
   172  					Expect(protocol.UnmarshalArgument(gotCompletion.Result, &result)).NotTo(HaveOccurred())
   173  					Expect(result).To(Equal(completion.Result))
   174  					close(done)
   175  				})
   176  			})
   177  			Context("Partial messages", func() {
   178  				It("should parse a message sent in two steps", func(done Done) {
   179  					messageBuf := &bytes.Buffer{}
   180  					streamItem := streamItemMessage{Type: 2, InvocationID: "2", Item: "A"}
   181  					Expect(protocol.WriteMessage(streamItem, messageBuf)).NotTo(HaveOccurred())
   182  					reader, writer := io.Pipe()
   183  					var remainBuf bytes.Buffer
   184  					// Store incomplete frame
   185  					go func() {
   186  						defer GinkgoRecover()
   187  						_, err := writer.Write(messageBuf.Bytes()[:messageBuf.Len()-2])
   188  						Expect(err).NotTo(HaveOccurred())
   189  					}()
   190  					up := make(chan struct{}, 1)
   191  					go func() {
   192  						defer GinkgoRecover()
   193  						up <- struct{}{}
   194  						got, err := protocol.ParseMessages(reader, &remainBuf)
   195  						Expect(err).NotTo(HaveOccurred())
   196  						Expect(len(got)).To(Equal(1))
   197  						Expect(got[0]).To(BeAssignableToTypeOf(streamItemMessage{}))
   198  						gotStreamItem := got[0].(streamItemMessage)
   199  						var item string
   200  						Expect(protocol.UnmarshalArgument(gotStreamItem.Item, &item)).NotTo(HaveOccurred())
   201  						Expect(item).To(Equal(streamItem.Item))
   202  						close(done)
   203  					}()
   204  					// Wait for parse to be started
   205  					<-up
   206  					// Let parse hang a while
   207  					<-time.After(time.Millisecond * 200)
   208  					// Write the rest of the frame
   209  					_, err := writer.Write(messageBuf.Bytes()[messageBuf.Len()-2:])
   210  					Expect(err).NotTo(HaveOccurred())
   211  				}, 2.0)
   212  			})
   213  		})
   214  	}
   215  })
   216  
   217  func TestDevParse(t *testing.T) {
   218  	if err := devParse(); err != nil {
   219  		t.Error(err)
   220  	}
   221  }
   222  
   223  //type simplestStruct struct {
   224  //	AsInt int
   225  //}
   226  
   227  type simpleStruct struct {
   228  	AsInt    int    `json:"AI"`
   229  	AsString string `json:"AS"`
   230  }
   231  
   232  //type parserHub struct {
   233  //	Hub
   234  //}
   235  //
   236  //func (p *parserHub) Parse(fileName string) []string {
   237  //	return nil
   238  //}
   239  
   240  func devParse() error {
   241  	fSet := token.NewFileSet()
   242  	file, err := parser.ParseFile(fSet, "hubprotocol_test.go", nil, parser.AllErrors)
   243  	if err != nil {
   244  		return err
   245  	}
   246  	g := generator{hubs: make(map[string]*hubInfo)}
   247  	ast.Walk(&g, file)
   248  	g.Generate()
   249  	return nil
   250  }
   251  
   252  type generator struct {
   253  	//packageName string
   254  	hubs map[string]*hubInfo
   255  }
   256  
   257  type hubInfo struct {
   258  	receiver  string
   259  	funcDecls []*ast.FuncDecl
   260  }
   261  
   262  func (g *generator) Generate() {
   263  	f := jen.NewFile("t1")
   264  	for hub, hubInfo := range g.hubs {
   265  		g.generateInvokeProtocol(f, "JSON", hub, hubInfo)
   266  		g.generateInvokeProtocol(f, "MessagePack", hub, hubInfo)
   267  	}
   268  	fmt.Printf("%#v", f)
   269  }
   270  
   271  func (g *generator) generateInvokeProtocol(f *jen.File, protocol string, hub string, hubInfo *hubInfo) {
   272  	targetCases := make([]jen.Code, 0)
   273  	for _, funcDecl := range hubInfo.funcDecls {
   274  		targetCases = append(targetCases, jen.Case(jen.Lit(funcDecl.Name.Name)).
   275  			Block(
   276  				jen.Return(jen.Id("Invoke"+funcDecl.Name.Name+protocol)).
   277  					Params(
   278  						jen.Id("arguments"),
   279  						jen.Id("streamIds"))))
   280  	}
   281  	f.Func().Params(jen.Id(hubInfo.receiver).Op("*").Id(hub)).Id("Invoke"+protocol).
   282  		Params(
   283  			jen.Id("target").String(),
   284  			jen.Id("arguments").Interface(),
   285  			jen.Id("streamIds").Index().String()).
   286  		Params(jen.Interface(), jen.Error()).
   287  		Block(
   288  			jen.Switch(jen.Id("target")).
   289  				Block(targetCases...),
   290  			jen.Return(jen.Nil(), jen.Qual("errors", "New").
   291  				Params(
   292  					jen.Lit("invalid target ").Op("+").Id("target"))))
   293  
   294  }
   295  
   296  func (g *generator) Visit(node ast.Node) (w ast.Visitor) {
   297  	if node == nil {
   298  		return nil
   299  	}
   300  	switch value := node.(type) {
   301  	case *ast.TypeSpec:
   302  		if structType, ok := value.Type.(*ast.StructType); ok {
   303  			if len(structType.Fields.List) > 0 {
   304  				if ident, ok := structType.Fields.List[0].Type.(*ast.Ident); ok && ident.Name == "Hub" {
   305  					if _, ok := g.hubs[value.Name.Name]; !ok {
   306  						g.hubs[value.Name.Name] = &hubInfo{
   307  							funcDecls: make([]*ast.FuncDecl, 0),
   308  						}
   309  					}
   310  				}
   311  			}
   312  		}
   313  	case *ast.FuncDecl:
   314  		if value.Recv != nil && len(value.Recv.List) == 1 {
   315  			switch recvType := value.Recv.List[0].Type.(type) {
   316  			case *ast.Ident:
   317  				if hubInfo, ok := g.hubs[recvType.Name]; ok {
   318  					hubInfo.receiver = value.Recv.List[0].Names[0].Name
   319  					hubInfo.funcDecls = append(hubInfo.funcDecls, value)
   320  				}
   321  			case *ast.StarExpr:
   322  				if recvType, ok := recvType.X.(*ast.Ident); ok {
   323  					if hubInfo, ok := g.hubs[recvType.Name]; ok {
   324  						hubInfo.receiver = value.Recv.List[0].Names[0].Name
   325  						hubInfo.funcDecls = append(hubInfo.funcDecls, value)
   326  					}
   327  				}
   328  			}
   329  		}
   330  	}
   331  	return g
   332  }