common/persistence/sql/sql_testing_util.go (161 lines of code) (raw):
// Copyright (c) 2017 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package sql
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"strings"
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/dynamicconfig"
"github.com/uber/cadence/common/persistence/persistence-tests/testcluster"
"github.com/uber/cadence/environment"
)
// testCluster allows executing cassandra operations in testing.
type testCluster struct {
dbName string
schemaDir string
cfg config.SQL
}
var _ testcluster.PersistenceTestCluster = (*testCluster)(nil)
// NewTestCluster returns a new SQL test cluster
func NewTestCluster(pluginName, dbName, username, password, host string, port int, schemaDir string) (testcluster.PersistenceTestCluster, error) {
var result testCluster
var err error
if port == 0 {
port, err = environment.GetMySQLPort()
if err != nil {
return nil, err
}
}
if schemaDir == "" {
return nil, errors.New("schemaDir is empty")
}
result.dbName = dbName
result.schemaDir = schemaDir
result.cfg = config.SQL{
User: username,
Password: password,
ConnectAddr: fmt.Sprintf("%v:%v", host, port),
ConnectProtocol: "tcp",
PluginName: pluginName,
DatabaseName: dbName,
NumShards: 4,
EncodingType: "thriftrw",
DecodingTypes: []string{"thriftrw"},
}
return &result, nil
}
// DatabaseName from PersistenceTestCluster interface
func (s *testCluster) DatabaseName() string {
return s.dbName
}
// SetupTestDatabase from PersistenceTestCluster interface
func (s *testCluster) SetupTestDatabase() {
s.createDatabase()
schemaDir := s.schemaDir + "/"
if !strings.HasPrefix(schemaDir, "/") && !strings.HasPrefix(schemaDir, "../") {
cadencePackageDir, err := getCadencePackageDir()
if err != nil {
log.Fatal(err)
}
schemaDir = cadencePackageDir + schemaDir
}
s.loadSchema([]string{"schema.sql"}, schemaDir)
s.loadVisibilitySchema([]string{"schema.sql"}, schemaDir)
}
// Config returns the persistence config for connecting to this test cluster
func (s *testCluster) Config() config.Persistence {
cfg := s.cfg
return config.Persistence{
DefaultStore: "test",
VisibilityStore: "test",
DataStores: map[string]config.DataStore{
"test": {SQL: &cfg},
},
TransactionSizeLimit: dynamicconfig.GetIntPropertyFn(common.DefaultTransactionSizeLimit),
ErrorInjectionRate: dynamicconfig.GetFloatPropertyFn(0),
}
}
// TearDownTestDatabase from PersistenceTestCluster interface
func (s *testCluster) TearDownTestDatabase() {
s.dropDatabase()
}
// createDatabase from PersistenceTestCluster interface
func (s *testCluster) createDatabase() {
cfg2 := s.cfg
// NOTE need to connect with empty name to create new database
cfg2.DatabaseName = ""
db, err := NewSQLAdminDB(&cfg2)
if err != nil {
panic(err)
}
defer func() {
err := db.Close()
if err != nil {
panic(err)
}
}()
err = db.CreateDatabase(s.cfg.DatabaseName)
if err != nil {
panic(err)
}
}
// dropDatabase from PersistenceTestCluster interface
func (s *testCluster) dropDatabase() {
cfg2 := s.cfg
// NOTE need to connect with empty name to drop the database
cfg2.DatabaseName = ""
db, err := NewSQLAdminDB(&cfg2)
if err != nil {
panic(err)
}
defer func() {
err := db.Close()
if err != nil {
panic(err)
}
}()
err = db.DropDatabase(s.cfg.DatabaseName)
if err != nil {
panic(err)
}
}
// loadSchema from PersistenceTestCluster interface
func (s *testCluster) loadSchema(fileNames []string, schemaDir string) {
workflowSchemaDir := schemaDir + "/cadence"
err := s.loadDatabaseSchema(workflowSchemaDir, fileNames, true)
if err != nil {
log.Fatal(err)
}
}
// loadVisibilitySchema from PersistenceTestCluster interface
func (s *testCluster) loadVisibilitySchema(fileNames []string, schemaDir string) {
workflowSchemaDir := schemaDir + "/visibility"
err := s.loadDatabaseSchema(workflowSchemaDir, fileNames, true)
if err != nil {
log.Fatal(err)
}
}
func getCadencePackageDir() (string, error) {
cadencePackageDir, err := os.Getwd()
if err != nil {
panic(err)
}
cadenceIndex := strings.LastIndex(cadencePackageDir, "/cadence/")
cadencePackageDir = cadencePackageDir[:cadenceIndex+len("/cadence/")]
return cadencePackageDir, err
}
// loadDatabaseSchema loads the schema from the given .sql files on this database
func (s *testCluster) loadDatabaseSchema(dir string, fileNames []string, override bool) (err error) {
db, err := NewSQLAdminDB(&s.cfg)
if err != nil {
panic(err)
}
defer func() {
err := db.Close()
if err != nil {
panic(err)
}
}()
for _, file := range fileNames {
// This is only used in tests. Excluding it from security scanners
// #nosec
content, err := ioutil.ReadFile(dir + "/" + file)
if err != nil {
return fmt.Errorf("error reading contents of file %v:%v", file, err.Error())
}
err = db.ExecSchemaOperationQuery(context.Background(), string(content))
if err != nil {
return fmt.Errorf("error loading schema from %v: %v", file, err.Error())
}
}
return nil
}