Source code analysis of golang reverse proxy

This article is reproduced from: Source code analysis of golang reverse proxy [source code attached]_ Dream siege lion_ 51CTO blog

1 example of reverse proxy based on reverse proxy

package main

import (
    "log"
    "net/http"
    "net/http/httputil"
    "net/url"
)

func main() {
    // Address rewrite instance
    // http://127.0.0.1:8888/test?id=1  => http://127.0.0.1:8081/reverse/test?id=1

    rs1 := "http://127.0.0.1:8081/reverse"
    targetUrl, err := url.Parse(rs1)
    if err != nil {
        log.Fatal("err")
    }
    proxy := httputil.NewSingleHostReverseProxy(targetUrl)
    log.Println("Reverse proxy server serve at : 127.0.0.1:8888")
    if err := http.ListenAndServe(":8888", proxy); err != nil {
        log.Fatal("Start server failed,err:", err)
    }
}
$ curl http://127.0.0.1:8888/hello?id=123 -s
http://127.0.0.1:8081/reverse/hello?id=123

2 reverse proxy source code analysis

Main structure reverseproxy

// Process the incoming request, send it to another server, implement the reverse proxy, and send the request back to the client
type ReverseProxy struct {
    // The request can be modified through transport, and the response body will be returned intact
    Director func(*http.Request)

    // The connection pool multiplexes connections to execute requests. If it is nil, http.DefaultTransport is used by default
    Transport http.RoundTripper

    // Refresh interval to client
    // This parameter will be ignored in streaming requests, and all reverse proxy requests will be refreshed immediately
    FlushInterval time.Duration

    // The default is std.err, which can be used to customize the logger
    ErrorLog *log.Logger

    // It is used to execute io.CopyBuffer, copy the response body and store it in byte slice
    BufferPool BufferPool

    // It is used to modify the response result and HTTP status code. When the returned result error is not empty, ErrorHandler will be called
    ModifyResponse func(*http.Response) error

    // It is used to process the error information returned by the backend and ModifyResponse. By default, the passed error information will be returned and HTTP 502 will be returned
    ErrorHandler func(http.ResponseWriter, *http.Request, error)
}

Main methods

// Instantiate ReverseProxy;
// Assuming that the target URI(target path) is / base and the requested URI (target request) is / dir,
// Then the request will be reverse proxied to http://x.x.x.x./base/dir ;
// ReverseProxy does not rewrite the Host header, but needs to rewrite the Host, which can be customized in the Director function

func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
    // Get the request parameters. For example, if the request is / dir?id=123, then rawQuery: id=123
    targetQuery := target.RawQuery

    // Instantiate director
    director := func(req *http.Request) {
        // http or https
        req.URL.Scheme = target.Scheme

        // Host name (ip + port or domain name + port)
        req.URL.Host = target.Host

        // Request URL splicing
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)

        // Use the "&" symbol to splice request parameters
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }

        // If the "user agent" header does not exist, it will be left blank
        if _, ok := req.Header["User-Agent"]; !ok {
            // explicitly disable User-Agent so it's not set to default value
            req.Header.Set("User-Agent", "")
        }
    }
    return &ReverseProxy{Director: director}
}

url splicing method

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash: // If both a and B exist, remove the first character of the latter, that is, splicing after "/"
        return a + b[1:]
    case !aslash && !bslash: // If neither a nor B exists, add "/" between them
        return a + "/" + b
    }
    return a + b // Otherwise, splice directly into one piece
}

From the above example, we know that the basic step is to instantiate a reverseproxy object and pass it into the http.ListenAndServe method

proxy := NewSingleHostReverseProxy(targetUrl)
http.ListenAndServe(":8888",proxy)

The http.ListenAndServe method receives an address and handler, and the function signature is as follows:

func ListenAndServe(addr string, handler Handler) error {...}

The handler here is an interface, and the implementation method is ServeHTTP

