oss/transport/dialer.go (74 lines of code) (raw):
package transport
import (
"context"
"net"
"time"
)
// Dialer
type Dialer struct {
net.Dialer
// Read/Write timeout
timeout time.Duration
postRead []func(n int, err error)
postWrite []func(n int, err error)
}
func newDialer(cfg *Config) *Dialer {
dialer := &Dialer{
Dialer: net.Dialer{
Timeout: *cfg.ConnectTimeout,
KeepAlive: *cfg.KeepAliveTimeout,
},
timeout: *cfg.ReadWriteTimeout,
postRead: cfg.PostRead,
postWrite: cfg.PostWrite,
}
return dialer
}
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
c, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return c, err
}
timeout := d.timeout
if u, ok := ctx.Value("OpReadWriteTimeout").(*time.Duration); ok {
timeout = *u
}
t := &timeoutConn{
Conn: c,
timeout: timeout,
dialer: d,
}
return t, t.nudgeDeadline()
}
// A net.Conn with Read/Write timeout and rate limiting,
type timeoutConn struct {
net.Conn
timeout time.Duration
dialer *Dialer
}
func (c *timeoutConn) nudgeDeadline() error {
if c.timeout > 0 {
return c.SetDeadline(time.Now().Add(c.timeout))
}
return nil
}
func (c *timeoutConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
for _, fn := range c.dialer.postRead {
fn(n, err)
}
if err == nil && n > 0 && c.timeout > 0 {
err = c.nudgeDeadline()
}
return n, err
}
func (c *timeoutConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
for _, fn := range c.dialer.postWrite {
fn(n, err)
}
if err == nil && n > 0 && c.timeout > 0 {
err = c.nudgeDeadline()
}
return n, err
}