oapi-codegen icon indicating copy to clipboard operation
oapi-codegen copied to clipboard

response during file download error

Open adabuleanu opened this issue 1 year ago • 1 comments

I am trying to implement an endpoint to download a file from a bucket. Following the swagger documentation , I have created the following openapi schema:

openapi: 3.0.3
paths:
  /documents/{id}:
    get:
      operationId: getDocument
      parameters:
        - name: id
          description: The id of the document
          in: path
          required: true
          schema:
            type: string
            format: uuid
      responses:
        '200':
          description: OK
          content:
            application/octet-stream:
              schema:
                type: string
                format: binary
        '401':
          $ref: '#/components/responses/Unauthorized'
        '404':
          $ref: '#/components/responses/NotFound'
        default:
          $ref: '#/components/responses/DefaultError'
components:
  responses:
    NotFound:
      description: The specified resource was not found
      content:
        application/json:
          schema:
            $ref: '#/components/schemas/Error'
    Unauthorized:
      description: Unauthorized
      content:
        application/json:
          schema:
            $ref: '#/components/schemas/Error'
    DefaultError:
      description: Error
      content:
        application/json:
          schema:
            $ref: '#/components/schemas/Error'
  schemas:
    Error:
      type: object
      properties:
        code:
          type: string
        message:
          type: string
      required:
        - code
        - message

Which generates the following code using github.com/deepmap/oapi-codegen/v2 v2.1.0:

// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen/v2 version v2.1.0 DO NOT EDIT.
package api

import (
	"bytes"
	"compress/gzip"
	"context"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"path"
	"strings"

	"github.com/getkin/kin-openapi/openapi3"
	"github.com/go-chi/chi/v5"
	"github.com/oapi-codegen/runtime"
	strictnethttp "github.com/oapi-codegen/runtime/strictmiddleware/nethttp"
	openapi_types "github.com/oapi-codegen/runtime/types"
)

// Error defines model for Error.
type Error struct {
	Code    string `json:"code"`
	Message string `json:"message"`
}

// DefaultError defines model for DefaultError.
type DefaultError = Error

// NotFound defines model for NotFound.
type NotFound = Error

// Unauthorized defines model for Unauthorized.
type Unauthorized = Error

// ServerInterface represents all server handlers.
type ServerInterface interface {

	// (GET /documents/{id})
	GetDocument(w http.ResponseWriter, r *http.Request, id openapi_types.UUID)
}

// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.

type Unimplemented struct{}

// (GET /documents/{id})
func (_ Unimplemented) GetDocument(w http.ResponseWriter, r *http.Request, id openapi_types.UUID) {
	w.WriteHeader(http.StatusNotImplemented)
}

// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
	Handler            ServerInterface
	HandlerMiddlewares []MiddlewareFunc
	ErrorHandlerFunc   func(w http.ResponseWriter, r *http.Request, err error)
}

type MiddlewareFunc func(http.Handler) http.Handler

// GetDocument operation middleware
func (siw *ServerInterfaceWrapper) GetDocument(w http.ResponseWriter, r *http.Request) {
	ctx := r.Context()

	var err error

	// ------------- Path parameter "id" -------------
	var id openapi_types.UUID

	err = runtime.BindStyledParameterWithOptions("simple", "id", chi.URLParam(r, "id"), &id, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true})
	if err != nil {
		siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "id", Err: err})
		return
	}

	handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		siw.Handler.GetDocument(w, r, id)
	}))

	for _, middleware := range siw.HandlerMiddlewares {
		handler = middleware(handler)
	}

	handler.ServeHTTP(w, r.WithContext(ctx))
}

type UnescapedCookieParamError struct {
	ParamName string
	Err       error
}

func (e *UnescapedCookieParamError) Error() string {
	return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
}

func (e *UnescapedCookieParamError) Unwrap() error {
	return e.Err
}

type UnmarshalingParamError struct {
	ParamName string
	Err       error
}

func (e *UnmarshalingParamError) Error() string {
	return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
}

func (e *UnmarshalingParamError) Unwrap() error {
	return e.Err
}

type RequiredParamError struct {
	ParamName string
}

func (e *RequiredParamError) Error() string {
	return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
}

type RequiredHeaderError struct {
	ParamName string
	Err       error
}

func (e *RequiredHeaderError) Error() string {
	return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
}

func (e *RequiredHeaderError) Unwrap() error {
	return e.Err
}

type InvalidParamFormatError struct {
	ParamName string
	Err       error
}

func (e *InvalidParamFormatError) Error() string {
	return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
}

