github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/protobuf/plugin/raftproxy/raftproxy.go (about) 1 package raftproxy 2 3 import ( 4 "strings" 5 6 "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 7 "github.com/gogo/protobuf/protoc-gen-gogo/generator" 8 ) 9 10 type raftProxyGen struct { 11 gen *generator.Generator 12 } 13 14 func init() { 15 generator.RegisterPlugin(new(raftProxyGen)) 16 } 17 18 func (g *raftProxyGen) Init(gen *generator.Generator) { 19 g.gen = gen 20 } 21 22 func (g *raftProxyGen) Name() string { 23 return "raftproxy" 24 } 25 26 func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) { 27 g.gen.P("type " + serviceTypeName(s) + " struct {") 28 g.gen.P("\tlocal " + s.GetName() + "Server") 29 g.gen.P("\tconnSelector raftselector.ConnProvider") 30 g.gen.P("\tlocalCtxMods, remoteCtxMods []func(context.Context)(context.Context, error)") 31 g.gen.P("}") 32 } 33 34 func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) { 35 g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") 36 g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) { 37 p, ok := peer.FromContext(ctx) 38 if !ok { 39 return ctx, status.Errorf(codes.InvalidArgument, "remote addr is not found in context") 40 } 41 addr := p.Addr.String() 42 md, ok := metadata.FromIncomingContext(ctx) 43 if ok && len(md["redirect"]) != 0 { 44 return ctx, status.Errorf(codes.ResourceExhausted, "more than one redirect to leader from: %s", md["redirect"]) 45 } 46 if !ok { 47 md = metadata.New(map[string]string{}) 48 } 49 md["redirect"] = append(md["redirect"], addr) 50 return metadata.NewOutgoingContext(ctx, md), nil 51 } 52 remoteMods := []func(context.Context)(context.Context, error){redirectChecker} 53 remoteMods = append(remoteMods, remoteCtxMod) 54 55 var localMods []func(context.Context)(context.Context, error) 56 if localCtxMod != nil { 57 localMods = []func(context.Context)(context.Context, error){localCtxMod} 58 } 59 `) 60 g.gen.P("return &" + serviceTypeName(s) + `{ 61 local: local, 62 connSelector: connSelector, 63 localCtxMods: localMods, 64 remoteCtxMods: remoteMods, 65 }`) 66 g.gen.P("}") 67 } 68 69 func (g *raftProxyGen) genRunCtxMods(s *descriptor.ServiceDescriptorProto) { 70 g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context, ctxMods []func(context.Context)(context.Context, error)) (context.Context, error) { 71 var err error 72 for _, mod := range ctxMods { 73 ctx, err = mod(ctx) 74 if err != nil { 75 return ctx, err 76 } 77 } 78 return ctx, nil 79 }`) 80 } 81 82 func getInputTypeName(m *descriptor.MethodDescriptorProto) string { 83 parts := strings.Split(m.GetInputType(), ".") 84 return parts[len(parts)-1] 85 } 86 87 func getOutputTypeName(m *descriptor.MethodDescriptorProto) string { 88 parts := strings.Split(m.GetOutputType(), ".") 89 return parts[len(parts)-1] 90 } 91 92 func serviceTypeName(s *descriptor.ServiceDescriptorProto) string { 93 return "raftProxy" + s.GetName() + "Server" 94 } 95 96 func sigPrefix(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) string { 97 return "func (p *" + serviceTypeName(s) + ") " + m.GetName() + "(" 98 } 99 100 func (g *raftProxyGen) genStreamWrapper(streamType string) { 101 // Generate stream wrapper that returns a modified context 102 g.gen.P(`type ` + streamType + `Wrapper struct { 103 ` + streamType + ` 104 ctx context.Context 105 } 106 `) 107 g.gen.P(`func (s ` + streamType + `Wrapper) Context() context.Context { 108 return s.ctx 109 } 110 `) 111 } 112 113 func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { 114 streamType := s.GetName() + "_" + m.GetName() + "Server" 115 116 // Generate stream wrapper that returns a modified context 117 g.genStreamWrapper(streamType) 118 119 g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { 120 ctx := stream.Context() 121 conn, err := p.connSelector.LeaderConn(ctx) 122 if err != nil { 123 if err == raftselector.ErrIsLeader { 124 ctx, err = p.runCtxMods(ctx, p.localCtxMods) 125 if err != nil { 126 return err 127 } 128 streamWrapper := ` + streamType + `Wrapper{ 129 ` + streamType + `: stream, 130 ctx: ctx, 131 } 132 return p.local.` + m.GetName() + `(streamWrapper) 133 } 134 return err 135 } 136 ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) 137 if err != nil { 138 return err 139 }`) 140 g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)") 141 g.gen.P(` 142 if err != nil { 143 return err 144 }`) 145 g.gen.P(` 146 for { 147 msg, err := stream.Recv() 148 if err == io.EOF { 149 break 150 } 151 if err != nil { 152 return err 153 } 154 if err := clientStream.Send(msg); err != nil { 155 return err 156 } 157 } 158 159 reply, err := clientStream.CloseAndRecv() 160 if err != nil { 161 return err 162 } 163 164 return stream.SendAndClose(reply)`) 165 g.gen.P("}") 166 } 167 168 func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { 169 streamType := s.GetName() + "_" + m.GetName() + "Server" 170 171 g.genStreamWrapper(streamType) 172 173 g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error { 174 ctx := stream.Context() 175 conn, err := p.connSelector.LeaderConn(ctx) 176 if err != nil { 177 if err == raftselector.ErrIsLeader { 178 ctx, err = p.runCtxMods(ctx, p.localCtxMods) 179 if err != nil { 180 return err 181 } 182 streamWrapper := ` + streamType + `Wrapper{ 183 ` + streamType + `: stream, 184 ctx: ctx, 185 } 186 return p.local.` + m.GetName() + `(r, streamWrapper) 187 } 188 return err 189 } 190 ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) 191 if err != nil { 192 return err 193 }`) 194 g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)") 195 g.gen.P(` 196 if err != nil { 197 return err 198 }`) 199 g.gen.P(` 200 for { 201 msg, err := clientStream.Recv() 202 if err == io.EOF { 203 break 204 } 205 if err != nil { 206 return err 207 } 208 if err := stream.Send(msg); err != nil { 209 return err 210 } 211 } 212 return nil`) 213 g.gen.P("}") 214 } 215 216 func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { 217 streamType := s.GetName() + "_" + m.GetName() + "Server" 218 219 g.genStreamWrapper(streamType) 220 221 g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { 222 ctx := stream.Context() 223 conn, err := p.connSelector.LeaderConn(ctx) 224 if err != nil { 225 if err == raftselector.ErrIsLeader { 226 ctx, err = p.runCtxMods(ctx, p.localCtxMods) 227 if err != nil { 228 return err 229 } 230 streamWrapper := ` + streamType + `Wrapper{ 231 ` + streamType + `: stream, 232 ctx: ctx, 233 } 234 return p.local.` + m.GetName() + `(streamWrapper) 235 } 236 return err 237 } 238 ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) 239 if err != nil { 240 return err 241 }`) 242 g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)") 243 g.gen.P(` 244 if err != nil { 245 return err 246 }`) 247 g.gen.P(`errc := make(chan error, 1) 248 go func() { 249 msg, err := stream.Recv() 250 if err == io.EOF { 251 close(errc) 252 return 253 } 254 if err != nil { 255 errc <- err 256 return 257 } 258 if err := clientStream.Send(msg); err != nil { 259 errc <- err 260 return 261 } 262 }()`) 263 g.gen.P(` 264 for { 265 msg, err := clientStream.Recv() 266 if err == io.EOF { 267 break 268 } 269 if err != nil { 270 return err 271 } 272 if err := stream.Send(msg); err != nil { 273 return err 274 } 275 } 276 clientStream.CloseSend() 277 return <-errc`) 278 g.gen.P("}") 279 } 280 281 func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { 282 g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {") 283 g.gen.P(` 284 conn, err := p.connSelector.LeaderConn(ctx) 285 if err != nil { 286 if err == raftselector.ErrIsLeader { 287 ctx, err = p.runCtxMods(ctx, p.localCtxMods) 288 if err != nil { 289 return nil, err 290 } 291 return p.local.` + m.GetName() + `(ctx, r) 292 } 293 return nil, err 294 } 295 modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) 296 if err != nil { 297 return nil, err 298 }`) 299 g.gen.P(` 300 resp, err := New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r) 301 if err != nil { 302 if !strings.Contains(err.Error(), "is closing") && !strings.Contains(err.Error(), "the connection is unavailable") && !strings.Contains(err.Error(), "connection error") { 303 return resp, err 304 } 305 conn, err := p.pollNewLeaderConn(ctx) 306 if err != nil { 307 if err == raftselector.ErrIsLeader { 308 return p.local.` + m.GetName() + `(ctx, r) 309 } 310 return nil, err 311 } 312 return New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r) 313 }`) 314 g.gen.P("return resp, err") 315 g.gen.P("}") 316 } 317 318 func (g *raftProxyGen) genProxyMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { 319 g.gen.P() 320 switch { 321 case m.GetServerStreaming() && m.GetClientStreaming(): 322 g.genClientServerStreamingMethod(s, m) 323 case m.GetServerStreaming(): 324 g.genServerStreamingMethod(s, m) 325 case m.GetClientStreaming(): 326 g.genClientStreamingMethod(s, m) 327 default: 328 g.genSimpleMethod(s, m) 329 } 330 g.gen.P() 331 } 332 333 func (g *raftProxyGen) genPollNewLeaderConn(s *descriptor.ServiceDescriptorProto) { 334 g.gen.P(`func (p *` + serviceTypeName(s) + `) pollNewLeaderConn(ctx context.Context) (*grpc.ClientConn, error) { 335 ticker := rafttime.NewTicker(500 * rafttime.Millisecond) 336 defer ticker.Stop() 337 for { 338 select { 339 case <-ticker.C: 340 conn, err := p.connSelector.LeaderConn(ctx) 341 if err != nil { 342 return nil, err 343 } 344 345 client := NewHealthClient(conn) 346 347 resp, err := client.Check(ctx, &HealthCheckRequest{Service: "Raft"}) 348 if err != nil || resp.Status != HealthCheckResponse_SERVING { 349 continue 350 } 351 return conn, nil 352 case <-ctx.Done(): 353 return nil, ctx.Err() 354 } 355 } 356 }`) 357 } 358 359 func (g *raftProxyGen) Generate(file *generator.FileDescriptor) { 360 if len(file.FileDescriptorProto.Service) == 0 { 361 return 362 } 363 g.gen.P() 364 for _, s := range file.Service { 365 g.genProxyStruct(s) 366 g.genProxyConstructor(s) 367 g.genRunCtxMods(s) 368 g.genPollNewLeaderConn(s) 369 for _, m := range s.Method { 370 g.genProxyMethod(s, m) 371 } 372 } 373 g.gen.P() 374 } 375 376 func (g *raftProxyGen) GenerateImports(file *generator.FileDescriptor) { 377 if len(file.Service) == 0 { 378 return 379 } 380 g.gen.PrintImport("raftselector", "github.com/docker/swarmkit/manager/raftselector") 381 g.gen.PrintImport("codes", "google.golang.org/grpc/codes") 382 g.gen.PrintImport("status", "google.golang.org/grpc/status") 383 g.gen.PrintImport("metadata", "google.golang.org/grpc/metadata") 384 g.gen.PrintImport("peer", "google.golang.org/grpc/peer") 385 // don't conflict with import added by ptypes 386 g.gen.PrintImport("rafttime", "time") 387 }