dax/internal/client/single.go (516 lines of code) (raw):

/* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at http://www.apache.org/licenses/LICENSE-2.0 or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package client import ( "bytes" "context" "errors" "time" "github.com/aws/aws-dax-go-v2/dax/internal/cbor" "github.com/aws/aws-dax-go-v2/dax/internal/lru" "github.com/aws/aws-dax-go-v2/dax/utils" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/smithy-go" "github.com/aws/smithy-go/logging" ) const ( userAgent = "DaxGoClient-1.0.0" daxAddress = "https://dax.amazonaws.com" authTtlSecs = 5 * 60 tubeAuthWindowScalar = 0.75 emptyAttributeListId = 1 ) const ( serviceName = "dax" opDefineAttributeList = "DefineAttributeList" opDefineAttributeListId = "DefineAttributeListId" opDefineKeySchema = "DefineKeySchema" opEndpoints = "Endpoints" OpGetItem = "GetItem" OpPutItem = "PutItem" OpUpdateItem = "UpdateItem" OpDeleteItem = "DeleteItem" OpBatchGetItem = "BatchGetItem" OpBatchWriteItem = "BatchWriteItem" OpTransactGetItems = "TransactGetItems" OpTransactWriteItems = "TransactWriteItems" OpQuery = "Query" OpScan = "Scan" ) const ( keySchemaLruCacheSize = 100 attributeListLruCacheSize = 1000 ) type SingleDaxClient struct { region string credentials aws.CredentialsProvider tubeAuthWindowSecs int64 executor *taskExecutor pool *tubePool keySchema *lru.Lru attrNamesListToId *lru.Lru attrListIdToNames *lru.Lru healthStatus HealthStatus } func NewSingleClient(endpoint string, connConfigData connConfig, region string, credentials aws.CredentialsProvider, routeListener RouteListener) (*SingleDaxClient, error) { return newSingleClientWithOptions(endpoint, connConfigData, region, credentials, -1, defaultDialer.DialContext, routeListener) } func newSingleClientWithOptions(endpoint string, connConfigData connConfig, region string, credentials aws.CredentialsProvider, maxPendingConnections int, dialContextFn dialContext, routeListener RouteListener) (*SingleDaxClient, error) { po := defaultTubePoolOptions if maxPendingConnections > 0 { po.maxConcurrentConnAttempts = maxPendingConnections } po.dialContext = dialContextFn client := &SingleDaxClient{ region: region, credentials: credentials, tubeAuthWindowSecs: authTtlSecs * tubeAuthWindowScalar, pool: newTubePoolWithOptions(endpoint, po, connConfigData), executor: newExecutor(), healthStatus: newHealthStatus(endpoint, routeListener), } client.keySchema = &lru.Lru{ MaxEntries: keySchemaLruCacheSize, LoadFunc: func(ctx context.Context, key lru.Key) (interface{}, error) { table, ok := key.(string) if !ok { return nil, &smithy.SerializationError{Err: errors.New("unexpected type for table name")} } if ctx == nil { ctx = context.Background() } return client.defineKeySchema(ctx, table) }, } client.attrNamesListToId = &lru.Lru{ MaxEntries: attributeListLruCacheSize, LoadFunc: func(ctx context.Context, key lru.Key) (interface{}, error) { attrNames, ok := key.([]string) if !ok { return nil, &smithy.SerializationError{Err: errors.New("unexpected type for attribute list")} } if ctx == nil { ctx = context.Background() } return client.defineAttributeListId(ctx, attrNames) }, KeyMarshaller: func(key lru.Key) lru.Key { var buf bytes.Buffer w := cbor.NewWriter(&buf) defer w.Close() for _, v := range key.([]string) { w.WriteString(v) } w.Flush() return string(buf.Bytes()) }, } client.attrListIdToNames = &lru.Lru{ MaxEntries: attributeListLruCacheSize, LoadFunc: func(ctx context.Context, key lru.Key) (interface{}, error) { id, ok := key.(int64) if !ok { return nil, &smithy.SerializationError{Err: errors.New("unexpected type for attribute list id")} } if ctx == nil { ctx = context.Background() } return client.defineAttributeList(ctx, id) }, } return client, nil } func (client *SingleDaxClient) Close() error { client.executor.stopAll() if client.pool != nil { return client.pool.Close() } return nil } func (client *SingleDaxClient) startHealthChecks(cc *cluster, host hostPort) { cc.debugLog("Starting health checks for :: " + host.host) client.executor.start(cc.config.ClientHealthCheckInterval, func() error { ctx, cfn := context.WithTimeout(context.Background(), 1*time.Second) defer cfn() var err error opts := RequestOptions{} opts.RetryMaxAttempts = 3 _, err = client.endpoints(ctx, opts) if err != nil { cc.debugLog("Health checks failed with error " + err.Error() + " for host :: " + host.host) cc.onHealthCheckFailed(host) } else { client.healthStatus.onHealthCheckSuccess(client) cc.debugLog("Health checks succeeded for host:: " + host.host) } return nil }) } func (client *SingleDaxClient) endpoints(ctx context.Context, opt RequestOptions) ([]serviceEndpoint, error) { encoder := func(writer *cbor.Writer) error { return encodeEndpointsInput(writer) } var out []serviceEndpoint var err error decoder := func(reader *cbor.Reader) error { out, err = decodeEndpointsOutput(reader) return err } if err = client.executeWithRetries(ctx, opEndpoints, opt, encoder, decoder); err != nil { return nil, err } return out, nil } func (client *SingleDaxClient) defineAttributeListId(ctx context.Context, attrNames []string) (int64, error) { if len(attrNames) == 0 { return emptyAttributeListId, nil } encoder := func(writer *cbor.Writer) error { return encodeDefineAttributeListIdInput(attrNames, writer) } var out int64 var err error decoder := func(reader *cbor.Reader) error { out, err = decodeDefineAttributeListIdOutput(reader) return err } opt := RequestOptions{} if err = client.executeWithRetries(ctx, opDefineAttributeListId, opt, encoder, decoder); err != nil { return 0, err } return out, nil } func (client *SingleDaxClient) defineAttributeList(ctx context.Context, id int64) ([]string, error) { if id == emptyAttributeListId { return []string{}, nil } encoder := func(writer *cbor.Writer) error { return encodeDefineAttributeListInput(id, writer) } var out []string var err error decoder := func(reader *cbor.Reader) error { out, err = decodeDefineAttributeListOutput(reader) return err } opt := RequestOptions{} if err = client.executeWithRetries(ctx, opDefineAttributeList, opt, encoder, decoder); err != nil { return nil, err } return out, nil } func (client *SingleDaxClient) defineKeySchema(ctx context.Context, table string) ([]types.AttributeDefinition, error) { encoder := func(writer *cbor.Writer) error { return encodeDefineKeySchemaInput(table, writer) } var out []types.AttributeDefinition var err error decoder := func(reader *cbor.Reader) error { out, err = decodeDefineKeySchemaOutput(reader) return err } opt := RequestOptions{} if err = client.executeWithRetries(ctx, opDefineKeySchema, opt, encoder, decoder); err != nil { return nil, err } return out, nil } func (client *SingleDaxClient) PutItemWithOptions(ctx context.Context, input *dynamodb.PutItemInput, output *dynamodb.PutItemOutput, opt RequestOptions) (*dynamodb.PutItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodePutItemInput(ctx, input, client.keySchema, client.attrNamesListToId, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodePutItemOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpPutItem, opt, encoder, decoder); err != nil { return output, err } return output, nil } func (client *SingleDaxClient) DeleteItemWithOptions(ctx context.Context, input *dynamodb.DeleteItemInput, output *dynamodb.DeleteItemOutput, opt RequestOptions) (*dynamodb.DeleteItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeDeleteItemInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeDeleteItemOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpDeleteItem, opt, encoder, decoder); err != nil { return output, err } return output, nil } func (client *SingleDaxClient) UpdateItemWithOptions(ctx context.Context, input *dynamodb.UpdateItemInput, output *dynamodb.UpdateItemOutput, opt RequestOptions) (*dynamodb.UpdateItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeUpdateItemInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeUpdateItemOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpUpdateItem, opt, encoder, decoder); err != nil { return output, err } return output, nil } func (client *SingleDaxClient) GetItemWithOptions(ctx context.Context, input *dynamodb.GetItemInput, output *dynamodb.GetItemOutput, opt RequestOptions) (*dynamodb.GetItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeGetItemInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeGetItemOutput(ctx, reader, input, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpGetItem, opt, encoder, decoder); err != nil { client.healthStatus.onErrorInReadRequest(err, client) return output, err } client.healthStatus.onSuccessInReadRequest() return output, nil } func (client *SingleDaxClient) ScanWithOptions(ctx context.Context, input *dynamodb.ScanInput, output *dynamodb.ScanOutput, opt RequestOptions) (*dynamodb.ScanOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeScanInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeScanOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpScan, opt, encoder, decoder); err != nil { client.healthStatus.onErrorInReadRequest(err, client) return output, err } client.healthStatus.onSuccessInReadRequest() return output, nil } func (client *SingleDaxClient) QueryWithOptions(ctx context.Context, input *dynamodb.QueryInput, output *dynamodb.QueryOutput, opt RequestOptions) (*dynamodb.QueryOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeQueryInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeQueryOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpQuery, opt, encoder, decoder); err != nil { client.healthStatus.onErrorInReadRequest(err, client) return output, err } client.healthStatus.onSuccessInReadRequest() return output, nil } func (client *SingleDaxClient) BatchWriteItemWithOptions(ctx context.Context, input *dynamodb.BatchWriteItemInput, output *dynamodb.BatchWriteItemOutput, opt RequestOptions) (*dynamodb.BatchWriteItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeBatchWriteItemInput(ctx, input, client.keySchema, client.attrNamesListToId, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeBatchWriteItemOutput(ctx, reader, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpBatchWriteItem, opt, encoder, decoder); err != nil { return output, err } return output, nil } func (client *SingleDaxClient) BatchGetItemWithOptions(ctx context.Context, input *dynamodb.BatchGetItemInput, output *dynamodb.BatchGetItemOutput, opt RequestOptions) (*dynamodb.BatchGetItemOutput, error) { encoder := func(writer *cbor.Writer) error { return encodeBatchGetItemInput(ctx, input, client.keySchema, writer) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeBatchGetItemOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpBatchGetItem, opt, encoder, decoder); err != nil { client.healthStatus.onErrorInReadRequest(err, client) return output, err } client.healthStatus.onSuccessInReadRequest() return output, nil } func (client *SingleDaxClient) TransactWriteItemsWithOptions(ctx context.Context, input *dynamodb.TransactWriteItemsInput, output *dynamodb.TransactWriteItemsOutput, opt RequestOptions) (*dynamodb.TransactWriteItemsOutput, error) { extractedKeys := make([]map[string]types.AttributeValue, len(input.TransactItems)) encoder := func(writer *cbor.Writer) error { return encodeTransactWriteItemsInput(ctx, input, client.keySchema, client.attrNamesListToId, writer, extractedKeys) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeTransactWriteItemsOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpBatchWriteItem, opt, encoder, decoder); err != nil { if failure, ok := err.(*daxTransactionCanceledFailure); ok { var cancellationReasons []types.CancellationReason if cancellationReasons, err = decodeTransactionCancellationReasons(ctx, failure, extractedKeys, client.attrListIdToNames); err != nil { return output, err } failure.cancellationReasons = cancellationReasons return output, failure } return output, err } return output, nil } func (client *SingleDaxClient) TransactGetItemsWithOptions(ctx context.Context, input *dynamodb.TransactGetItemsInput, output *dynamodb.TransactGetItemsOutput, opt RequestOptions) (*dynamodb.TransactGetItemsOutput, error) { extractedKeys := make([]map[string]types.AttributeValue, len(input.TransactItems)) encoder := func(writer *cbor.Writer) error { return encodeTransactGetItemsInput(ctx, input, client.keySchema, writer, extractedKeys) } var err error decoder := func(reader *cbor.Reader) error { output, err = decodeTransactGetItemsOutput(ctx, reader, input, client.keySchema, client.attrListIdToNames, output) return err } if err = client.executeWithRetries(ctx, OpBatchWriteItem, opt, encoder, decoder); err != nil { if failure, ok := err.(*daxTransactionCanceledFailure); ok { var cancellationReasons []types.CancellationReason if cancellationReasons, err = decodeTransactionCancellationReasons(ctx, failure, extractedKeys, client.attrListIdToNames); err != nil { return output, err } failure.cancellationReasons = cancellationReasons return output, failure } return output, err } return output, nil } func (client *SingleDaxClient) newContext(ctx context.Context, o RequestOptions) context.Context { if o.Context != nil { return o.Context } if ctx != nil { return ctx } return context.Background() } func (client *SingleDaxClient) executeWithRetries(ctx context.Context, op string, o RequestOptions, encoder func(writer *cbor.Writer) error, decoder func(reader *cbor.Reader) error) error { ctx = client.newContext(ctx, o) var err error attempts := o.RetryMaxAttempts // Start from 0 to accommodate for the initial request for i := 0; i <= attempts; i++ { if i > 0 && o.Logger != nil && o.LogLevel.Matches(utils.LogDebugWithRequestRetries) { o.Logger.Logf(logging.Debug, "Retrying Request %s/%s, attempt %d", service, op, i) } err = client.executeWithContext(ctx, op, encoder, decoder, o) if err == nil { return nil } if errors.Is(err, context.Canceled) { return &smithy.CanceledError{Err: err} } if i != attempts { delay := o.RetryDelay if sleepErr := SleepWithContext(ctx, op, delay); sleepErr != nil { return &smithy.OperationError{Err: sleepErr, ServiceID: service, OperationName: op} } if o.Logger != nil && o.LogLevel.Matches(utils.LogDebugWithRequestRetries) { o.Logger.Logf(logging.Debug, "Error in executing %s%s : %s", service, op, err) } } } // Return the last error occurred return translateError(err) } func (client *SingleDaxClient) executeWithContext(ctx context.Context, op string, encoder func(writer *cbor.Writer) error, decoder func(reader *cbor.Reader) error, opt RequestOptions) error { t, err := client.pool.getWithContext(ctx, client.isHighPriority(op), opt) if err != nil { return err } if err = client.pool.setDeadline(ctx, t); err != nil { // If the error is just due to context cancelled or timeout // then the tube is still usable because we have not written anything to tube if err == ctx.Err() { client.pool.put(t) return err } // If we get error while setting deadline of tube // probably something is wrong with the tube client.pool.closeTube(t) return err } if err = client.auth(ctx, t); err != nil { // Auth method writes in the tube and // it is not guaranteed that it will be drained completely on error client.pool.closeTube(t) return err } writer := t.CborWriter() if err = encoder(writer); err != nil { // Validation errors will cause connection to be closed as there is no guarantee // that the validation was performed before any data was written into tube client.pool.closeTube(t) return err } if err := writer.Flush(); err != nil { client.pool.closeTube(t) return err } reader := t.CborReader() ex, err := decodeError(reader) if err != nil { // decode or network error - doesn't guarantee completely drained tube client.pool.closeTube(t) return err } if ex != nil { // user or server error client.recycleTube(t, ex) return ex } err = decoder(reader) if err != nil { // we are not able to completely drain tube client.pool.closeTube(t) } else { client.pool.put(t) } return err } func (client *SingleDaxClient) isHighPriority(op string) bool { switch op { case opDefineAttributeListId, opDefineAttributeList, opDefineKeySchema: return true default: return false } } func (client *SingleDaxClient) recycleTube(t tube, err error) { if t == nil { return } var recycle bool if err == nil { recycle = true } else { // IO streams are guaranteed to be completely drained only on daxRequestException d, ok := err.(*daxRequestFailure) recycle = ok if ok && d.authError() { t.SetAuthExpiryUnix(time.Now().Unix()) } } if recycle { client.pool.put(t) } else { client.pool.closeTube(t) } } func (client *SingleDaxClient) auth(ctx context.Context, t tube) error { // TODO credentials.Get() cause a throughput drop of ~25 with 250 goroutines with DefaultCredentialChain (only instance profile credentials available) creds, err := client.credentials.Retrieve(ctx) if err != nil { return err } now := time.Now().UTC() if t.CompareAndSwapAuthID(creds.AccessKeyID) || t.AuthExpiryUnix() <= now.Unix() { stringToSign, signature := generateSigV4WithTime(creds, daxAddress, client.region, "", now) writer := t.CborWriter() if err := encodeAuthInput(creds.AccessKeyID, creds.SessionToken, stringToSign, signature, userAgent, writer); err != nil { return err } if err := writer.Flush(); err != nil { return err } t.SetAuthExpiryUnix(now.Unix() + client.tubeAuthWindowSecs) } return nil } func (client *SingleDaxClient) reapIdleConnections() { client.pool.reapIdleConnections() } type HealthCheckDaxAPI interface { startHealthChecks(cc *cluster, host hostPort) }