expressions_api/expression_verify.go (215 lines of code) (raw):
package expressions_api
import (
"context"
"encoding/json"
"fmt"
"sync"
spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/task"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
)
const THREAD_POOL = 500
type ExpressionVerificationAccessor interface {
//Batch API which parallelizes expression verification calls
VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}
type ExpressionVerificationAccessorImpl struct {
SpannerAccessor *spanneraccessor.SpannerAccessorImpl
}
func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) {
var spannerAccessor *spanneraccessor.SpannerAccessorImpl
var err error
if project == "" || instance == "" {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImpl(ctx)
if err != nil {
return nil, err
}
} else {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf(constants.DB_URI, project, instance, constants.TEMP_DB))
if err != nil {
return nil, err
}
}
return &ExpressionVerificationAccessorImpl{
SpannerAccessor: spannerAccessor,
}, nil
}
// APIs to verify and process Spanner DLL features such as Default Values, Check Constraints
type DDLVerifier interface {
VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error)
GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}
type DDLVerifierImpl struct {
Expressions ExpressionVerificationAccessor
}
func NewDDLVerifierImpl(ctx context.Context, project string, instance string) (*DDLVerifierImpl, error) {
expVerifier, err := NewExpressionVerificationAccessorImpl(ctx, project, instance)
return &DDLVerifierImpl{
Expressions: expVerifier,
}, err
}
func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
err := ev.validateRequest(verifyExpressionsInput)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
dbURI := ev.SpannerAccessor.SpannerClient.DatabaseName()
dbExists, err := ev.SpannerAccessor.CheckExistingDb(ctx, dbURI)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
if dbExists {
err := ev.SpannerAccessor.DropDatabase(ctx, dbURI)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
}
verifyExpressionsInput.Conv, err = ev.removeExpressions(verifyExpressionsInput.Conv)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
err = ev.SpannerAccessor.CreateDatabase(ctx, dbURI, verifyExpressionsInput.Conv, verifyExpressionsInput.Source, constants.DATAFLOW_MIGRATION)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
//Drop the staging database after verifications are completed.
defer ev.SpannerAccessor.DropDatabase(ctx, dbURI)
//This recreates a spanner client for the staging database before doing operations on it.
ev.SpannerAccessor.Refresh(ctx, dbURI)
r := task.RunParallelTasksImpl[internal.ExpressionDetail, internal.ExpressionVerificationOutput]{}
expressionVerificationOutputList, _ := r.RunParallelTasks(verifyExpressionsInput.ExpressionDetailList, THREAD_POOL, ev.verifyExpressionInternal, true)
var verifyExpressionsOutput internal.VerifyExpressionsOutput
var errorCount int16 = 0
for _, expressionVerificationOutput := range expressionVerificationOutputList {
verifyExpressionsOutput.ExpressionVerificationOutputList = append(verifyExpressionsOutput.ExpressionVerificationOutputList, expressionVerificationOutput.Result)
if expressionVerificationOutput.Result.Err != nil {
errorCount++
}
}
if errorCount != 0 {
verifyExpressionsOutput.Err = fmt.Errorf("%d expressions either failed verification or did not get verified. Please look at the individual errors returned for each expression", errorCount)
}
return verifyExpressionsOutput
}
func (ev *ExpressionVerificationAccessorImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
spannerClient, err := spannerclient.NewSpannerClientImpl(ctx, fmt.Sprintf(constants.DB_URI, project, instance, constants.TEMP_DB))
if err != nil {
return err
}
ev.SpannerAccessor.SpannerClient = spannerClient
return nil
}
func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionDetail internal.ExpressionDetail, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] {
var sqlStatement string
switch expressionDetail.Type {
case constants.CHECK_EXPRESSION:
sqlStatement = fmt.Sprintf("SELECT 1 from %s where %s;", expressionDetail.ReferenceElement.Name, expressionDetail.Expression)
case constants.DEFAULT_EXPRESSION:
sqlStatement = fmt.Sprintf("SELECT CAST(%s as %s)", expressionDetail.Expression, expressionDetail.ReferenceElement.Name)
default:
return task.TaskResult[internal.ExpressionVerificationOutput]{Result: internal.ExpressionVerificationOutput{Result: false, Err: fmt.Errorf("invalid expression type requested")}, Err: nil}
}
result, err := ev.SpannerAccessor.ValidateDML(context.Background(), sqlStatement)
return task.TaskResult[internal.ExpressionVerificationOutput]{Result: internal.ExpressionVerificationOutput{Result: result, Err: err, ExpressionDetail: expressionDetail}, Err: nil}
}
func (ev *ExpressionVerificationAccessorImpl) validateRequest(verifyExpressionsInput internal.VerifyExpressionsInput) error {
if verifyExpressionsInput.Conv == nil || verifyExpressionsInput.Source == "" {
return fmt.Errorf("one of conv or source is empty. These are mandatory fields = %v", verifyExpressionsInput)
}
for _, expressionDetail := range verifyExpressionsInput.ExpressionDetailList {
if expressionDetail.ExpressionId == "" || expressionDetail.Expression == "" || expressionDetail.Type == "" || expressionDetail.ReferenceElement.Name == "" {
return fmt.Errorf("one of expressionId, expression, type or referenceElement.Name is empty. These are mandatory fields = %v", expressionDetail)
}
}
return nil
}
// We simplify conv to remove any existing expressions that are part of the SpSchema to ensure that the stagingDB creation
// does not fail due to inconsistent, user configured expressions during a schema conversion session.
// The minimal conv object needed for stagingDB is one which contains all table and column definitions only.
func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *internal.Conv) (*internal.Conv, error) {
convCopy := &internal.Conv{}
convJSON, err := json.Marshal(inputConv)
if err != nil {
return nil, fmt.Errorf("error marshaling conv: %v", err)
}
err = json.Unmarshal(convJSON, convCopy)
if err != nil {
return nil, fmt.Errorf("error unmarshaling conv: %v", err)
}
//Set sequences as nil
//TODO: Implement similar checks for DEFAULT and CHECK constraints as well
convCopy.SpSequences = nil
for _, table := range convCopy.SpSchema {
table.CheckConstraints = []ddl.CheckConstraint{}
convCopy.SpSchema[table.Id] = table
for colName, colDef := range table.ColDefs {
colDef.AutoGen = ddl.AutoGenCol{}
colDef.DefaultValue = ddl.DefaultValue{}
table.ColDefs[colName] = colDef
}
}
return convCopy, nil
}
func (ddlv *DDLVerifierImpl) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) {
ctx := context.Background()
verifyExpressionsInput := internal.VerifyExpressionsInput{
Conv: conv,
Source: conv.Source,
ExpressionDetailList: expressionDetails,
}
ddlv.RefreshSpannerClient(ctx, conv.SpProjectId, conv.SpInstanceId)
verificationResults := ddlv.Expressions.VerifyExpressions(ctx, verifyExpressionsInput)
return verificationResults, verificationResults.Err
}
func (ddlv *DDLVerifierImpl) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
srcTable := conv.SrcSchema[tableId]
for _, srcColId := range srcTable.ColIds {
srcCol := srcTable.ColDefs[srcColId]
if srcCol.DefaultValue.IsPresent {
tyName := conv.SpSchema[tableId].ColDefs[srcColId].T.Name
if conv.SpDialect == constants.DIALECT_POSTGRESQL {
tyName = ddl.GetPGType(conv.SpSchema[tableId].ColDefs[srcColId].T)
}
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: tyName,
},
ExpressionId: srcCol.DefaultValue.Value.ExpressionId,
Expression: srcCol.DefaultValue.Value.Statement,
Type: constants.DEFAULT_EXPRESSION,
Metadata: map[string]string{"TableId": tableId, "ColId": srcColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}
func (ddlv *DDLVerifierImpl) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
spTable := conv.SpSchema[tableId]
for _, spColId := range spTable.ColIds {
spCol := spTable.ColDefs[spColId]
if spCol.DefaultValue.IsPresent {
tyName := conv.SpSchema[tableId].ColDefs[spColId].T.Name
if conv.SpDialect == constants.DIALECT_POSTGRESQL {
tyName = ddl.GetPGType(conv.SpSchema[tableId].ColDefs[spColId].T)
}
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: tyName,
},
ExpressionId: spCol.DefaultValue.Value.ExpressionId,
Expression: spCol.DefaultValue.Value.Statement,
Type: constants.DEFAULT_EXPRESSION,
Metadata: map[string]string{"TableId": tableId, "ColId": spColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}
func (ddlv *DDLVerifierImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
return ddlv.Expressions.RefreshSpannerClient(ctx, project, instance)
}