func (e *InvalidParamFormatError) Unwrap() error {
	return e.Err
}

type TooManyValuesForParamError struct {
	ParamName string
	Count     int
}

func (e *TooManyValuesForParamError) Error() string {
	return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
}

// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
	return HandlerWithOptions(si, ChiServerOptions{})
}

type ChiServerOptions struct {
	BaseURL          string
	BaseRouter       chi.Router
	Middlewares      []MiddlewareFunc
	ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}

// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
	return HandlerWithOptions(si, ChiServerOptions{
		BaseRouter: r,
	})
}

func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
	return HandlerWithOptions(si, ChiServerOptions{
		BaseURL:    baseURL,
		BaseRouter: r,
	})
}

// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
	r := options.BaseRouter

	if r == nil {
		r = chi.NewRouter()
	}
	if options.ErrorHandlerFunc == nil {
		options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
			http.Error(w, err.Error(), http.StatusBadRequest)
		}
	}
	wrapper := ServerInterfaceWrapper{
		Handler:            si,
		HandlerMiddlewares: options.Middlewares,
		ErrorHandlerFunc:   options.ErrorHandlerFunc,
	}

	r.Group(func(r chi.Router) {
		r.Get(options.BaseURL+"/documents/{id}", wrapper.GetDocument)
	})

	return r
}

type DefaultErrorJSONResponse Error

type NotFoundJSONResponse Error

type UnauthorizedJSONResponse Error

type GetDocumentRequestObject struct {
	Id openapi_types.UUID `json:"id"`
}

type GetDocumentResponseObject interface {
	VisitGetDocumentResponse(w http.ResponseWriter) error
}

type GetDocument200ApplicationoctetStreamResponse struct {
	Body          io.Reader
	ContentLength int64
}

func (response GetDocument200ApplicationoctetStreamResponse) VisitGetDocumentResponse(w http.ResponseWriter) error {
	w.Header().Set("Content-Type", "application/octet-stream")
	if response.ContentLength != 0 {
		w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
	}
	w.WriteHeader(200)

	if closer, ok := response.Body.(io.ReadCloser); ok {
		defer closer.Close()
	}
	_, err := io.Copy(w, response.Body)
	return err
}

type GetDocument401JSONResponse struct{ UnauthorizedJSONResponse }

func (response GetDocument401JSONResponse) VisitGetDocumentResponse(w http.ResponseWriter) error {
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(401)

	return json.NewEncoder(w).Encode(response)
}

type GetDocument404JSONResponse struct{ NotFoundJSONResponse }

func (response GetDocument404JSONResponse) VisitGetDocumentResponse(w http.ResponseWriter) error {
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(404)

	return json.NewEncoder(w).Encode(response)
}

type GetDocumentdefaultJSONResponse struct {
	Body       Error
	StatusCode int
}

func (response GetDocumentdefaultJSONResponse) VisitGetDocumentResponse(w http.ResponseWriter) error {
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(response.StatusCode)

	return json.NewEncoder(w).Encode(response.Body)
}

// StrictServerInterface represents all server handlers.
type StrictServerInterface interface {

	// (GET /documents/{id})
	GetDocument(ctx context.Context, request GetDocumentRequestObject) (GetDocumentResponseObject, error)
}

type StrictHandlerFunc = strictnethttp.StrictHTTPHandlerFunc
type StrictMiddlewareFunc = strictnethttp.StrictHTTPMiddlewareFunc

type StrictHTTPServerOptions struct {
	RequestErrorHandlerFunc  func(w http.ResponseWriter, r *http.Request, err error)
	ResponseErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}

func NewStrictHandler(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc) ServerInterface {
	return &strictHandler{ssi: ssi, middlewares: middlewares, options: StrictHTTPServerOptions{
		RequestErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
			http.Error(w, err.Error(), http.StatusBadRequest)
		},
		ResponseErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
			http.Error(w, err.Error(), http.StatusInternalServerError)
		},
	}}
}

func NewStrictHandlerWithOptions(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc, options StrictHTTPServerOptions) ServerInterface {
	return &strictHandler{ssi: ssi, middlewares: middlewares, options: options}
}

type strictHandler struct {
	ssi         StrictServerInterface
	middlewares []StrictMiddlewareFunc
	options     StrictHTTPServerOptions
}