type Handler interface {
    ServeHTTP(ResponseWriter, *Request)
}

Therefore, we can be sure that the instantiated reverseproxy object also implements the serverhttp method
The main steps are:

  1. Copy the Header of the upstream request to the downstream request
  2. Modify requests (e.g. protocols, parameters, URLs, etc.)
  3. Determine whether to Upgrade the protocol
  4. Delete the hop by hop header in the upstream request, that is, it does not need to be transmitted to the downstream header
  5. Set the X-Forward-For Header and append the current node IP
  6. Use the connection pool to initiate requests downstream
  7. Processing protocol upgrade (httpcode 101)
  8. Delete hop by hop headers that do not need to be returned to the upstream
  9. Modify the content of the response body (if necessary)
  10. Copy downstream response header to upstream response request
  11. Return HTTP status code
  12. Periodically refresh content to response

Let's analyze the core method serverHttp

func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    transport := p.Transport
    if transport == nil {
        transport = http.DefaultTransport
    }

    // Check whether the request is terminated;
    // Get the context of the request, get the CloseNotify instance from the responseWriter,
    // Start a goroutine monitoring notifyChan and call the context cancel() method after receiving the end of the request.
    // Closing the browser, network interruption, forced termination request or normal termination request will receive the request end notification.
    ctx := req.Context()
    if cn, ok := rw.(http.CloseNotifier); ok {
        var cancel context.CancelFunc
        ctx, cancel = context.WithCancel(ctx)
        defer cancel()
        notifyChan := cn.CloseNotify()
        go func() {
            select {
            case <-notifyChan:
                cancel()
            case <-ctx.Done():
            }
        }()
    }

    // Set context, which refers to the request to downstream requests
    outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
    if req.ContentLength == 0 {
        outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
    }

    // Deep copy Header, that is, copy the upstream Header to the downstream request Header
    outreq.Header = cloneHeader(req.Header)

    // Set Director and modify request
    p.Director(outreq)
    outreq.Close = false

    // Upgrade http protocol, HTTP Upgrade
    // Determine whether there is an Upgrade in the header Connection
    reqUpType := upgradeType(outreq.Header)
    removeConnectionHeaders(outreq.Header)

    // Remove hop-by-hop headers to the backend. Especially
    // important is "Connection" because we want a persistent
    // connection, regardless of what the client sent to us.
    // Delete hop by hop headers, mainly some specified headers that do not need to be passed to the downstream
    for _, h := range hopHeaders {
        hv := outreq.Header.Get(h)
        if hv == "" {
            continue
        }
        // Te and trainers headers are not deleted
        if h == "Te" && hv == "trailers" {
            // Issue 21096: tell backend applications that
            // care about trailer support that we support
            // trailers. (We do, but we don't go out of
            // our way to advertise that unless the
            // incoming client request thought it was
            // worth mentioning)
            continue
        }
        outreq.Header.Del(h)
    }

    // After stripping all the hop-by-hop connection headers above, add back any
    // necessary for protocol upgrades, such as for websockets.
    // If reqUpType is not empty, set the Connection and Upgrade values to Upgrade, such as the scenario of websocket
    if reqUpType != "" {
        outreq.Header.Set("Connection", "Upgrade")
        outreq.Header.Set("Upgrade", reqUpType)
    }

    // Set x-forward-for and append node IP
    if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
        // If we aren't the first proxy retain prior
        // X-Forwarded-For information as a comma+space
        // separated list and fold multiple headers into one.
        if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
            clientIP = strings.Join(prior, ", ") + ", " + clientIP
        }
        outreq.Header.Set("X-Forwarded-For", clientIP)
    }

    // Request to downstream
    res, err := transport.RoundTrip(outreq)
    if err != nil {
        p.getErrorHandler()(rw, outreq, err)
        return
    }

    // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
    // Processing upgrade protocol requests
    if res.StatusCode == http.StatusSwitchingProtocols {
        if !p.modifyResponse(rw, res, outreq) {
            return
        }
        p.handleUpgradeResponse(rw, outreq, res)
        return
    }
    // Delete hop by hop header s in response to requests
    removeConnectionHeaders(res.Header)

    for _, h := range hopHeaders {
        res.Header.Del(h)
    }

    // Modify response content
    if !p.modifyResponse(rw, res, outreq) {
        return
    }

    // Copy response Header to upstream
    copyHeader(rw.Header(), res.Header)

    // The "Trailer" header isn't included in the Transport's response,
    // at least for *http.Transport. Build it up from Trailer.
    announcedTrailers := len(res.Trailer)
    if announcedTrailers > 0 {
        trailerKeys := make([]string, 0, len(res.Trailer))
        for k := range res.Trailer {
            trailerKeys = append(trailerKeys, k)
        }
        rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
    }
    // Write status code
    rw.WriteHeader(res.StatusCode)

    // Periodically refresh content to response
    err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
    if err != nil {
        defer res.Body.Close()
        // Since we're streaming the response, if we run into an error all we can do
        // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
        // on read error while copying body.
        if !shouldPanicOnCopyError(req) {
            p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
            return
        }
        panic(http.ErrAbortHandler)
    }
    res.Body.Close() // close now, instead of defer, to populate res.Trailer
    ......
}

