jeudi 17 août 2017

How to test for leaking connections in a Go test?

After I've found leaks in my program, I've solved the problem. However, now I'm trying to find out how to test leaking connections in a Go test.

I've tried to change the number of requests in my test, didn't matter. No matter what I do, the current number of TCP connections in my test stay the same.

func TestLeakingConnections(t *testing.T) {
    getter := myhttp.New()

    s := newServer(ok)
    defer s.Close()

    cur := tcps(t)
    for i := 0; i < 1000; i++ {
        getter.GetWithTimeout(s.URL, time.Millisecond*10)
    }

    for tries := 10; tries >= 0; tries-- {
        growth := tcps(t) - cur
        if growth > 5 {
            t.Error("leaked")
        }
    }
}

// find tcp connections
func tcps(t *testing.T) (conns int) {
    lsof, err := exec.Command("lsof", "-n", "-p", strconv.Itoa(os.Getpid())).Output()
    if err != nil {
        t.Skip("skipping test; error finding or running lsof")
    }

    for _, ls := range strings.Split(string(lsof), "\n") {
        if strings.Contains(ls, "TCP") {
            conns++
        }
    }
    return
}

func newServer(f http.HandlerFunc) *httptest.Server {
    return httptest.NewServer(http.HandlerFunc(f))
}

func ok(w http.ResponseWriter, r *http.Request) {
    w.Header().Add("Content-Type", "application/xml")
    io.WriteString(w, "<xml></xml>")
}

// myhttp package

// ...other code omitted for clarification

func (g *Getter) GetWithTimeout(
    url string,
    timeout time.Duration,
) (
    *http.Response, error,
) {
    // this is the leaking part
    // moving this out of here will stop leaks
    var transport = http.Transport{
        Dial:                  dialTimeout(timeout),
        TLSHandshakeTimeout:   timeout,
        ResponseHeaderTimeout: timeout,
        ExpectContinueTimeout: timeout,
    }

    client := http.Client{
        Timeout:   timeout,
        Transport: &transport,
    }

    res, err := client.Get(url)
    if err == nil {
        if res.StatusCode != 200 {
            err = errors.Errorf(ErrStatus.Error(), res.StatusCode)
        }
    }
    return res, err
}

type dialFunc func(network, addr string) (net.Conn, error)

func dialTimeout(timeout time.Duration) dialFunc {
    // this will be used once for each connection in transport roundtrip
    return func(network, addr string) (net.Conn, error) {
        return net.DialTimeout(network, addr, timeout)
    }
}

Aucun commentaire:

Enregistrer un commentaire