// GetDocument operation middleware
func (sh *strictHandler) GetDocument(w http.ResponseWriter, r *http.Request, id openapi_types.UUID) {
	var request GetDocumentRequestObject

	request.Id = id

	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
		return sh.ssi.GetDocument(ctx, request.(GetDocumentRequestObject))
	}
	for _, middleware := range sh.middlewares {
		handler = middleware(handler, "GetDocument")
	}

	response, err := handler(r.Context(), w, r, request)

	if err != nil {
		sh.options.ResponseErrorHandlerFunc(w, r, err)
	} else if validResponse, ok := response.(GetDocumentResponseObject); ok {
		if err := validResponse.VisitGetDocumentResponse(w); err != nil {
			sh.options.ResponseErrorHandlerFunc(w, r, err)
		}
	} else if response != nil {
		sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
	}
}

// Base64 encoded, gzipped, json marshaled Swagger object
var swaggerSpec = []string{

	"H4sIAAAAAAAC/7xUwW4UMQz9lZHhGDoL7SnnAkJIcIFT1UOaeHZS7cTBcYTKav4dObPsdmChJ7qXTWI/",
	"vxe/ePbgacqUMEkBuwfGkikVbJtrHFzdyVtmYt17SoJJdOly3kXvJFLq7wslPSt+xMnp6iXjABZe9Kfi",
	"/RIt/VJtnmcDAYvnmLUIWDgEDHwieUc1hf9P+WXErmT0cYgYOsZClT12313pEkk3NBWzga/JVRmJ4w98",
	"BlUrNg0fEFrw6EVmysgSF6c8BdR/ecgIFopwTFtVPmEpbnsuNhtg/FYj65Vulgqn/FvzK5/u7tHLojOm",
	"gcCmutsZoIzJ5QgWLi82F5dgIDsZm5o+kK9Tu+w+hlmPtti6pZpbrz4EsPAe5fqQ2eDsJhTkAvZmf8ap",
	"GDoaOhmxCydU1Kgyg4HkJlUcAzy+m3BF88iTgXhyAhZqbZm/9+XWrKfgzWbzD8/JC8qrIoxuWnt/5LmL",
	"yfHDGaY/nP/8UT272rz+29s5CuvXj0RBV0+DjpPVmNtwPw1afQXm9vsZAAD//z2nJZM2BAAA",
}

// GetSwagger returns the content of the embedded swagger specification file
// or error if failed to decode
func decodeSpec() ([]byte, error) {
	zipped, err := base64.StdEncoding.DecodeString(strings.Join(swaggerSpec, ""))
	if err != nil {
		return nil, fmt.Errorf("error base64 decoding spec: %w", err)
	}
	zr, err := gzip.NewReader(bytes.NewReader(zipped))
	if err != nil {
		return nil, fmt.Errorf("error decompressing spec: %w", err)
	}
	var buf bytes.Buffer
	_, err = buf.ReadFrom(zr)
	if err != nil {
		return nil, fmt.Errorf("error decompressing spec: %w", err)
	}

	return buf.Bytes(), nil
}

var rawSpec = decodeSpecCached()

// a naive cached of a decoded swagger spec
func decodeSpecCached() func() ([]byte, error) {
	data, err := decodeSpec()
	return func() ([]byte, error) {
		return data, err
	}
}

// Constructs a synthetic filesystem for resolving external references when loading openapi specifications.
func PathToRawSpec(pathToFile string) map[string]func() ([]byte, error) {
	res := make(map[string]func() ([]byte, error))
	if len(pathToFile) > 0 {
		res[pathToFile] = rawSpec
	}

	return res
}

// GetSwagger returns the Swagger specification corresponding to the generated code
// in this file. The external references of Swagger specification are resolved.
// The logic of resolving external references is tightly connected to "import-mapping" feature.
// Externally referenced files must be embedded in the corresponding golang packages.
// Urls can be supported but this task was out of the scope.
func GetSwagger() (swagger *openapi3.T, err error) {
	resolvePath := PathToRawSpec("")

	loader := openapi3.NewLoader()
	loader.IsExternalRefsAllowed = true
	loader.ReadFromURIFunc = func(loader *openapi3.Loader, url *url.URL) ([]byte, error) {
		pathToFile := url.String()
		pathToFile = path.Clean(pathToFile)
		getSpec, ok := resolvePath[pathToFile]
		if !ok {
			err1 := fmt.Errorf("path not found: %s", pathToFile)
			return nil, err1
		}
		return getSpec()
	}
	var specData []byte
	specData, err = rawSpec()
	if err != nil {
		return
	}
	swagger, err = loader.LoadFromData(specData)
	if err != nil {
		return
	}
	return
}

My implementation of the file download logic does not involve storing the file in memory, but rather return a ReaderSeekerCloser object:

