10 "github.com/golang/protobuf/proto"
11 "github.com/golang/protobuf/ptypes/any"
12 "github.com/grpc-ecosystem/grpc-gateway/internal"
13 "google.golang.org/grpc/codes"
14 "google.golang.org/grpc/grpclog"
15 "google.golang.org/grpc/status"
18 // ForwardResponseStream forwards the stream from gRPC server to REST client.
19 func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
20 f, ok := w.(http.Flusher)
22 grpclog.Infof("Flush not supported in %T", w)
23 http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
27 md, ok := ServerMetadataFromContext(ctx)
29 grpclog.Infof("Failed to extract ServerMetadata from context")
30 http.Error(w, "unexpected error", http.StatusInternalServerError)
33 handleForwardResponseServerMetadata(w, mux, md)
35 w.Header().Set("Transfer-Encoding", "chunked")
36 w.Header().Set("Content-Type", marshaler.ContentType())
37 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
38 HTTPError(ctx, mux, marshaler, w, req, err)
43 if d, ok := marshaler.(Delimited); ok {
44 delimiter = d.Delimiter()
46 delimiter = []byte("\n")
56 handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
59 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
60 handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
64 buf, err := marshaler.Marshal(streamChunk(resp, nil))
66 grpclog.Infof("Failed to marshal response chunk: %v", err)
67 handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
70 if _, err = w.Write(buf); err != nil {
71 grpclog.Infof("Failed to send response chunk: %v", err)
75 if _, err = w.Write(delimiter); err != nil {
76 grpclog.Infof("Failed to send delimiter chunk: %v", err)
83 func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
84 for k, vs := range md.HeaderMD {
85 if h, ok := mux.outgoingHeaderMatcher(k); ok {
86 for _, v := range vs {
93 func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
94 for k := range md.TrailerMD {
95 tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
96 w.Header().Add("Trailer", tKey)
100 func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
101 for k, vs := range md.TrailerMD {
102 tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
103 for _, v := range vs {
104 w.Header().Add(tKey, v)
109 // responseBody interface contains method for getting field for marshaling to the response body
110 // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
111 type responseBody interface {
112 XXX_ResponseBody() interface{}
115 // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
116 func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
117 md, ok := ServerMetadataFromContext(ctx)
119 grpclog.Infof("Failed to extract ServerMetadata from context")
122 handleForwardResponseServerMetadata(w, mux, md)
123 handleForwardResponseTrailerHeader(w, md)
125 contentType := marshaler.ContentType()
126 // Check marshaler on run time in order to keep backwards compatability
127 // An interface param needs to be added to the ContentType() function on
128 // the Marshal interface to be able to remove this check
129 if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
130 contentType = httpBodyMarshaler.ContentTypeFromMessage(resp)
132 w.Header().Set("Content-Type", contentType)
134 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
135 HTTPError(ctx, mux, marshaler, w, req, err)
140 if rb, ok := resp.(responseBody); ok {
141 buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
143 buf, err = marshaler.Marshal(resp)
146 grpclog.Infof("Marshal error: %v", err)
147 HTTPError(ctx, mux, marshaler, w, req, err)
151 if _, err = w.Write(buf); err != nil {
152 grpclog.Infof("Failed to write response: %v", err)
155 handleForwardResponseTrailer(w, md)
158 func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
162 for _, opt := range opts {
163 if err := opt(ctx, w, resp); err != nil {
164 grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
171 func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
172 buf, merr := marshaler.Marshal(streamChunk(nil, err))
174 grpclog.Infof("Failed to marshal an error: %v", merr)
178 s, ok := status.FromError(err)
180 s = status.New(codes.Unknown, err.Error())
182 w.WriteHeader(HTTPStatusFromCode(s.Code()))
184 if _, werr := w.Write(buf); werr != nil {
185 grpclog.Infof("Failed to notify error to client: %v", werr)
190 func streamChunk(result proto.Message, err error) map[string]proto.Message {
192 grpcCode := codes.Unknown
193 grpcMessage := err.Error()
194 var grpcDetails []*any.Any
195 if s, ok := status.FromError(err); ok {
197 grpcMessage = s.Message()
198 grpcDetails = s.Proto().GetDetails()
200 httpCode := HTTPStatusFromCode(grpcCode)
201 return map[string]proto.Message{
202 "error": &internal.StreamError{
203 GrpcCode: int32(grpcCode),
204 HttpCode: int32(httpCode),
205 Message: grpcMessage,
206 HttpStatus: http.StatusText(httpCode),
207 Details: grpcDetails,
212 return streamChunk(nil, fmt.Errorf("empty response"))
214 return map[string]proto.Message{"result": result}