github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/protobuf/plugin/raftproxy/raftproxy.go (about)

     1  package raftproxy
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
     7  	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
     8  )
     9  
    10  type raftProxyGen struct {
    11  	gen *generator.Generator
    12  }
    13  
    14  func init() {
    15  	generator.RegisterPlugin(new(raftProxyGen))
    16  }
    17  
    18  func (g *raftProxyGen) Init(gen *generator.Generator) {
    19  	g.gen = gen
    20  }
    21  
    22  func (g *raftProxyGen) Name() string {
    23  	return "raftproxy"
    24  }
    25  
    26  func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) {
    27  	g.gen.P("type " + serviceTypeName(s) + " struct {")
    28  	g.gen.P("\tlocal " + s.GetName() + "Server")
    29  	g.gen.P("\tconnSelector raftselector.ConnProvider")
    30  	g.gen.P("\tlocalCtxMods, remoteCtxMods []func(context.Context)(context.Context, error)")
    31  	g.gen.P("}")
    32  }
    33  
    34  func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) {
    35  	g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {")
    36  	g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) {
    37  		p, ok := peer.FromContext(ctx)
    38  		if !ok {
    39  			return ctx, status.Errorf(codes.InvalidArgument, "remote addr is not found in context")
    40  		}
    41  		addr := p.Addr.String()
    42  		md, ok := metadata.FromIncomingContext(ctx)
    43  		if ok && len(md["redirect"]) != 0 {
    44  			return ctx, status.Errorf(codes.ResourceExhausted, "more than one redirect to leader from: %s", md["redirect"])
    45  		}
    46  		if !ok {
    47  			md = metadata.New(map[string]string{})
    48  		}
    49  		md["redirect"] = append(md["redirect"], addr)
    50  		return metadata.NewOutgoingContext(ctx, md), nil
    51  	}
    52  	remoteMods := []func(context.Context)(context.Context, error){redirectChecker}
    53  	remoteMods = append(remoteMods, remoteCtxMod)
    54  
    55  	var localMods []func(context.Context)(context.Context, error)
    56  	if localCtxMod != nil {
    57  		localMods = []func(context.Context)(context.Context, error){localCtxMod}
    58  	}
    59  	`)
    60  	g.gen.P("return &" + serviceTypeName(s) + `{
    61  		local: local,
    62  		connSelector: connSelector,
    63  		localCtxMods: localMods,
    64  		remoteCtxMods: remoteMods,
    65  	}`)
    66  	g.gen.P("}")
    67  }
    68  
    69  func (g *raftProxyGen) genRunCtxMods(s *descriptor.ServiceDescriptorProto) {
    70  	g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context, ctxMods []func(context.Context)(context.Context, error)) (context.Context, error) {
    71  	var err error
    72  	for _, mod := range ctxMods {
    73  		ctx, err = mod(ctx)
    74  		if err != nil {
    75  			return ctx, err
    76  		}
    77  	}
    78  	return ctx, nil
    79  }`)
    80  }
    81  
    82  func getInputTypeName(m *descriptor.MethodDescriptorProto) string {
    83  	parts := strings.Split(m.GetInputType(), ".")
    84  	return parts[len(parts)-1]
    85  }
    86  
    87  func getOutputTypeName(m *descriptor.MethodDescriptorProto) string {
    88  	parts := strings.Split(m.GetOutputType(), ".")
    89  	return parts[len(parts)-1]
    90  }
    91  
    92  func serviceTypeName(s *descriptor.ServiceDescriptorProto) string {
    93  	return "raftProxy" + s.GetName() + "Server"
    94  }
    95  
    96  func sigPrefix(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) string {
    97  	return "func (p *" + serviceTypeName(s) + ") " + m.GetName() + "("
    98  }
    99  
   100  func (g *raftProxyGen) genStreamWrapper(streamType string) {
   101  	// Generate stream wrapper that returns a modified context
   102  	g.gen.P(`type ` + streamType + `Wrapper struct {
   103  	` + streamType + `
   104  	ctx context.Context
   105  }
   106  `)
   107  	g.gen.P(`func (s ` + streamType + `Wrapper) Context() context.Context {
   108  	return s.ctx
   109  }
   110  `)
   111  }
   112  
   113  func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
   114  	streamType := s.GetName() + "_" + m.GetName() + "Server"
   115  
   116  	// Generate stream wrapper that returns a modified context
   117  	g.genStreamWrapper(streamType)
   118  
   119  	g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error {
   120  	ctx := stream.Context()
   121  	conn, err := p.connSelector.LeaderConn(ctx)
   122  	if err != nil {
   123  		if err == raftselector.ErrIsLeader {
   124  			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
   125  			if err != nil {
   126  				return err
   127  			}
   128  			streamWrapper := ` + streamType + `Wrapper{
   129  				` + streamType + `: stream,
   130  				ctx: ctx,
   131  			}
   132  			return p.local.` + m.GetName() + `(streamWrapper)
   133  		}
   134  		return err
   135  	}
   136  	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
   137  	if err != nil {
   138  		return err
   139  	}`)
   140  	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
   141  	g.gen.P(`
   142  	if err != nil {
   143  			return err
   144  	}`)
   145  	g.gen.P(`
   146  	for {
   147  		msg, err := stream.Recv()
   148  		if err == io.EOF {
   149  			break
   150  		}
   151  		if err != nil {
   152  			return err
   153  		}
   154  		if err := clientStream.Send(msg); err != nil {
   155  			return err
   156  		}
   157  	}
   158  
   159  	reply, err := clientStream.CloseAndRecv()
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	return stream.SendAndClose(reply)`)
   165  	g.gen.P("}")
   166  }
   167  
   168  func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
   169  	streamType := s.GetName() + "_" + m.GetName() + "Server"
   170  
   171  	g.genStreamWrapper(streamType)
   172  
   173  	g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error {
   174  	ctx := stream.Context()
   175  	conn, err := p.connSelector.LeaderConn(ctx)
   176  	if err != nil {
   177  		if err == raftselector.ErrIsLeader {
   178  			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
   179  			if err != nil {
   180  				return err
   181  			}
   182  			streamWrapper := ` + streamType + `Wrapper{
   183  				` + streamType + `: stream,
   184  				ctx: ctx,
   185  			}
   186  			return p.local.` + m.GetName() + `(r, streamWrapper)
   187  		}
   188  		return err
   189  	}
   190  	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
   191  	if err != nil {
   192  		return err
   193  	}`)
   194  	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)")
   195  	g.gen.P(`
   196  	if err != nil {
   197  			return err
   198  	}`)
   199  	g.gen.P(`
   200  	for {
   201  		msg, err := clientStream.Recv()
   202  		if err == io.EOF {
   203  			break
   204  		}
   205  		if err != nil {
   206  			return err
   207  		}
   208  		if err := stream.Send(msg); err != nil {
   209  			return err
   210  		}
   211  	}
   212  	return nil`)
   213  	g.gen.P("}")
   214  }
   215  
   216  func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
   217  	streamType := s.GetName() + "_" + m.GetName() + "Server"
   218  
   219  	g.genStreamWrapper(streamType)
   220  
   221  	g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error {
   222  	ctx := stream.Context()
   223  	conn, err := p.connSelector.LeaderConn(ctx)
   224  	if err != nil {
   225  		if err == raftselector.ErrIsLeader {
   226  			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
   227  			if err != nil {
   228  				return err
   229  			}
   230  			streamWrapper := ` + streamType + `Wrapper{
   231  				` + streamType + `: stream,
   232  				ctx: ctx,
   233  			}
   234  			return p.local.` + m.GetName() + `(streamWrapper)
   235  		}
   236  		return err
   237  	}
   238  	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
   239  	if err != nil {
   240  		return err
   241  	}`)
   242  	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
   243  	g.gen.P(`
   244  	if err != nil {
   245  			return err
   246  	}`)
   247  	g.gen.P(`errc := make(chan error, 1)
   248  	go func() {
   249  		msg, err := stream.Recv()
   250  		if err == io.EOF {
   251  			close(errc)
   252  			return
   253  		}
   254  		if err != nil {
   255  			errc <- err
   256  			return
   257  		}
   258  		if err := clientStream.Send(msg); err != nil {
   259  			errc <- err
   260  			return
   261  		}
   262  	}()`)
   263  	g.gen.P(`
   264  	for {
   265  		msg, err := clientStream.Recv()
   266  		if err == io.EOF {
   267  			break
   268  		}
   269  		if err != nil {
   270  			return err
   271  		}
   272  		if err := stream.Send(msg); err != nil {
   273  			return err
   274  		}
   275  	}
   276  	clientStream.CloseSend()
   277  	return <-errc`)
   278  	g.gen.P("}")
   279  }
   280  
   281  func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
   282  	g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {")
   283  	g.gen.P(`
   284  	conn, err := p.connSelector.LeaderConn(ctx)
   285  	if err != nil {
   286  		if err == raftselector.ErrIsLeader {
   287  			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
   288  			if err != nil {
   289  				return nil, err
   290  			}
   291  			return p.local.` + m.GetName() + `(ctx, r)
   292  		}
   293  		return nil, err
   294  	}
   295  	modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
   296  	if err != nil {
   297  		return nil, err
   298  	}`)
   299  	g.gen.P(`
   300  	resp, err := New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
   301  	if err != nil {
   302  		if !strings.Contains(err.Error(), "is closing") && !strings.Contains(err.Error(), "the connection is unavailable") && !strings.Contains(err.Error(), "connection error") {
   303  			return resp, err
   304  		}
   305  		conn, err := p.pollNewLeaderConn(ctx)
   306  		if err != nil {
   307  			if err == raftselector.ErrIsLeader {
   308  				return p.local.` + m.GetName() + `(ctx, r)
   309  			}
   310  			return nil, err
   311  		}
   312  		return New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
   313  	}`)
   314  	g.gen.P("return resp, err")
   315  	g.gen.P("}")
   316  }
   317  
   318  func (g *raftProxyGen) genProxyMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
   319  	g.gen.P()
   320  	switch {
   321  	case m.GetServerStreaming() && m.GetClientStreaming():
   322  		g.genClientServerStreamingMethod(s, m)
   323  	case m.GetServerStreaming():
   324  		g.genServerStreamingMethod(s, m)
   325  	case m.GetClientStreaming():
   326  		g.genClientStreamingMethod(s, m)
   327  	default:
   328  		g.genSimpleMethod(s, m)
   329  	}
   330  	g.gen.P()
   331  }
   332  
   333  func (g *raftProxyGen) genPollNewLeaderConn(s *descriptor.ServiceDescriptorProto) {
   334  	g.gen.P(`func (p *` + serviceTypeName(s) + `) pollNewLeaderConn(ctx context.Context) (*grpc.ClientConn, error) {
   335  		ticker := rafttime.NewTicker(500 * rafttime.Millisecond)
   336  		defer ticker.Stop()
   337  		for {
   338  			select {
   339  			case <-ticker.C:
   340  				conn, err := p.connSelector.LeaderConn(ctx)
   341  				if err != nil {
   342  					return nil, err
   343  				}
   344  
   345  				client := NewHealthClient(conn)
   346  
   347  				resp, err := client.Check(ctx, &HealthCheckRequest{Service: "Raft"})
   348  				if err != nil || resp.Status != HealthCheckResponse_SERVING {
   349  					continue
   350  				}
   351  				return conn, nil
   352  			case <-ctx.Done():
   353  				return nil, ctx.Err()
   354  			}
   355  		}
   356  	}`)
   357  }
   358  
   359  func (g *raftProxyGen) Generate(file *generator.FileDescriptor) {
   360  	if len(file.FileDescriptorProto.Service) == 0 {
   361  		return
   362  	}
   363  	g.gen.P()
   364  	for _, s := range file.Service {
   365  		g.genProxyStruct(s)
   366  		g.genProxyConstructor(s)
   367  		g.genRunCtxMods(s)
   368  		g.genPollNewLeaderConn(s)
   369  		for _, m := range s.Method {
   370  			g.genProxyMethod(s, m)
   371  		}
   372  	}
   373  	g.gen.P()
   374  }
   375  
   376  func (g *raftProxyGen) GenerateImports(file *generator.FileDescriptor) {
   377  	if len(file.Service) == 0 {
   378  		return
   379  	}
   380  	g.gen.PrintImport("raftselector", "github.com/docker/swarmkit/manager/raftselector")
   381  	g.gen.PrintImport("codes", "google.golang.org/grpc/codes")
   382  	g.gen.PrintImport("status", "google.golang.org/grpc/status")
   383  	g.gen.PrintImport("metadata", "google.golang.org/grpc/metadata")
   384  	g.gen.PrintImport("peer", "google.golang.org/grpc/peer")
   385  	// don't conflict with import added by ptypes
   386  	g.gen.PrintImport("rafttime", "time")
   387  }