conn.go (1,445 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bufio" "bytes" "context" "crypto/tls" "errors" "fmt" "io" "io/ioutil" "net" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/gocql/gocql/internal/lru" "github.com/gocql/gocql/internal/streams" ) // approve the authenticator with the list of allowed authenticators. If the provided list is empty, // the given authenticator is allowed. func approve(authenticator string, approvedAuthenticators []string) bool { if len(approvedAuthenticators) == 0 { return true } for _, s := range approvedAuthenticators { if authenticator == s { return true } } return false } // JoinHostPort is a utility to return an address string that can be used // by `gocql.Conn` to form a connection with a host. func JoinHostPort(addr string, port int) string { addr = strings.TrimSpace(addr) if _, _, err := net.SplitHostPort(addr); err != nil { addr = net.JoinHostPort(addr, strconv.Itoa(port)) } return addr } type Authenticator interface { Challenge(req []byte) (resp []byte, auth Authenticator, err error) Success(data []byte) error } // PasswordAuthenticator specifies credentials to be used when authenticating. // It can be configured with an "allow list" of authenticator class names to avoid // attempting to authenticate with Cassandra if it doesn't provide an expected authenticator. type PasswordAuthenticator struct { Username string Password string // Setting this to nil or empty will allow authenticating with any authenticator // provided by the server. This is the default behavior of most other driver // implementations. AllowedAuthenticators []string } func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) { if !approve(string(req), p.AllowedAuthenticators) { return nil, nil, fmt.Errorf("unexpected authenticator %q", req) } resp := make([]byte, 2+len(p.Username)+len(p.Password)) resp[0] = 0 copy(resp[1:], p.Username) resp[len(p.Username)+1] = 0 copy(resp[2+len(p.Username):], p.Password) return resp, nil, nil } func (p PasswordAuthenticator) Success(data []byte) error { return nil } // SslOptions configures TLS use. // // Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification // to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config. // SslOptions and Config.InsecureSkipVerify interact as follows: // // Config.InsecureSkipVerify | EnableHostVerification | Result // Config is nil | false | do not verify host // Config is nil | true | verify host // false | false | verify host // true | false | do not verify host // false | true | verify host // true | true | verify host type SslOptions struct { *tls.Config // CertPath and KeyPath are optional depending on server // config, but both fields must be omitted to avoid using a // client certificate CertPath string KeyPath string CaPath string //optional depending on server config // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this // on. // This option is basically the inverse of tls.Config.InsecureSkipVerify. // See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info. // // See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config. EnableHostVerification bool } type ConnConfig struct { ProtoVersion int CQLVersion string Timeout time.Duration WriteTimeout time.Duration ConnectTimeout time.Duration Dialer Dialer HostDialer HostDialer Compressor Compressor Authenticator Authenticator AuthProvider func(h *HostInfo) (Authenticator, error) Keepalive time.Duration Logger StdLogger tlsConfig *tls.Config disableCoalesce bool } func (c *ConnConfig) logger() StdLogger { if c.Logger == nil { return &defaultLogger{} } return c.Logger } type ConnErrorHandler interface { HandleError(conn *Conn, err error, closed bool) } type connErrorHandlerFn func(conn *Conn, err error, closed bool) func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) { fn(conn, err, closed) } // Conn is a single connection to a Cassandra node. It can be used to execute // queries, but users are usually advised to use a more reliable, higher // level API. type Conn struct { r ConnReader w contextWriter writeTimeout time.Duration cfg *ConnConfig frameObserver FrameHeaderObserver streamObserver StreamObserver headerBuf [maxFrameHeaderSize]byte streams *streams.IDGenerator mu sync.Mutex // calls stores a map from stream ID to callReq. // This map is protected by mu. // calls should not be used when closed is true, calls is set to nil when closed=true. calls map[int]*callReq errorHandler ConnErrorHandler compressor Compressor auth Authenticator addr string version uint8 currentKeyspace string host *HostInfo isSchemaV2 bool session *Session // true if connection close process for the connection started. // closed is protected by mu. closed bool ctx context.Context cancel context.CancelFunc timeouts int64 logger StdLogger } // connect establishes a connection to a Cassandra node using session's connection config. func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) { return s.dial(ctx, host, s.connCfg, errorHandler) } // dial establishes a connection to a Cassandra node and notifies the session's connectObserver. func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { var obs ObservedConnect if s.connectObserver != nil { obs.Host = host obs.Start = time.Now() } conn, err := s.dialWithoutObserver(ctx, host, connConfig, errorHandler) if s.connectObserver != nil { obs.End = time.Now() obs.Err = err s.connectObserver.ObserveConnect(obs) } return conn, err } // dialWithoutObserver establishes connection to a Cassandra node. // // dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead. func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { dialedHost, err := cfg.HostDialer.DialHost(ctx, host) if err != nil { return nil, err } writeTimeout := cfg.Timeout if cfg.WriteTimeout > 0 { writeTimeout = cfg.WriteTimeout } ctx, cancel := context.WithCancel(ctx) c := &Conn{ r: &connReader{ conn: dialedHost.Conn, r: bufio.NewReader(dialedHost.Conn), }, cfg: cfg, calls: make(map[int]*callReq), version: uint8(cfg.ProtoVersion), addr: dialedHost.Conn.RemoteAddr().String(), errorHandler: errorHandler, compressor: cfg.Compressor, session: s, streams: streams.New(cfg.ProtoVersion), host: host, isSchemaV2: true, // Try using "system.peers_v2" until proven otherwise frameObserver: s.frameObserver, w: &deadlineContextWriter{ w: dialedHost.Conn, timeout: writeTimeout, semaphore: make(chan struct{}, 1), quit: make(chan struct{}), }, ctx: ctx, cancel: cancel, logger: cfg.logger(), streamObserver: s.streamObserver, writeTimeout: writeTimeout, } if err := c.init(ctx, dialedHost); err != nil { cancel() c.Close() return nil, err } return c, nil } func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { if c.session.cfg.AuthProvider != nil { var err error c.auth, err = c.cfg.AuthProvider(c.host) if err != nil { return err } } else { c.auth = c.cfg.Authenticator } startup := &startupCoordinator{ frameTicker: make(chan struct{}), conn: c, } c.r.SetTimeout(c.cfg.ConnectTimeout) if err := startup.setupConn(ctx); err != nil { return err } c.r.SetTimeout(c.cfg.Timeout) // dont coalesce startup frames if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce { c.w = newWriteCoalescer(dialedHost.Conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) } go c.serve(ctx) go c.heartBeat(ctx) return nil } func (c *Conn) Write(p []byte) (n int, err error) { return c.w.writeContext(context.Background(), p) } type startupCoordinator struct { conn *Conn frameTicker chan struct{} } func (s *startupCoordinator) setupConn(ctx context.Context) error { var cancel context.CancelFunc if s.conn.r.GetTimeout() > 0 { ctx, cancel = context.WithTimeout(ctx, s.conn.r.GetTimeout()) } else { ctx, cancel = context.WithCancel(ctx) } defer cancel() // Only for proto v5+. // Indicates if STARTUP has been completed. // github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec // 2.3.1 Initial Handshake // In order to support both v5 and earlier formats, the v5 framing format is not // applied to message exchanges before an initial handshake is completed. startupCompleted := &atomic.Bool{} startupCompleted.Store(false) startupErr := make(chan error) go func() { for range s.frameTicker { err := s.conn.recv(ctx, startupCompleted.Load()) if err != nil { select { case startupErr <- err: case <-ctx.Done(): } return } } }() go func() { defer close(s.frameTicker) err := s.options(ctx, startupCompleted) select { case startupErr <- err: case <-ctx.Done(): } }() select { case err := <-startupErr: if err != nil { return err } case <-ctx.Done(): return errors.New("gocql: no response to connection startup within timeout") } return nil } func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder, startupCompleted *atomic.Bool) (frame, error) { select { case s.frameTicker <- struct{}{}: case <-ctx.Done(): return nil, ctx.Err() } framer, err := s.conn.execInternal(ctx, frame, nil, startupCompleted.Load()) if err != nil { return nil, err } return framer.parseFrame() } func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atomic.Bool) error { frame, err := s.write(ctx, &writeOptionsFrame{}, startupCompleted) if err != nil { return err } supported, ok := frame.(*supportedFrame) if !ok { return NewErrProtocol("Unknown type of response to startup frame: %T", frame) } return s.startup(ctx, supported.supported, startupCompleted) } func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string, startupCompleted *atomic.Bool) error { m := map[string]string{ "CQL_VERSION": s.conn.cfg.CQLVersion, "DRIVER_NAME": driverName, "DRIVER_VERSION": driverVersion, } if s.conn.compressor != nil { comp := supported["COMPRESSION"] name := s.conn.compressor.Name() for _, compressor := range comp { if compressor == name { m["COMPRESSION"] = compressor break } } if _, ok := m["COMPRESSION"]; !ok { s.conn.compressor = nil } } frame, err := s.write(ctx, &writeStartupFrame{opts: m}, startupCompleted) if err != nil { return err } switch v := frame.(type) { case error: return v case *readyFrame: // Startup is successfully completed, so we could use Native Protocol 5 startupCompleted.Store(true) return nil case *authenticateFrame: // Startup is successfully completed, so we could use Native Protocol 5 startupCompleted.Store(true) return s.authenticateHandshake(ctx, v, startupCompleted) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, startupCompleted *atomic.Bool) error { if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} for { frame, err := s.write(ctx, req, startupCompleted) if err != nil { return err } switch v := frame.(type) { case error: return v case *authSuccessFrame: if challenger != nil { return challenger.Success(v.data) } return nil case *authChallengeFrame: resp, challenger, err = challenger.Challenge(v.data) if err != nil { return err } req = &writeAuthResponseFrame{ data: resp, } default: return fmt.Errorf("unknown frame response during authentication: %v", v) } } } func (c *Conn) closeWithError(err error) { if c == nil { return } c.mu.Lock() if c.closed { c.mu.Unlock() return } c.closed = true var callsToClose map[int]*callReq // We should attempt to deliver the error back to the caller if it // exists. However, don't block c.mu while we are delivering the // error to outstanding calls. if err != nil { callsToClose = c.calls // It is safe to change c.calls to nil. Nobody should use it after c.closed is set to true. c.calls = nil } c.mu.Unlock() for _, req := range callsToClose { // we need to send the error to all waiting queries. select { case req.resp <- callResp{err: err}: case <-req.timeout: } if req.streamObserverContext != nil { req.streamObserverEndOnce.Do(func() { req.streamObserverContext.StreamAbandoned(ObservedStream{ Host: c.host, }) }) } } // if error was nil then unblock the quit channel c.cancel() cerr := c.r.Close() if err != nil { c.errorHandler.HandleError(c, err, true) } else if cerr != nil { // TODO(zariel): is it a good idea to do this? c.errorHandler.HandleError(c, cerr, true) } } func (c *Conn) Close() { c.closeWithError(nil) } // Serve starts the stream multiplexer for this connection, which is required // to execute any queries. This method runs as long as the connection is // open and is therefore usually called in a separate goroutine. func (c *Conn) serve(ctx context.Context) { var err error for err == nil { err = c.recv(ctx, true) } c.closeWithError(err) } func (c *Conn) discardFrame(r io.Reader, head frameHeader) error { _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return err } return nil } type protocolError struct { frame frame } func (p *protocolError) Error() string { if err, ok := p.frame.(error); ok { return err.Error() } return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) } func (c *Conn) heartBeat(ctx context.Context) { sleepTime := 1 * time.Second timer := time.NewTimer(sleepTime) defer timer.Stop() var failures int for { if failures > 5 { c.closeWithError(fmt.Errorf("gocql: heartbeat failed")) return } timer.Reset(sleepTime) select { case <-ctx.Done(): return case <-timer.C: } framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil) if err != nil { failures++ continue } resp, err := framer.parseFrame() if err != nil { // invalid frame failures++ continue } switch resp.(type) { case *supportedFrame: // Everything ok sleepTime = 5 * time.Second failures = 0 case error: // TODO: should we do something here? default: panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) } } } func (c *Conn) recv(ctx context.Context, startupCompleted bool) error { // If startup is completed and native proto 5+ is set up then we should // unwrap payload from compressed/uncompressed frame if startupCompleted && c.version > protoVersion4 { return c.recvSegment(ctx) } return c.processFrame(ctx, c.r) } func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { // not safe for concurrent reads // read a full header, ignore timeouts, as this is being ran in a loop // TODO: TCP level deadlines? or just query level deadlines? if c.r.GetTimeout() > 0 { c.r.SetReadDeadline(time.Time{}) } headStartTime := time.Now() // were just reading headers over and over and copy bodies head, err := readHeader(r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err } if c.frameObserver != nil { c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ Version: protoVersion(head.version), Flags: head.flags, Stream: int16(head.stream), Opcode: frameOp(head.op), Length: int32(head.length), Start: headStartTime, End: headEndTime, Host: c.host, }) } if head.stream > c.streams.NumStreams { return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) if err := framer.readFrame(r, &head); err != nil { return err } go c.session.handleEvent(framer) return nil } else if head.stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. framer := newFramer(c.compressor, c.version) if err := framer.readFrame(r, &head); err != nil { return err } frame, err := framer.parseFrame() if err != nil { return err } return &protocolError{ frame: frame, } } c.mu.Lock() if c.closed { c.mu.Unlock() return ErrConnectionClosed } call, ok := c.calls[head.stream] delete(c.calls, head.stream) c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(r, head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } framer := newFramer(c.compressor, c.version) err = framer.readFrame(r, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. if _, ok := err.(net.Error); ok { return err } } // we either, return a response to the caller, the caller timedout, or the // connection has closed. Either way we should never block indefinatly here select { case call.resp <- callResp{framer: framer, err: err}: case <-call.timeout: c.releaseStream(call) case <-ctx.Done(): } return nil } func (c *Conn) releaseStream(call *callReq) { if call.timer != nil { call.timer.Stop() } c.streams.Clear(call.streamID) if call.streamObserverContext != nil { call.streamObserverEndOnce.Do(func() { call.streamObserverContext.StreamFinished(ObservedStream{ Host: c.host, }) }) } } func (c *Conn) handleTimeout() { if atomic.AddInt64(&c.timeouts, 1) > 0 { c.closeWithError(ErrTooManyTimeouts) } } func (c *Conn) recvSegment(ctx context.Context) error { var ( frame []byte isSelfContained bool err error ) // Read frame based on compression if c.compressor != nil { frame, isSelfContained, err = readCompressedSegment(c.r, c.compressor) } else { frame, isSelfContained, err = readUncompressedSegment(c.r) } if err != nil { return err } if isSelfContained { return c.processAllFramesInSegment(ctx, bytes.NewReader(frame)) } head, err := readHeader(bytes.NewReader(frame), c.headerBuf[:]) if err != nil { return err } const frameHeaderLength = 9 buf := bytes.NewBuffer(make([]byte, 0, head.length+frameHeaderLength)) buf.Write(frame) // Computing how many bytes of message left to read bytesToRead := head.length - len(frame) + frameHeaderLength err = c.recvPartialFrames(buf, bytesToRead) if err != nil { return err } return c.processFrame(ctx, buf) } // recvPartialFrames reads proto v5 segments from Conn.r and writes decoded partial frames to dst. // It reads data until the bytesToRead is reached. // If Conn.compressor is not nil, it processes Compressed Format segments. func (c *Conn) recvPartialFrames(dst *bytes.Buffer, bytesToRead int) error { var ( read int frame []byte isSelfContained bool err error ) for read != bytesToRead { // Read frame based on compression if c.compressor != nil { frame, isSelfContained, err = readCompressedSegment(c.r, c.compressor) } else { frame, isSelfContained, err = readUncompressedSegment(c.r) } if err != nil { return fmt.Errorf("gocql: failed to read non self-contained frame: %w", err) } if isSelfContained { return fmt.Errorf("gocql: received self-contained segment, but expected not") } if totalLength := dst.Len() + len(frame); totalLength > dst.Cap() { return fmt.Errorf("gocql: expected partial frame of length %d, got %d", dst.Cap(), totalLength) } // Write the frame to the destination writer n, _ := dst.Write(frame) read += n } return nil } func (c *Conn) processAllFramesInSegment(ctx context.Context, r *bytes.Reader) error { var err error for r.Len() > 0 && err == nil { err = c.processFrame(ctx, r) } return err } // ConnReader is like net.Conn but also allows to set timeout duration. type ConnReader interface { net.Conn // SetTimeout sets timeout duration for reading data form conn SetTimeout(timeout time.Duration) // GetTimeout returns timeout duration GetTimeout() time.Duration } // connReader implements ConnReader. // It retries to read data up to 5 times or returns error. type connReader struct { conn net.Conn r *bufio.Reader timeout time.Duration } func (c *connReader) Read(p []byte) (n int, err error) { const maxAttempts = 5 for i := 0; i < maxAttempts; i++ { var nn int if c.timeout > 0 { c.conn.SetReadDeadline(time.Now().Add(c.timeout)) } nn, err = io.ReadFull(c.r, p[n:]) n += nn if err == nil { break } if verr, ok := err.(net.Error); !ok || !verr.Temporary() { break } } return } func (c *connReader) Write(b []byte) (n int, err error) { return c.conn.Write(b) } func (c *connReader) Close() error { return c.conn.Close() } func (c *connReader) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c *connReader) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c *connReader) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } func (c *connReader) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } func (c *connReader) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } func (c *connReader) SetTimeout(timeout time.Duration) { c.timeout = timeout } func (c *connReader) GetTimeout() time.Duration { return c.timeout } type callReq struct { // resp will receive the frame that was sent as a response to this stream. resp chan callResp timeout chan struct{} // indicates to recv() that a call has timed out streamID int // current stream in use timer *time.Timer // streamObserverContext is notified about events regarding this stream streamObserverContext StreamObserverContext // streamObserverEndOnce ensures that either StreamAbandoned or StreamFinished is called, // but not both. streamObserverEndOnce sync.Once } type callResp struct { // framer is the response frame. // May be nil if err is not nil. framer *framer // err is error encountered, if any. err error } // contextWriter is like io.Writer, but takes context as well. type contextWriter interface { // writeContext writes p to the connection. // // If ctx is canceled before we start writing p (e.g. during waiting while another write is currently in progress), // p is not written and ctx.Err() is returned. Context is ignored after we start writing p (i.e. we don't interrupt // blocked writes that are in progress) so that we always either write the full frame or not write it at all. // // It returns the number of bytes written from p (0 <= n <= len(p)) and any error that caused the write to stop // early. writeContext must return a non-nil error if it returns n < len(p). writeContext must not modify the // data in p, even temporarily. writeContext(ctx context.Context, p []byte) (n int, err error) } type deadlineWriter interface { SetWriteDeadline(time.Time) error io.Writer } type deadlineContextWriter struct { w deadlineWriter timeout time.Duration // semaphore protects critical section for SetWriteDeadline/Write. // It is a channel with capacity 1. semaphore chan struct{} // quit closed once the connection is closed. quit chan struct{} } // writeContext implements contextWriter. func (c *deadlineContextWriter) writeContext(ctx context.Context, p []byte) (int, error) { select { case <-ctx.Done(): return 0, ctx.Err() case <-c.quit: return 0, ErrConnectionClosed case c.semaphore <- struct{}{}: // acquired } defer func() { // release <-c.semaphore }() if c.timeout > 0 { err := c.w.SetWriteDeadline(time.Now().Add(c.timeout)) if err != nil { return 0, err } } return c.w.Write(p) } func newWriteCoalescer(conn deadlineWriter, writeTimeout, coalesceDuration time.Duration, quit <-chan struct{}) *writeCoalescer { wc := &writeCoalescer{ writeCh: make(chan writeRequest), c: conn, quit: quit, timeout: writeTimeout, } go wc.writeFlusher(coalesceDuration) return wc } type writeCoalescer struct { c deadlineWriter mu sync.Mutex quit <-chan struct{} writeCh chan writeRequest timeout time.Duration testEnqueuedHook func() testFlushedHook func() } type writeRequest struct { // resultChan is a channel (with buffer size 1) where to send results of the write. resultChan chan<- writeResult // data to write. data []byte } type writeResult struct { n int err error } // writeContext implements contextWriter. func (w *writeCoalescer) writeContext(ctx context.Context, p []byte) (int, error) { resultChan := make(chan writeResult, 1) wr := writeRequest{ resultChan: resultChan, data: p, } select { case <-ctx.Done(): return 0, ctx.Err() case <-w.quit: return 0, io.EOF // TODO: better error here? case w.writeCh <- wr: // enqueued for writing } if w.testEnqueuedHook != nil { w.testEnqueuedHook() } result := <-resultChan return result.n, result.err } func (w *writeCoalescer) writeFlusher(interval time.Duration) { timer := time.NewTimer(interval) defer timer.Stop() if !timer.Stop() { <-timer.C } w.writeFlusherImpl(timer.C, func() { timer.Reset(interval) }) } func (w *writeCoalescer) writeFlusherImpl(timerC <-chan time.Time, resetTimer func()) { running := false var buffers net.Buffers var resultChans []chan<- writeResult for { select { case req := <-w.writeCh: buffers = append(buffers, req.data) resultChans = append(resultChans, req.resultChan) if !running { // Start timer on first write. resetTimer() running = true } case <-w.quit: result := writeResult{ n: 0, err: io.EOF, // TODO: better error here? } // Unblock whoever was waiting. for _, resultChan := range resultChans { // resultChan has capacity 1, so it does not block. resultChan <- result } return case <-timerC: running = false w.flush(resultChans, buffers) buffers = nil resultChans = nil if w.testFlushedHook != nil { w.testFlushedHook() } } } } func (w *writeCoalescer) flush(resultChans []chan<- writeResult, buffers net.Buffers) { // Flush everything we have so far. if w.timeout > 0 { err := w.c.SetWriteDeadline(time.Now().Add(w.timeout)) if err != nil { for i := range resultChans { resultChans[i] <- writeResult{ n: 0, err: err, } } return } } // Copy buffers because WriteTo modifies buffers in-place. buffers2 := make(net.Buffers, len(buffers)) copy(buffers2, buffers) n, err := buffers2.WriteTo(w.c) // Writes of bytes before n succeeded, writes of bytes starting from n failed with err. // Use n as remaining byte counter. for i := range buffers { if int64(len(buffers[i])) <= n { // this buffer was fully written. resultChans[i] <- writeResult{ n: len(buffers[i]), err: nil, } n -= int64(len(buffers[i])) } else { // this buffer was not (fully) written. resultChans[i] <- writeResult{ n: int(n), err: err, } n = 0 } } } // addCall attempts to add a call to c.calls. // It fails with error if the connection already started closing or if a call for the given stream // already exists. func (c *Conn) addCall(call *callReq) error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return ErrConnectionClosed } existingCall := c.calls[call.streamID] if existingCall != nil { return fmt.Errorf("attempting to use stream already in use: %d -> %d", call.streamID, existingCall.streamID) } c.calls[call.streamID] = call return nil } func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { return c.execInternal(ctx, req, tracer, true) } func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer, startupCompleted bool) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } // TODO: move tracer onto conn stream, ok := c.streams.GetStream() if !ok { return nil, ErrNoStreams } // resp is basically a waiting semaphore protecting the framer framer := newFramer(c.compressor, c.version) call := &callReq{ timeout: make(chan struct{}), streamID: stream, resp: make(chan callResp), } if c.streamObserver != nil { call.streamObserverContext = c.streamObserver.StreamContext(ctx) } if err := c.addCall(call); err != nil { return nil, err } // After this point, we need to either read from call.resp or close(call.timeout) // since closeWithError can try to write a connection close error to call.resp. // If we don't close(call.timeout) or read from call.resp, closeWithError can deadlock. if tracer != nil { framer.trace() } if call.streamObserverContext != nil { call.streamObserverContext.StreamStarted(ObservedStream{ Host: c.host, }) } err := req.buildFrame(framer, stream) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout. close(call.timeout) // We failed to serialize the frame into a buffer. // This should not affect the connection as we didn't write anything. We just free the current call. c.mu.Lock() if !c.closed { delete(c.calls, call.streamID) } c.mu.Unlock() // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil // check above could fail. c.releaseStream(call) return nil, err } var n int if c.version > protoVersion4 && startupCompleted { err = framer.prepareModernLayout() } if err == nil { n, err = c.w.writeContext(ctx, framer.buf) } if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure // but we should not get a response after an error on the write side. close(call.timeout) if (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) && n == 0 { // We have not started to write this frame. // Release the stream as no response can come from the server on the stream. c.mu.Lock() if !c.closed { delete(c.calls, call.streamID) } c.mu.Unlock() // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil // check above could fail. c.releaseStream(call) } else { // I think this is the correct thing to do, im not entirely sure. It is not // ideal as readers might still get some data, but they probably wont. // Here we need to be careful as the stream is not available and if all // writes just timeout or fail then the pool might use this connection to // send a frame on, with all the streams used up and not returned. c.closeWithError(err) } return nil, err } var timeoutCh <-chan time.Time if timeout := c.r.GetTimeout(); timeout > 0 { if call.timer == nil { call.timer = time.NewTimer(0) <-call.timer.C } else { if !call.timer.Stop() { select { case <-call.timer.C: default: } } } call.timer.Reset(timeout) timeoutCh = call.timer.C } var ctxDone <-chan struct{} if ctx != nil { ctxDone = ctx.Done() } select { case resp := <-call.resp: close(call.timeout) if resp.err != nil { if !c.Closed() { // if the connection is closed then we cant release the stream, // this is because the request is still outstanding and we have // been handed another error from another stream which caused the // connection to close. c.releaseStream(call) } return nil, resp.err } // dont release the stream if detect a timeout as another request can reuse // that stream and get a response for the old request, which we have no // easy way of detecting. // // Ensure that the stream is not released if there are potentially outstanding // requests on the stream to prevent nil pointer dereferences in recv(). defer c.releaseStream(call) if v := resp.framer.header.version.version(); v != c.version { return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } return resp.framer, nil case <-timeoutCh: close(call.timeout) c.handleTimeout() return nil, ErrTimeoutNoResponse case <-ctxDone: close(call.timeout) return nil, ctx.Err() case <-c.ctx.Done(): close(call.timeout) return nil, ErrConnectionClosed } } // ObservedStream observes a single request/response stream. type ObservedStream struct { // Host of the connection used to send the stream. Host *HostInfo } // StreamObserver is notified about request/response pairs. // Streams are created for executing queries/batches or // internal requests to the database and might live longer than // execution of the query - the stream is still tracked until // response arrives so that stream IDs are not reused. type StreamObserver interface { // StreamContext is called before creating a new stream. // ctx is context passed to Session.Query / Session.Batch, // but might also be an internal context (for example // for internal requests that use control connection). // StreamContext might return nil if it is not interested // in the details of this stream. // StreamContext is called before the stream is created // and the returned StreamObserverContext might be discarded // without any methods called on the StreamObserverContext if // creation of the stream fails. // Note that if you don't need to track per-stream data, // you can always return the same StreamObserverContext. StreamContext(ctx context.Context) StreamObserverContext } // StreamObserverContext is notified about state of a stream. // A stream is started every time a request is written to the server // and is finished when a response is received. // It is abandoned when the underlying network connection is closed // before receiving a response. type StreamObserverContext interface { // StreamStarted is called when the stream is started. // This happens just before a request is written to the wire. StreamStarted(observedStream ObservedStream) // StreamAbandoned is called when we stop waiting for response. // This happens when the underlying network connection is closed. // StreamFinished won't be called if StreamAbandoned is. StreamAbandoned(observedStream ObservedStream) // StreamFinished is called when we receive a response for the stream. StreamFinished(observedStream ObservedStream) } type preparedStatment struct { id []byte resultMetadataID []byte request preparedMetadata response resultMetadata } type inflightPrepare struct { done chan struct{} err error preparedStatment *preparedStatment } func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), } lru.Add(stmtCacheKey, flight) return flight }) if !ok { go func() { defer close(flight.done) prep := &writePrepareFrame{ statement: stmt, } if c.version > protoVersion4 { prep.keyspace = keyspace } // we won the race to do the load, if our context is canceled we shouldnt // stop the load as other callers are waiting for it but this caller should get // their context cancelled error. framer, err := c.exec(c.ctx, prep, tracer) if err != nil { flight.err = err c.session.stmtsLRU.remove(stmtCacheKey) return } frame, err := framer.parseFrame() if err != nil { flight.err = err c.session.stmtsLRU.remove(stmtCacheKey) return } // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated // everytime we need to parse a frame. if len(framer.traceID) > 0 && tracer != nil { tracer.Trace(framer.traceID) } switch x := frame.(type) { case *resultPreparedFrame: flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. id: copyBytes(x.preparedID), resultMetadataID: copyBytes(x.resultMetadataID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, response: x.respMeta, } case error: flight.err = x default: flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x) } if flight.err != nil { c.session.stmtsLRU.remove(stmtCacheKey) } }() } select { case <-ctx.Done(): return nil, ctx.Err() case <-flight.done: return flight.preparedStatment, flight.err } } func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name value = named.value } if _, ok := value.(unsetColumn); !ok { val, err := Marshal(typ, value) if err != nil { return err } dst.value = val } else { dst.isUnset = true } return nil } func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params := queryParams{ consistency: qry.cons, } // frame checks that it is not 0 params.serialConsistency = qry.serialCons params.defaultTimestamp = qry.defaultTimestamp params.defaultTimestampValue = qry.defaultTimestampValue if len(qry.pageState) > 0 { params.pagingState = qry.pageState } if qry.pageSize > 0 { params.pageSize = qry.pageSize } if c.version > protoVersion4 { params.keyspace = qry.keyspace params.nowInSeconds = qry.nowInSecondsValue } // If a keyspace for the qry is overriden, // then we should use it to create stmt cache key usedKeyspace := c.currentKeyspace if qry.keyspace != "" { usedKeyspace = qry.keyspace } var ( frame frameBuilder info *preparedStatment ) if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) if err != nil { return &Iter{err: err} } values := qry.values if qry.binding != nil { values, err = qry.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, PKeyColumns: info.request.pkeyColumns, }) if err != nil { return &Iter{err: err} } } if len(values) != info.request.actualColCount { return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} } params.values = make([]queryValues, len(values)) for i := 0; i < len(values); i++ { v := &params.values[i] value := values[i] typ := info.request.columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } // if the metadata was not present in the response then we should not skip it params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData == 0 frame = &writeExecuteFrame{ preparedID: info.id, params: params, customPayload: qry.customPayload, resultMetadataID: info.resultMetadataID, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() qry.routingInfo.keyspace = info.request.keyspace if info.request.keyspace == "" { qry.routingInfo.keyspace = usedKeyspace } qry.routingInfo.table = info.request.table qry.routingInfo.mu.Unlock() } else { frame = &writeQueryFrame{ statement: qry.stmt, params: params, customPayload: qry.customPayload, } } framer, err := c.exec(ctx, frame, qry.trace) if err != nil { return &Iter{err: err} } resp, err := framer.parseFrame() if err != nil { return &Iter{err: err} } if len(framer.traceID) > 0 && qry.trace != nil { qry.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: return &Iter{framer: framer} case *resultRowsFrame: if x.meta.newMetadataID != nil { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ done: make(chan struct{}), preparedStatment: &preparedStatment{ id: oldInflight.preparedStatment.id, resultMetadataID: x.meta.newMetadataID, request: oldInflight.preparedStatment.request, response: x.meta, }, } // The driver should close this done to avoid deadlocks of // other subsequent requests close(newInflight.done) c.session.stmtsLRU.add(stmtCacheKey, newInflight) // Updating info to ensure the code is looking at the updated // version of the prepared statement info = newInflight.preparedStatment } } iter := &Iter{ meta: x.meta, framer: framer, numRows: x.numRows, } if x.meta.noMetaData() { if info != nil { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } } else { iter.meta = x.meta } if x.meta.morePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry newQry.pageState = copyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ qry: newQry, pos: int((1 - qry.prefetch) * float64(x.numRows)), } if iter.next.pos < 1 { iter.next.pos = 1 } } return iter case *resultKeyspaceFrame: return &Iter{framer: framer} case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: iter := &Iter{framer: framer} if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag c.logger.Println(err) } // dont return an error from this, might be a good idea to give a warning // though. The impact of this returning an error would be that the cluster // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: return &Iter{err: x, framer: framer} default: return &Iter{ err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), framer: framer, } } } func (c *Conn) Pick(qry *Query) *Conn { if c.Closed() { return nil } return c } func (c *Conn) Closed() bool { c.mu.Lock() defer c.mu.Unlock() return c.closed } func (c *Conn) Address() string { return c.addr } func (c *Conn) AvailableStreams() int { return c.streams.Available() } func (c *Conn) UseKeyspace(keyspace string) error { q := &writeQueryFrame{statement: `USE "` + keyspace + `"`} q.params.consistency = c.session.cons framer, err := c.exec(c.ctx, q, nil) if err != nil { return err } resp, err := framer.parseFrame() if err != nil { return err } switch x := resp.(type) { case *resultKeyspaceFrame: case error: return x default: return NewErrProtocol("unknown frame in response to USE: %v", x) } c.currentKeyspace = keyspace return nil } func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { if c.version == protoVersion1 { return &Iter{err: ErrUnsupported} } n := len(batch.Entries) req := &writeBatchFrame{ typ: batch.Type, statements: make([]batchStatment, n), consistency: batch.Cons, serialConsistency: batch.serialCons, defaultTimestamp: batch.defaultTimestamp, defaultTimestampValue: batch.defaultTimestampValue, customPayload: batch.CustomPayload, } if c.version > protoVersion4 { req.keyspace = batch.keyspace req.nowInSeconds = batch.nowInSeconds } usedKeyspace := c.currentKeyspace if batch.keyspace != "" { usedKeyspace = batch.keyspace } stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { entry := &batch.Entries[i] b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) if err != nil { return &Iter{err: err} } var values []interface{} if entry.binding == nil { values = entry.Args } else { values, err = entry.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, PKeyColumns: info.request.pkeyColumns, }) if err != nil { return &Iter{err: err} } } if len(values) != info.request.actualColCount { return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} } b.preparedID = info.id stmts[string(info.id)] = entry.Stmt b.values = make([]queryValues, info.request.actualColCount) for j := 0; j < info.request.actualColCount; j++ { v := &b.values[j] value := values[j] typ := info.request.columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } } else { b.statement = entry.Stmt } } framer, err := c.exec(batch.Context(), req, batch.trace) if err != nil { return &Iter{err: err} } resp, err := framer.parseFrame() if err != nil { return &Iter{err: err, framer: framer} } if len(framer.traceID) > 0 && batch.trace != nil { batch.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: return &Iter{} case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) case *resultRowsFrame: iter := &Iter{ meta: x.meta, framer: framer, numRows: x.numRows, } return iter case error: return &Iter{err: x, framer: framer} default: return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} } } func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).Trace(nil) q.skipPrepare = true q.disableSkipMetadata = true // we want to keep the query on this connection q.conn = c return c.executeQuery(ctx, q) } func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { const ( peerSchema = "SELECT * FROM system.peers" peerV2Schemas = "SELECT * FROM system.peers_v2" ) c.mu.Lock() isSchemaV2 := c.isSchemaV2 c.mu.Unlock() if version.AtLeast(4, 0, 0) && isSchemaV2 { // Try "system.peers_v2" and fallback to "system.peers" if it's not found iter := c.query(ctx, peerV2Schemas) err := iter.checkErrAndNotFound() if err != nil { if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers c.mu.Lock() c.isSchemaV2 = false c.mu.Unlock() return c.query(ctx, peerSchema) } else { return iter } } return iter } else { return c.query(ctx, peerSchema) } } func (c *Conn) querySystemLocal(ctx context.Context) *Iter { return c.query(ctx, "SELECT * FROM system.local WHERE key='local'") } func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { const localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" var versions map[string]struct{} var schemaVersion string endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) versions = make(map[string]struct{}) rows, err := iter.SliceMap() if err != nil { goto cont } for _, row := range rows { h, err := NewHostInfo(c.host.ConnectAddress(), c.session.cfg.Port) if err != nil { goto cont } host, err := c.session.hostInfoFromMap(row, h) if err != nil { goto cont } if !isValidPeer(host) || host.schemaVersion == "" { c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) continue } versions[host.schemaVersion] = struct{}{} } if err = iter.Close(); err != nil { goto cont } iter = c.query(ctx, localSchemas) for iter.Scan(&schemaVersion) { versions[schemaVersion] = struct{}{} schemaVersion = "" } if err = iter.Close(); err != nil { goto cont } if len(versions) <= 1 { return nil } cont: select { case <-ctx.Done(): return ctx.Err() case <-time.After(200 * time.Millisecond): } } if err != nil { return err } schemas := make([]string, 0, len(versions)) for schema := range versions { schemas = append(schemas, schema) } // not exported return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection") ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") ErrNoStreams = errors.New("gocql: no streams available on connection") )