--- /dev/null
+// Copyright (c) 2012-2016 The Revel Framework Authors, All rights reserved.
+// Revel Framework source code and usage is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package revel
+
+import (
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+)
+
+var compressionTypes = [...]string{
+ "gzip",
+ "deflate",
+}
+
+var compressableMimes = [...]string{
+ "text/plain",
+ "text/html",
+ "text/xml",
+ "text/css",
+ "application/json",
+ "application/xml",
+ "application/xhtml+xml",
+ "application/rss+xml",
+ "application/javascript",
+ "application/x-javascript",
+}
+
+// Local log instance for this class
+var compressLog = RevelLog.New("section", "compress")
+
+// WriteFlusher interface for compress writer
+type WriteFlusher interface {
+ io.Writer // An IO Writer
+ io.Closer // A closure
+ Flush() error /// A flush function
+}
+
+// The compressed writer
+type CompressResponseWriter struct {
+ Header *BufferedServerHeader // The header
+ ControllerResponse *Response // The response
+ OriginalWriter io.Writer // The writer
+ compressWriter WriteFlusher // The flushed writer
+ compressionType string // The compression type
+ headersWritten bool // True if written
+ closeNotify chan bool // The notify channel to close
+ parentNotify <-chan bool // The parent chanel to receive the closed event
+ closed bool // True if closed
+}
+
+// CompressFilter does compression of response body in gzip/deflate if
+// `results.compressed=true` in the app.conf
+func CompressFilter(c *Controller, fc []Filter) {
+ if c.Response.Out.internalHeader.Server != nil && Config.BoolDefault("results.compressed", false) {
+ if c.Response.Status != http.StatusNoContent && c.Response.Status != http.StatusNotModified {
+ if found, compressType, compressWriter := detectCompressionType(c.Request, c.Response); found {
+ writer := CompressResponseWriter{
+ ControllerResponse: c.Response,
+ OriginalWriter: c.Response.GetWriter(),
+ compressWriter: compressWriter,
+ compressionType: compressType,
+ headersWritten: false,
+ closeNotify: make(chan bool, 1),
+ closed: false,
+ }
+ // Swap out the header with our own
+ writer.Header = NewBufferedServerHeader(c.Response.Out.internalHeader.Server)
+ c.Response.Out.internalHeader.Server = writer.Header
+ if w, ok := c.Response.GetWriter().(http.CloseNotifier); ok {
+ writer.parentNotify = w.CloseNotify()
+ }
+ c.Response.SetWriter(&writer)
+ }
+ } else {
+ compressLog.Debug("CompressFilter: Compression disabled for response ", "status", c.Response.Status)
+ }
+ }
+ fc[0](c, fc[1:])
+}
+
+// Called to notify the writer is closing
+func (c CompressResponseWriter) CloseNotify() <-chan bool {
+ if c.parentNotify != nil {
+ return c.parentNotify
+ }
+ return c.closeNotify
+}
+
+// Cancel the writer
+func (c *CompressResponseWriter) cancel() {
+ c.closed = true
+}
+
+// Prepare the headers
+func (c *CompressResponseWriter) prepareHeaders() {
+ if c.compressionType != "" {
+ responseMime := ""
+ if t := c.Header.Get("Content-Type"); len(t) > 0 {
+ responseMime = t[0]
+ }
+ responseMime = strings.TrimSpace(strings.SplitN(responseMime, ";", 2)[0])
+ shouldEncode := false
+
+ if len(c.Header.Get("Content-Encoding")) == 0 {
+ for _, compressableMime := range compressableMimes {
+ if responseMime == compressableMime {
+ shouldEncode = true
+ c.Header.Set("Content-Encoding", c.compressionType)
+ c.Header.Del("Content-Length")
+ break
+ }
+ }
+ }
+
+ if !shouldEncode {
+ c.compressWriter = nil
+ c.compressionType = ""
+ }
+ }
+ c.Header.Release()
+}
+
+// Write the headers
+func (c *CompressResponseWriter) WriteHeader(status int) {
+ if c.closed {
+ return
+ }
+ c.headersWritten = true
+ c.prepareHeaders()
+ c.Header.SetStatus(status)
+}
+
+// Close the writer
+func (c *CompressResponseWriter) Close() error {
+ if c.closed {
+ return nil
+ }
+ if !c.headersWritten {
+ c.prepareHeaders()
+ }
+ if c.compressionType != "" {
+ c.Header.Del("Content-Length")
+ if err := c.compressWriter.Close(); err != nil {
+ // TODO When writing directly to stream, an error will be generated
+ compressLog.Error("Close: Error closing compress writer", "type", c.compressionType, "error", err)
+ }
+
+ }
+ // Non-blocking write to the closenotifier, if we for some reason should
+ // get called multiple times
+ select {
+ case c.closeNotify <- true:
+ default:
+ }
+ c.closed = true
+ return nil
+}
+
+// Write to the underling buffer
+func (c *CompressResponseWriter) Write(b []byte) (int, error) {
+ if c.closed {
+ return 0, io.ErrClosedPipe
+ }
+ // Abort if parent has been closed
+ if c.parentNotify != nil {
+ select {
+ case <-c.parentNotify:
+ return 0, io.ErrClosedPipe
+ default:
+ }
+ }
+ // Abort if we ourselves have been closed
+ if c.closed {
+ return 0, io.ErrClosedPipe
+ }
+
+ if !c.headersWritten {
+ c.prepareHeaders()
+ c.headersWritten = true
+ }
+ if c.compressionType != "" {
+ return c.compressWriter.Write(b)
+ }
+ return c.OriginalWriter.Write(b)
+}
+
+// DetectCompressionType method detects the compression type
+// from header "Accept-Encoding"
+func detectCompressionType(req *Request, resp *Response) (found bool, compressionType string, compressionKind WriteFlusher) {
+ if Config.BoolDefault("results.compressed", false) {
+ acceptedEncodings := strings.Split(req.GetHttpHeader("Accept-Encoding"), ",")
+
+ largestQ := 0.0
+ chosenEncoding := len(compressionTypes)
+
+ // I have fixed one edge case for issue #914
+ // But it's better to cover all possible edge cases or
+ // Adapt to https://github.com/golang/gddo/blob/master/httputil/header/header.go#L172
+ for _, encoding := range acceptedEncodings {
+ encoding = strings.TrimSpace(encoding)
+ encodingParts := strings.SplitN(encoding, ";", 2)
+
+ // If we are the format "gzip;q=0.8"
+ if len(encodingParts) > 1 {
+ q := strings.TrimSpace(encodingParts[1])
+ if len(q) == 0 || !strings.HasPrefix(q, "q=") {
+ continue
+ }
+
+ // Strip off the q=
+ num, err := strconv.ParseFloat(q[2:], 32)
+ if err != nil {
+ continue
+ }
+
+ if num >= largestQ && num > 0 {
+ if encodingParts[0] == "*" {
+ chosenEncoding = 0
+ largestQ = num
+ continue
+ }
+ for i, encoding := range compressionTypes {
+ if encoding == encodingParts[0] {
+ if i < chosenEncoding {
+ largestQ = num
+ chosenEncoding = i
+ }
+ break
+ }
+ }
+ }
+ } else {
+ // If we can accept anything, chose our preferred method.
+ if encodingParts[0] == "*" {
+ chosenEncoding = 0
+ largestQ = 1
+ break
+ }
+ // This is for just plain "gzip"
+ for i, encoding := range compressionTypes {
+ if encoding == encodingParts[0] {
+ if i < chosenEncoding {
+ largestQ = 1.0
+ chosenEncoding = i
+ }
+ break
+ }
+ }
+ }
+ }
+
+ if largestQ == 0 {
+ return
+ }
+
+ compressionType = compressionTypes[chosenEncoding]
+
+ switch compressionType {
+ case "gzip":
+ compressionKind = gzip.NewWriter(resp.GetWriter())
+ found = true
+ case "deflate":
+ compressionKind = zlib.NewWriter(resp.GetWriter())
+ found = true
+ }
+ }
+ return
+}
+
+// BufferedServerHeader will not send content out until the Released is called, from that point on it will act normally
+// It implements all the ServerHeader
+type BufferedServerHeader struct {
+ cookieList []string // The cookie list
+ headerMap map[string][]string // The header map
+ status int // The status
+ released bool // True if released
+ original ServerHeader // The original header
+}
+
+// Creates a new instance based on the ServerHeader
+func NewBufferedServerHeader(o ServerHeader) *BufferedServerHeader {
+ return &BufferedServerHeader{original: o, headerMap: map[string][]string{}}
+}
+
+// Sets the cookie
+func (bsh *BufferedServerHeader) SetCookie(cookie string) {
+ if bsh.released {
+ bsh.original.SetCookie(cookie)
+ } else {
+ bsh.cookieList = append(bsh.cookieList, cookie)
+ }
+}
+
+// Returns a cookie
+func (bsh *BufferedServerHeader) GetCookie(key string) (ServerCookie, error) {
+ return bsh.original.GetCookie(key)
+}
+
+// Sets (replace) the header key
+func (bsh *BufferedServerHeader) Set(key string, value string) {
+ if bsh.released {
+ bsh.original.Set(key, value)
+ } else {
+ bsh.headerMap[key] = []string{value}
+ }
+}
+
+// Add (append) to a key this value
+func (bsh *BufferedServerHeader) Add(key string, value string) {
+ if bsh.released {
+ bsh.original.Set(key, value)
+ } else {
+ old := []string{}
+ if v, found := bsh.headerMap[key]; found {
+ old = v
+ }
+ bsh.headerMap[key] = append(old, value)
+ }
+}
+
+// Delete this key
+func (bsh *BufferedServerHeader) Del(key string) {
+ if bsh.released {
+ bsh.original.Del(key)
+ } else {
+ delete(bsh.headerMap, key)
+ }
+}
+
+// Get this key
+func (bsh *BufferedServerHeader) Get(key string) (value []string) {
+ if bsh.released {
+ value = bsh.original.Get(key)
+ } else {
+ if v, found := bsh.headerMap[key]; found && len(v) > 0 {
+ value = v
+ } else {
+ value = bsh.original.Get(key)
+ }
+ }
+ return
+}
+
+// Get all header keys
+func (bsh *BufferedServerHeader) GetKeys() (value []string) {
+ if bsh.released {
+ value = bsh.original.GetKeys()
+ } else {
+ value = bsh.original.GetKeys()
+ for key := range bsh.headerMap {
+ found := false
+ for _,v := range value {
+ if v==key {
+ found = true
+ break
+ }
+ }
+ if !found {
+ value = append(value,key)
+ }
+ }
+ }
+ return
+}
+
+// Set the status
+func (bsh *BufferedServerHeader) SetStatus(statusCode int) {
+ if bsh.released {
+ bsh.original.SetStatus(statusCode)
+ } else {
+ bsh.status = statusCode
+ }
+}
+
+// Release the header and push the results to the original
+func (bsh *BufferedServerHeader) Release() {
+ bsh.released = true
+ for k, v := range bsh.headerMap {
+ for _, r := range v {
+ bsh.original.Set(k, r)
+ }
+ }
+ for _, c := range bsh.cookieList {
+ bsh.original.SetCookie(c)
+ }
+ if bsh.status > 0 {
+ bsh.original.SetStatus(bsh.status)
+ }
+}