3. Modify the returned content instance

The core is to modify the response body content and content length in the ModifyResponse method in reverseproxy

func NewSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
    targetQuery := target.RawQuery
    director := func(req *http.Request) {
        req.URL.Scheme = target.Scheme
        req.URL.Host = target.Host
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
        if _, ok := req.Header["User-Agent"]; !ok {
            // explicitly disable User-Agent so it's not set to default value
            req.Header.Set("User-Agent", "")
        }
    }

    // Custom ModifyResponse
    modifyResp := func(resp *http.Response) error {
        var oldData, newData []byte
        oldData, err := ioutil.ReadAll(resp.Body)
        if err != nil {
            return err
        }
        // Modify the returned content according to different status codes
        if resp.StatusCode == 200 {
            newData = []byte("[INFO] " + string(oldData))

        } else {
            newData = []byte("[ERROR] " + string(oldData))
        }

        // Modify the returned content and ContentLength
        resp.Body = ioutil.NopCloser(bytes.NewBuffer(newData))
        resp.ContentLength = int64(len(newData))
        resp.Header.Set("Content-Length", fmt.Sprint(len(newData)))
        return nil
    }
    // Pass in a custom ModifyResponse
    return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyResp}
}

test result

$ curl http://127.0.0.1:8888/test?id=123
[INFO] http://127.0.0.1:8081/reverse/test?id=123

4. Return the real IP address of the client

For security reasons, we usually do not expose real servers, that is, realserver s, directly to external users, but expose services through reverse proxy, as shown in the following figure:

The problem is, after one or more reverse proxy servers pass between the user and the real server, how should the real server obtain the user's real IP? In other words, how should the middle reverse proxy server transmit the user's real IP intact to the back-end real server.

We usually implement it based on HTTP header. The commonly used fields are X-Real-IP and X-Forward-For.
X-Real-IP: it is usually set at the proxy point closest to the user to record the user's real IP. The next reverse proxy node does not need to be set, otherwise it will be overwritten with the IP of the previous reverse proxy
X-Forward-For: record the IP of each passing node, separated by "," for example, if the request link is client - > proxy1 - > proxy2 - > webapp, the obtained values are client, proxy1, proxy2

if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
    // If we aren't the first proxy retain prior
    // X-Forwarded-For information as a comma+space
    // separated list and fold multiple headers into one.
    if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
        clientIP = strings.Join(prior, ", ") + ", " + clientIP
    }
    outreq.Header.Set("X-Forwarded-For", clientIP)
}

Tags: Go Web Development

Posted on Sun, 05 Dec 2021 17:34:54 -0500 by MerlinJR