func GetDocument(ctx context.Context, request api.GetDocumentRequestObject) (api.GetDocumentResponseObject, error) {
        rsc := &GSReadSeekCloser{
		ObjectHandle:  c.client.Bucket(c.bucketName).Object(request.Id.String()),
		Context:      ctx,
	}

        // some more code 
        if err == ErrNotFound {
		return api.GetDocument404JSONResponse{}, nil
	}
	if err != nil {
		return api.GetDocumentdefaultJSONResponse{Body: api.Error{Message: err.Error()}, StatusCode: http.StatusInternalServerError}, nil
	}

	return api.GetDocument200ApplicationoctetStreamResponse{Body: rsc}, nil
}

While the GSReadSeekCloser is a version of https://github.com/googleapis/google-cloud-go/issues/1124#issuecomment-860092646.

Everything works as expected, but problem comes when an error occurs during file download (for example on big files). In this case the response code is 200 and the error message is appended to the response body.

$ curl -I -X 'GET'   'http://localhost/documents/34e2d1e6-2488-48de-98b3-b7f3ae5b1b5b'   -H 'accept: application/octet-stream' 
HTTP/1.1 200 OK
Content-Type: application/octet-stream
Vary: Origin
Date: Wed, 05 Jun 2024 15:14:27 GMT
Content-Length: 190

This behavior is happening because in the generated code the 200 header is already written and error occurs afterwards:

func (response GetDocument200ApplicationoctetStreamResponse) VisitGetDocumentResponse(w http.ResponseWriter) error {
	w.Header().Set("Content-Type", "application/octet-stream")
	if response.ContentLength != 0 {
		w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
	}
	w.WriteHeader(200) // we first write 200 header

	if closer, ok := response.Body.(io.ReadCloser); ok {
		defer closer.Close()
	}
	_, err := io.Copy(w, response.Body) // error occurs here
	return err
}

There is a handler for the error:

func (sh *strictHandler) GetDocument(w http.ResponseWriter, r *http.Request, id openapi_types.UUID) {
	var request GetDocumentRequestObject

	request.Id = id

	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
		return sh.ssi.GetDocument(ctx, request.(GetDocumentRequestObject))
	}
	for _, middleware := range sh.middlewares {
		handler = middleware(handler, "GetDocument")
	}

	response, err := handler(r.Context(), w, r, request)

	if err != nil {
		sh.options.ResponseErrorHandlerFunc(w, r, err)
	} else if validResponse, ok := response.(GetDocumentResponseObject); ok {
		if err := validResponse.VisitGetDocumentResponse(w); err != nil {
			sh.options.ResponseErrorHandlerFunc(w, r, err) // error is handled here
		}
	} else if response != nil {
		sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
	}
}

Which should respond in a http.Error:

func NewStrictHandler(ssi StrictServerInterface, middlewares []StrictMiddlewareFunc) ServerInterface {
	return &strictHandler{ssi: ssi, middlewares: middlewares, options: StrictHTTPServerOptions{
		RequestErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
			http.Error(w, err.Error(), http.StatusBadRequest)
		},
		ResponseErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
			http.Error(w, err.Error(), http.StatusInternalServerError) // it should give a http error
		},
	}}
}

But what the http.Error tries to do is to write the 500 header, but this is skipped since the header was already writen and generates the following log entry:

2024/06/06 17:56:58 http: superfluous response.WriteHeader call from ......

The behavior is somehow understandable since the error happens somewhere during the file transfer. But the end result is that the client receives a 200 response and a partial response body (file content transferred until error + error message) which is not desired at all.

Is this the actual desired behavior? if yes, what is the proper way to handle file downloads using openapi v3 specification?

adabuleanu avatar Jun 06 '24 16:06 adabuleanu

This is a more general question of how you want to handle your response; do you want to stream it from the upstream object store, or can you fetch everything from the upstream object store before sending the response.

in the first case, you can not avoid the problem where you have sent HTTP/200 and something goes wrong, because there are many things which could make the upstream download fail mid-stream. The way I've handed this in my API's which proxy from S3 is to first query the object store for the object size, and ensure that I include a Content-Length header. Then, when I have an upstream error, I truncate the response, and close the connection, and the client now knows that the download got truncated despite HTTP/200. I've not tried this yet, but HTTP/1.1 also supports a Trailer field at the end of a message.

In the second case, you have high API latency, because you have to download something locally in its entirety before sending it downstream, but in this case, error handling is really easy.

It does look like our autogenerated code is structured in such a way that it'll send writeHeader more than once, which is a bug.

mromaszewicz avatar Jun 06 '24 19:06 mromaszewicz