github.com/MetalBlockchain/metalgo@v1.11.9/vms/rpcchainvm/ghttp/gresponsewriter/writer_server.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package gresponsewriter
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net/http"
    10  
    11  	"google.golang.org/protobuf/types/known/emptypb"
    12  
    13  	"github.com/MetalBlockchain/metalgo/vms/rpcchainvm/ghttp/gconn"
    14  	"github.com/MetalBlockchain/metalgo/vms/rpcchainvm/ghttp/greader"
    15  	"github.com/MetalBlockchain/metalgo/vms/rpcchainvm/ghttp/gwriter"
    16  	"github.com/MetalBlockchain/metalgo/vms/rpcchainvm/grpcutils"
    17  
    18  	responsewriterpb "github.com/MetalBlockchain/metalgo/proto/pb/http/responsewriter"
    19  	readerpb "github.com/MetalBlockchain/metalgo/proto/pb/io/reader"
    20  	writerpb "github.com/MetalBlockchain/metalgo/proto/pb/io/writer"
    21  	connpb "github.com/MetalBlockchain/metalgo/proto/pb/net/conn"
    22  )
    23  
    24  var (
    25  	errUnsupportedFlushing  = errors.New("response writer doesn't support flushing")
    26  	errUnsupportedHijacking = errors.New("response writer doesn't support hijacking")
    27  
    28  	_ responsewriterpb.WriterServer = (*Server)(nil)
    29  )
    30  
    31  // Server is an http.ResponseWriter that is managed over RPC.
    32  type Server struct {
    33  	responsewriterpb.UnsafeWriterServer
    34  	writer http.ResponseWriter
    35  }
    36  
    37  // NewServer returns an http.ResponseWriter instance managed remotely
    38  func NewServer(writer http.ResponseWriter) *Server {
    39  	return &Server{
    40  		writer: writer,
    41  	}
    42  }
    43  
    44  func (s *Server) Write(
    45  	_ context.Context,
    46  	req *responsewriterpb.WriteRequest,
    47  ) (*responsewriterpb.WriteResponse, error) {
    48  	headers := s.writer.Header()
    49  	clear(headers)
    50  	for _, header := range req.Headers {
    51  		headers[header.Key] = header.Values
    52  	}
    53  
    54  	n, err := s.writer.Write(req.Payload)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	return &responsewriterpb.WriteResponse{
    59  		Written: int32(n),
    60  	}, nil
    61  }
    62  
    63  func (s *Server) WriteHeader(
    64  	_ context.Context,
    65  	req *responsewriterpb.WriteHeaderRequest,
    66  ) (*emptypb.Empty, error) {
    67  	headers := s.writer.Header()
    68  	clear(headers)
    69  	for _, header := range req.Headers {
    70  		headers[header.Key] = header.Values
    71  	}
    72  	s.writer.WriteHeader(grpcutils.EnsureValidResponseCode(int(req.StatusCode)))
    73  	return &emptypb.Empty{}, nil
    74  }
    75  
    76  func (s *Server) Flush(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
    77  	flusher, ok := s.writer.(http.Flusher)
    78  	if !ok {
    79  		return nil, errUnsupportedFlushing
    80  	}
    81  	flusher.Flush()
    82  	return &emptypb.Empty{}, nil
    83  }
    84  
    85  func (s *Server) Hijack(context.Context, *emptypb.Empty) (*responsewriterpb.HijackResponse, error) {
    86  	hijacker, ok := s.writer.(http.Hijacker)
    87  	if !ok {
    88  		return nil, errUnsupportedHijacking
    89  	}
    90  	conn, readWriter, err := hijacker.Hijack()
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	serverListener, err := grpcutils.NewListener()
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	server := grpcutils.NewServer()
   101  	closer := grpcutils.ServerCloser{}
   102  	closer.Add(server)
   103  
   104  	connpb.RegisterConnServer(server, gconn.NewServer(conn, &closer))
   105  	readerpb.RegisterReaderServer(server, greader.NewServer(readWriter))
   106  	writerpb.RegisterWriterServer(server, gwriter.NewServer(readWriter))
   107  
   108  	go grpcutils.Serve(serverListener, server)
   109  
   110  	local := conn.LocalAddr()
   111  	remote := conn.RemoteAddr()
   112  
   113  	return &responsewriterpb.HijackResponse{
   114  		LocalNetwork:  local.Network(),
   115  		LocalString:   local.String(),
   116  		RemoteNetwork: remote.Network(),
   117  		RemoteString:  remote.String(),
   118  		ServerAddr:    serverListener.Addr().String(),
   119  	}, nil
   120  }