conn.go (197 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
package zk
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/facebookincubator/zk/internal/proto"
"github.com/go-zookeeper/jute/lib/go/jute"
)
const defaultTimeout = 2 * time.Second
const overflowBitMask = 1<<31 - 1
// Conn represents a client connection to a Zookeeper server and parameters needed to handle its lifetime.
type Conn struct {
conn net.Conn
// client-side request ID
xid int32
// the client sends a requested timeout, the server responds with the timeout that it can give the client
sessionTimeout time.Duration
reqs sync.Map
cancelSession context.CancelFunc
sessionCtx context.Context
}
type pendingRequest struct {
reply jute.RecordReader
done chan struct{}
error error
}
// isAlive() checks the TCP connection is alive by reading from the sessionCtx channel.
func (c *Conn) isAlive() bool {
select {
case <-c.sessionCtx.Done():
return false
default:
return true
}
}
// DialContext connects to the ZK server using the default client.
func DialContext(ctx context.Context, network, address string) (*Conn, error) {
defaultClient := Client{}
return defaultClient.DialContext(ctx, network, address)
}
// DialContext connects the ZK client to the specified Zookeeper server.
// The provided context is used to determine the dial lifetime.
func (client *Client) DialContext(ctx context.Context, network, address string) (*Conn, error) {
if client.Dialer == nil {
defaultDialer := &net.Dialer{}
client.Dialer = defaultDialer.DialContext
}
conn, err := client.Dialer(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("could not dial ZK server: %w", err)
}
sessionCtx, cancel := context.WithCancel(context.Background())
c := &Conn{
conn: conn,
sessionTimeout: defaultTimeout,
cancelSession: cancel,
sessionCtx: sessionCtx,
}
if client.SessionTimeout != 0 {
c.sessionTimeout = client.SessionTimeout
}
if err = c.authenticate(); err != nil {
return nil, fmt.Errorf("could not authenticate with ZK server: %w", err)
}
go c.handleReads()
go c.keepAlive()
return c, nil
}
// Close closes the client connection, clearing all pending requests.
func (c *Conn) Close() error {
c.cancelSession()
c.clearPendingRequests()
return c.conn.Close()
}
func (c *Conn) authenticate() error {
// create and encode request for zk server
request := &proto.ConnectRequest{
TimeOut: int32(c.sessionTimeout.Milliseconds()),
}
if err := WriteRecords(c.conn, request); err != nil {
return fmt.Errorf("could not write authentication request: %w", err)
}
// receive bytes from same socket, reading the message length first
dec, err := createDecoder(c.conn)
if err != nil {
return fmt.Errorf("could not read auth response: %w", err)
}
response := proto.ConnectResponse{}
if err := response.Read(dec); err != nil {
return fmt.Errorf("could not decode authentication response: %w", err)
}
if response.TimeOut > 0 {
c.sessionTimeout = time.Duration(response.TimeOut) * time.Millisecond
}
return nil
}
// GetData calls Get on a Zookeeper server's node using the specified path and returns the server's response.
func (c *Conn) GetData(path string) ([]byte, error) {
request := &proto.GetDataRequest{Path: path}
response := &proto.GetDataResponse{}
if err := c.rpc(opGetData, request, response); err != nil {
return nil, fmt.Errorf("error sending GetData request: %w", err)
}
return response.Data, nil
}
// GetChildren returns all children of a node at the given path, if they exist.
func (c *Conn) GetChildren(path string) ([]string, error) {
request := &proto.GetChildrenRequest{Path: path}
response := &proto.GetChildrenResponse{}
if err := c.rpc(opGetChildren, request, response); err != nil {
return nil, fmt.Errorf("error sending GetChildren request: %w", err)
}
return response.Children, nil
}
func (c *Conn) rpc(opcode int32, w jute.RecordWriter, r jute.RecordReader) error {
header := &proto.RequestHeader{
Xid: c.nextXid(),
Type: opcode,
}
pending := &pendingRequest{
reply: r,
done: make(chan struct{}, 1),
}
c.reqs.Store(header.Xid, pending)
if err := WriteRecords(c.conn, header, w); err != nil {
return fmt.Errorf("could not write rpc request: %w", err)
}
select {
case <-pending.done:
return pending.error
case <-c.sessionCtx.Done():
return fmt.Errorf("session closed: %w", c.sessionCtx.Err())
case <-time.After(c.sessionTimeout):
return fmt.Errorf("got a timeout waiting on response for xid %d", header.Xid)
}
}
func (c *Conn) handleReads() {
defer c.Close()
for {
if c.sessionCtx.Err() != nil {
return
}
dec, err := createDecoder(c.conn)
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // don't make further attempts to read from closed connection, close goroutine
}
if err != nil {
log.Printf("could not read response packet: %v", err)
return
}
replyHeader := &proto.ReplyHeader{}
if err = dec.ReadRecord(replyHeader); err != nil {
log.Printf("could not decode reply header: %v", err)
return
}
if replyHeader.Xid == pingXID {
continue // ignore ping responses
}
value, ok := c.reqs.LoadAndDelete(replyHeader.Xid)
if !ok {
log.Printf("no matching request found for xid %d", replyHeader.Xid)
continue
}
pending := value.(*pendingRequest)
if replyHeader.Err != 0 {
code := Error(replyHeader.Err)
pending.error = &code
} else if err = dec.ReadRecord(pending.reply); err != nil {
log.Printf("could not decode reply record: %v", err)
return
}
pending.done <- struct{}{}
}
}
func (c *Conn) keepAlive() {
// set the ping interval to half of the session timeout, according to Zookeeper documentation
pingTicker := time.NewTicker(c.sessionTimeout / 2)
defer pingTicker.Stop()
defer c.Close()
for {
select {
case <-pingTicker.C:
header := &proto.RequestHeader{
Xid: pingXID,
Type: opPing,
}
if err := WriteRecords(c.conn, header); err != nil {
log.Printf("error writing ping request: %v", err)
return
}
case <-c.sessionCtx.Done():
return
}
}
}
func (c *Conn) clearPendingRequests() {
c.reqs.Range(func(key, value interface{}) bool {
c.reqs.Delete(key)
return true
})
}
func (c *Conn) nextXid() int32 {
return atomic.AddInt32(&c.xid, 1) & overflowBitMask
}