traffic_ops/app/db/traffic_vault_migrate/postgres.go (650 lines of code) (raw):
package main
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/apache/trafficcontrol/v8/lib/go-log"
"github.com/apache/trafficcontrol/v8/lib/go-tc"
util "github.com/apache/trafficcontrol/v8/lib/go-util"
_ "github.com/lib/pq"
)
// PGConfig represents the configuration options available to the PG backend.
type PGConfig struct {
Host string `json:"host"`
Port string `json:"port"`
User string `json:"user"`
Password string `json:"password"`
SSLMode string `json:"sslmode"`
Database string `json:"database"`
KeyBase64 string `json:"aesKey"`
AESKey []byte
}
// PGBackend is the Postgres implementation of TVBackend.
type PGBackend struct {
sslKey pgSSLKeyTable
dnssec pgDNSSecTable
uriSigningKeys pgURISignKeyTable
urlSigKeys pgURLSigKeyTable
cfg PGConfig
db *sql.DB
}
// String returns a high level overview of the backend and its keys.
func (pg *PGBackend) String() string {
data := fmt.Sprintf("PG server %s@%s:%s\n", pg.cfg.User, pg.cfg.Host, pg.cfg.Port)
data += fmt.Sprintf("\tSSL Keys: %d\n", len(pg.sslKey.Records))
data += fmt.Sprintf("\tDNSSec Keys: %d\n", len(pg.dnssec.Records))
data += fmt.Sprintf("\tURI Signing Keys: %d\n", len(pg.uriSigningKeys.Records))
data += fmt.Sprintf("\tURL Sig Keys: %d\n", len(pg.urlSigKeys.Records))
return data
}
// Name returns the name for this backend.
func (pg *PGBackend) Name() string {
return "PG"
}
// ReadConfigFile takes in a filename and will read it into the backends config.
func (pg *PGBackend) ReadConfigFile(configFile string) error {
var err error
if err = UnmarshalConfig(configFile, &pg.cfg); err != nil {
return err
}
if pg.cfg.AESKey, err = base64.StdEncoding.DecodeString(pg.cfg.KeyBase64); err != nil {
return fmt.Errorf("unable to decode PG AESKey '%s': %w", pg.cfg.KeyBase64, err)
}
if err = util.ValidateAESKey(pg.cfg.AESKey); err != nil {
return fmt.Errorf("unable to validate PG AESKey '%s'", pg.cfg.KeyBase64)
}
return nil
}
// Insert takes the current keys and inserts them into the backend DB.
func (pg *PGBackend) Insert() error {
if err := pg.sslKey.insertKeys(pg.db); err != nil {
return err
}
if err := pg.dnssec.insertKeys(pg.db); err != nil {
return err
}
if err := pg.urlSigKeys.insertKeys(pg.db); err != nil {
return err
}
if err := pg.uriSigningKeys.insertKeys(pg.db); err != nil {
return err
}
return nil
}
// Start initiates the connection to the backend DB.
func (pg *PGBackend) Start() error {
sqlStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", pg.cfg.User, pg.cfg.Password, pg.cfg.Host, pg.cfg.Port, pg.cfg.Database, pg.cfg.SSLMode)
db, err := sql.Open("postgres", sqlStr)
if err != nil {
sqlStr = strings.Replace(sqlStr, pg.cfg.Password, "*", 1)
return fmt.Errorf("unable to start PG client with connection string '%s': %w", sqlStr, err)
}
pg.db = db
pg.sslKey = pgSSLKeyTable{}
pg.dnssec = pgDNSSecTable{}
pg.urlSigKeys = pgURLSigKeyTable{}
pg.uriSigningKeys = pgURISignKeyTable{}
return nil
}
// ValidateKey validates that the keys are valid (in most cases, certain fields are not null).
func (pg *PGBackend) ValidateKey() []string {
var allErrs []string
if errs := pg.sslKey.validate(); errs != nil {
allErrs = append(allErrs, errs...)
}
if errs := pg.dnssec.validate(); errs != nil {
allErrs = append(allErrs, errs...)
}
if errs := pg.uriSigningKeys.validate(); errs != nil {
allErrs = append(allErrs, errs...)
}
if errs := pg.urlSigKeys.validate(); errs != nil {
allErrs = append(allErrs, errs...)
}
return allErrs
}
// Close terminates the connection to the backend DB.
func (pg *PGBackend) Close() error {
return pg.db.Close()
}
// Ping checks the connection to the backend DB.
func (pg *PGBackend) Ping() error {
return pg.db.Ping()
}
// Fetch gets all of the keys from the backend DB.
func (pg *PGBackend) Fetch() error {
if err := pg.sslKey.gatherKeys(pg.db); err != nil {
return err
}
if err := pg.dnssec.gatherKeys(pg.db); err != nil {
return err
}
if err := pg.urlSigKeys.gatherKeys(pg.db); err != nil {
return err
}
if err := pg.uriSigningKeys.gatherKeys(pg.db); err != nil {
return err
}
return nil
}
// GetSSLKeys converts the backends internal key representation into the common representation (SSLKey).
func (pg *PGBackend) GetSSLKeys() ([]SSLKey, error) {
if err := pg.sslKey.decrypt(pg.cfg.AESKey); err != nil {
return nil, err
}
return pg.sslKey.toGeneric(), nil
}
// SetSSLKeys takes in keys and converts & encrypts the data into the backends internal format.
func (pg *PGBackend) SetSSLKeys(keys []SSLKey) error {
pg.sslKey.fromGeneric(keys)
return pg.sslKey.encrypt(pg.cfg.AESKey)
}
// GetDNSSecKeys converts the backends internal key representation into the common representation (DNSSecKey).
func (pg *PGBackend) GetDNSSecKeys() ([]DNSSecKey, error) {
if err := pg.dnssec.decrypt(pg.cfg.AESKey); err != nil {
return nil, err
}
return pg.dnssec.toGeneric(), nil
}
// SetDNSSecKeys takes in keys and converts & encrypts the data into the backends internal format.
func (pg *PGBackend) SetDNSSecKeys(keys []DNSSecKey) error {
pg.dnssec.fromGeneric(keys)
return pg.dnssec.encrypt(pg.cfg.AESKey)
}
// GetURISignKeys converts the pg internal key representation into the common representation (URISignKey).
func (pg *PGBackend) GetURISignKeys() ([]URISignKey, error) {
if err := pg.uriSigningKeys.decrypt(pg.cfg.AESKey); err != nil {
return nil, err
}
return pg.uriSigningKeys.toGeneric(), nil
}
// SetURISignKeys takes in keys and converts & encrypts the data into the backends internal format.
func (pg *PGBackend) SetURISignKeys(keys []URISignKey) error {
pg.uriSigningKeys.fromGeneric(keys)
return pg.uriSigningKeys.encrypt(pg.cfg.AESKey)
}
// GetURLSigKeys converts the backends internal key representation into the common representation (URLSigKey).
func (pg *PGBackend) GetURLSigKeys() ([]URLSigKey, error) {
if err := pg.urlSigKeys.decrypt(pg.cfg.AESKey); err != nil {
return nil, err
}
return pg.urlSigKeys.toGeneric(), nil
}
// SetURLSigKeys takes in keys and converts & encrypts the data into the backends internal format.
func (pg *PGBackend) SetURLSigKeys(keys []URLSigKey) error {
pg.urlSigKeys.fromGeneric(keys)
return pg.urlSigKeys.encrypt(pg.cfg.AESKey)
}
type pgCommonRecord struct {
DataEncrypted []byte
}
type pgDNSSecRecord struct {
Key tc.DNSSECKeysTrafficVault
CDN string
pgCommonRecord
}
type pgDNSSecTable struct {
Records []pgDNSSecRecord
}
func (tbl *pgDNSSecTable) gatherKeys(db *sql.DB) error {
sz, err := getSize(db, "dnssec")
if err != nil {
log.Errorln("PGDNSSec gatherKeys: unable to determine size of dnssec table")
}
tbl.Records = make([]pgDNSSecRecord, sz)
query := "SELECT cdn, data from dnssec"
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("PGDNSSec gatherKeys: unable to run query '%s': %w", query, err)
}
defer log.Close(rows, "closing dnssec query")
i := 0
for rows.Next() {
if i > len(tbl.Records)-1 {
return fmt.Errorf("PGDNSSec gatherKeys got more results than expected %d", len(tbl.Records))
}
if err := rows.Scan(&tbl.Records[i].CDN, &tbl.Records[i].DataEncrypted); err != nil {
return fmt.Errorf("PGDNSSec gatherKeys unable to scan row: %w", err)
}
i += 1
}
return nil
}
func (tbl *pgDNSSecTable) decrypt(aesKey []byte) error {
for i, _ := range tbl.Records {
if err := decryptInto(aesKey, tbl.Records[i].DataEncrypted, &tbl.Records[i].Key); err != nil {
return fmt.Errorf("unable to decrypt into keys: %w", err)
}
}
return nil
}
func (tbl *pgDNSSecTable) encrypt(aesKey []byte) error {
for i, dns := range tbl.Records {
data, err := json.Marshal(&dns.Key)
if err != nil {
return fmt.Errorf("encrypt issue marshalling keys: %w", err)
}
dat, err := encrypt(data, aesKey)
if err != nil {
return fmt.Errorf("encrypt error: %w", err)
}
tbl.Records[i].DataEncrypted = dat
}
return nil
}
func (tbl *pgDNSSecTable) toGeneric() []DNSSecKey {
keys := make([]DNSSecKey, len(tbl.Records))
for i, record := range tbl.Records {
keys[i] = DNSSecKey{
CDN: record.CDN,
DNSSECKeysTrafficVault: record.Key,
}
}
return keys
}
func (tbl *pgDNSSecTable) fromGeneric(keys []DNSSecKey) {
tbl.Records = make([]pgDNSSecRecord, len(keys))
for i, key := range keys {
tbl.Records[i] = pgDNSSecRecord{
Key: key.DNSSECKeysTrafficVault,
CDN: key.CDN,
pgCommonRecord: pgCommonRecord{
DataEncrypted: nil,
},
}
}
}
func (tbl *pgDNSSecTable) validate() []string {
for _, record := range tbl.Records {
if record.DataEncrypted == nil && len(record.Key) > 0 {
return []string{fmt.Sprintf("DNSSEC Key CDN '%s': DataEncrypted is blank!", record.CDN)}
}
}
return nil
}
func (tbl *pgDNSSecTable) insertKeys(db *sql.DB) error {
queryBase := "INSERT INTO dnssec (cdn, data) VALUES %s ON CONFLICT (cdn) DO UPDATE SET data = EXCLUDED.data"
stride := 2
queryArgs := make([]interface{}, len(tbl.Records)*stride)
for i, record := range tbl.Records {
j := i * stride
queryArgs[j] = record.CDN
queryArgs[j+1] = record.DataEncrypted
}
return insertIntoTable(db, queryBase, stride, queryArgs)
}
type pgSSLKeyRecord struct {
Keys tc.DeliveryServiceSSLKeys
pgCommonRecord
// These records are stored on the table but are duplicated
DeliveryService string
CDN string
Version string
}
type pgSSLKeyTable struct {
Records []pgSSLKeyRecord
}
func (tbl *pgSSLKeyTable) insertKeys(db *sql.DB) error {
queryBase := "INSERT INTO sslkey (deliveryservice, data, cdn, version, provider) VALUES %s ON CONFLICT (deliveryservice,cdn,version) DO UPDATE SET data = EXCLUDED.data"
stride := 5
queryArgs := make([]interface{}, len(tbl.Records)*stride)
for i, record := range tbl.Records {
j := i * stride
queryArgs[j] = record.DeliveryService
queryArgs[j+1] = record.DataEncrypted
queryArgs[j+2] = record.CDN
queryArgs[j+3] = record.Version
queryArgs[j+4] = ""
}
return insertIntoTable(db, queryBase, stride, queryArgs)
}
func (tbl *pgSSLKeyTable) gatherKeys(db *sql.DB) error {
sz, err := getSize(db, "sslkey")
if err != nil {
return fmt.Errorf("PGSSLKey gatherKeys unable to determine size of sslkey table: %w", err)
}
tbl.Records = make([]pgSSLKeyRecord, sz)
query := "SELECT data, deliveryservice, cdn, version from sslkey"
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("PGSSLKey gatherKeys unable to run query '%s': %w", query, err)
}
defer log.Close(rows, "closing sslkey query")
i := 0
for rows.Next() {
if i > len(tbl.Records)-1 {
return fmt.Errorf("PGSSLKey gatherKeys: got more results than expected")
}
if err := rows.Scan(&tbl.Records[i].DataEncrypted, &tbl.Records[i].DeliveryService, &tbl.Records[i].CDN, &tbl.Records[i].Version); err != nil {
return fmt.Errorf("PGSSLKey gatherKeys unable to scan %d row: %w", i, err)
}
i += 1
}
return nil
}
func (tbl *pgSSLKeyTable) decrypt(aesKey []byte) error {
for i, key := range tbl.Records {
if err := decryptInto(aesKey, key.DataEncrypted, &tbl.Records[i].Keys); err != nil {
return fmt.Errorf("unable to decrypt into keys: %w", err)
}
}
return nil
}
func (tbl *pgSSLKeyTable) encrypt(aesKey []byte) error {
for i, key := range tbl.Records {
data, err := json.Marshal(key.Keys)
if err != nil {
return fmt.Errorf("encrypt issue marshalling keys: %w", err)
}
dat, err := encrypt(data, aesKey)
if err != nil {
return fmt.Errorf("encrypt error: %w", err)
}
tbl.Records[i].DataEncrypted = dat
}
return nil
}
func (tbl *pgSSLKeyTable) toGeneric() []SSLKey {
keys := make([]SSLKey, len(tbl.Records))
for i, record := range tbl.Records {
keys[i] = SSLKey{
DeliveryServiceSSLKeys: record.Keys,
Version: record.Version,
}
}
return keys
}
func (tbl *pgSSLKeyTable) fromGeneric(keys []SSLKey) {
tbl.Records = make([]pgSSLKeyRecord, len(keys))
for i, key := range keys {
tbl.Records[i] = pgSSLKeyRecord{
Keys: key.DeliveryServiceSSLKeys,
pgCommonRecord: pgCommonRecord{
DataEncrypted: nil,
},
DeliveryService: key.DeliveryService,
CDN: key.CDN,
Version: key.Version,
}
}
}
func (tbl *pgSSLKeyTable) validate() []string {
defaultKey := tc.DeliveryServiceSSLKeys{}
var errs []string
fmtStr := "SSL Key DS '%s': %s"
for _, record := range tbl.Records {
if record.Keys.DeliveryService == "" {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "DS is blank!"))
} else if record.Keys == defaultKey {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "DS SSL Keys are default!"))
} else if record.Keys.Key == "" {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "Key is blank!"))
} else if record.Keys.CDN == "" {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "CDN is blank!"))
} else if record.DataEncrypted == nil {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "DataEncrypted is blank!"))
} else if record.Version == "" {
errs = append(errs, fmt.Sprintf(fmtStr, record.DeliveryService, "Version is blank!"))
}
}
return errs
}
type pgURLSigKeyRecord struct {
Keys tc.URLSigKeys
DeliveryService string
pgCommonRecord
}
type pgURLSigKeyTable struct {
Records []pgURLSigKeyRecord
}
func (tbl *pgURLSigKeyTable) insertKeys(db *sql.DB) error {
queryBase := "INSERT INTO url_sig_key (deliveryservice, data) VALUES %s ON CONFLICT (deliveryservice) DO UPDATE set data = EXCLUDED.data"
stride := 2
queryArgs := make([]interface{}, len(tbl.Records)*stride)
for i, record := range tbl.Records {
j := i * stride
queryArgs[j] = record.DeliveryService
queryArgs[j+1] = record.DataEncrypted
}
return insertIntoTable(db, queryBase, stride, queryArgs)
}
func (tbl *pgURLSigKeyTable) gatherKeys(db *sql.DB) error {
sz, err := getSize(db, "url_sig_key")
if err != nil {
log.Errorln("PGURLSigKey gatherKeys: unable to determine url_sig_key table size")
}
tbl.Records = make([]pgURLSigKeyRecord, sz)
query := "SELECT deliveryservice, data from url_sig_key"
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("PGURLSigKey gatherKeys error running query '%s': %w", query, err)
}
defer log.Close(rows, "closing url_sig_key query")
i := 0
for rows.Next() {
if i > len(tbl.Records)-1 {
return fmt.Errorf("PGURLSigKey gatherKeys: got more results than expected %d", len(tbl.Records))
}
if err := rows.Scan(&tbl.Records[i].DeliveryService, &tbl.Records[i].DataEncrypted); err != nil {
return fmt.Errorf("PGURLSigKey gatherKeys: unable to scan row: %w", err)
}
i += 1
}
return nil
}
func (tbl *pgURLSigKeyTable) decrypt(aesKey []byte) error {
for i, sig := range tbl.Records {
if err := decryptInto(aesKey, sig.DataEncrypted, &tbl.Records[i].Keys); err != nil {
return fmt.Errorf("unable to decrypt into keys: %w", err)
}
}
return nil
}
func (tbl *pgURLSigKeyTable) encrypt(aesKey []byte) error {
for i, sig := range tbl.Records {
data, err := json.Marshal(&sig.Keys)
if err != nil {
return fmt.Errorf("encrypt issue marshalling keys: %w", err)
}
dat, err := encrypt(data, aesKey)
if err != nil {
return fmt.Errorf("encrypt error: %w", err)
}
tbl.Records[i].DataEncrypted = dat
}
return nil
}
func (tbl *pgURLSigKeyTable) toGeneric() []URLSigKey {
keys := make([]URLSigKey, len(tbl.Records))
for i, record := range tbl.Records {
keys[i] = URLSigKey{
DeliveryService: record.DeliveryService,
URLSigKeys: record.Keys,
}
}
return keys
}
func (tbl *pgURLSigKeyTable) fromGeneric(keys []URLSigKey) {
tbl.Records = make([]pgURLSigKeyRecord, len(keys))
for i, key := range keys {
tbl.Records[i] = pgURLSigKeyRecord{
Keys: key.URLSigKeys,
DeliveryService: key.DeliveryService,
pgCommonRecord: pgCommonRecord{
DataEncrypted: nil,
},
}
}
}
func (tbl *pgURLSigKeyTable) validate() []string {
for _, record := range tbl.Records {
if record.DataEncrypted == nil && len(record.Keys) > 0 {
return []string{fmt.Sprintf("URL Sig Key DS '%s': DataEncrypted is blank!", record.DeliveryService)}
}
}
return nil
}
type pgURISignKeyRecord struct {
Keys tc.JWKSMap
DeliveryService string
pgCommonRecord
}
type pgURISignKeyTable struct {
Records []pgURISignKeyRecord
}
func (tbl *pgURISignKeyTable) insertKeys(db *sql.DB) error {
queryBase := "INSERT INTO uri_signing_key (deliveryservice, data) VALUES %s ON CONFLICT (deliveryservice) DO UPDATE SET data = EXCLUDED.data"
stride := 2
queryArgs := make([]interface{}, len(tbl.Records)*stride)
for i, record := range tbl.Records {
j := i * stride
queryArgs[j] = record.DeliveryService
queryArgs[j+1] = record.DataEncrypted
}
return insertIntoTable(db, queryBase, stride, queryArgs)
}
func (tbl *pgURISignKeyTable) gatherKeys(db *sql.DB) error {
sz, err := getSize(db, "uri_signing_key")
if err != nil {
log.Errorln("PGURISignKey gatherKeys: unable to determine size of uri_signing_key table")
}
tbl.Records = make([]pgURISignKeyRecord, sz)
query := "SELECT deliveryservice, data from uri_signing_key"
rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("PGURISignKey gatherKeys error while running query '%s': %w", query, err)
}
defer log.Close(rows, "closing uri_signing_key table")
i := 0
for rows.Next() {
if i > len(tbl.Records)-1 {
return fmt.Errorf("PGURISignKey gatherKeys: got more results than expected %d", len(tbl.Records))
}
if err := rows.Scan(&tbl.Records[i].DeliveryService, &tbl.Records[i].DataEncrypted); err != nil {
return fmt.Errorf("PGURISignKey gatherKeys: unable to scan row: %w", err)
}
i += 1
}
return nil
}
func (tbl *pgURISignKeyTable) decrypt(aesKey []byte) error {
for i, sign := range tbl.Records {
if err := decryptInto(aesKey, sign.DataEncrypted, &tbl.Records[i].Keys); err != nil {
return fmt.Errorf("unable to decrypt into keys: %w", err)
}
}
return nil
}
func (tbl *pgURISignKeyTable) encrypt(aesKey []byte) error {
for i, sign := range tbl.Records {
data, err := json.Marshal(sign.Keys)
if err != nil {
return fmt.Errorf("encrypt issue marshalling keys: %w", err)
}
dat, err := encrypt(data, aesKey)
if err != nil {
return fmt.Errorf("encrypt error: %w", err)
}
tbl.Records[i].DataEncrypted = dat
}
return nil
}
func (tbl *pgURISignKeyTable) toGeneric() []URISignKey {
keys := make([]URISignKey, len(tbl.Records))
for i, record := range tbl.Records {
keys[i] = URISignKey{
DeliveryService: record.DeliveryService,
Keys: record.Keys,
}
}
return keys
}
func (tbl *pgURISignKeyTable) fromGeneric(keys []URISignKey) {
tbl.Records = make([]pgURISignKeyRecord, len(keys))
for i, key := range keys {
tbl.Records[i] = pgURISignKeyRecord{
Keys: key.Keys,
DeliveryService: key.DeliveryService,
pgCommonRecord: pgCommonRecord{
DataEncrypted: nil,
},
}
}
}
func (tbl *pgURISignKeyTable) validate() []string {
for _, record := range tbl.Records {
if record.DataEncrypted == nil && len(record.Keys) > 0 {
return []string{fmt.Sprintf("URI Signing Key DS '%s': DataEncrypted is blank!", record.DeliveryService)}
}
}
return nil
}
func getSize(db *sql.DB, table string) (int64, error) {
rows, err := db.Query("SELECT COUNT(*) FROM " + table)
if err != nil {
return 0, err
}
var numRows int64
if !rows.Next() {
return 0, errors.New("no results returned for: " + table)
}
if err = rows.Scan(&numRows); err != nil {
return 0, fmt.Errorf("error reading number of results for %s: %w", table, err)
}
return numRows, nil
}
func decrypt(record []byte, aesKey []byte) ([]byte, error) {
unencrypted, err := util.AESDecrypt(record, aesKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt: %w", err)
}
return unencrypted, nil
}
func encrypt(record []byte, aesKey []byte) ([]byte, error) {
encrypted, err := util.AESEncrypt(record, aesKey)
if err != nil {
return nil, err
}
return encrypted, nil
}
func decryptInto(aesKey []byte, encData []byte, value interface{}) error {
data, err := decrypt(encData, aesKey)
if err != nil {
return err
}
if len(data) == 0 {
return errors.New("decrypted data is empty")
}
if err = json.Unmarshal(data, &value); err != nil {
return err
}
return nil
}
func insertIntoTable(db *sql.DB, queryBase string, stride int, queryArgs []interface{}) error {
if len(queryArgs) == 0 {
return nil
}
rows := len(queryArgs) / stride
workStr := ""
queryValueStr := make([]string, rows)
for i, _ := range queryArgs {
rowIndex := i % stride
rowGroup := i / stride
if rowIndex == 0 && i > 0 {
queryValueStr[rowGroup-1] = "(" + workStr + ")"
workStr = ""
}
if rowIndex == 0 {
workStr += "$"
} else {
workStr += ",$"
}
workStr += strconv.Itoa(i + 1)
}
queryValueStr[len(queryValueStr)-1] = "(" + workStr + ")"
query := fmt.Sprintf(queryBase, strings.Join(queryValueStr, ","))
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("unable to open db transaction: %w", err)
}
res, err := tx.Exec(query, queryArgs...)
if err != nil {
return rollback(tx, fmt.Errorf("error executing query '%s': %w", query, err))
}
if rows, err := res.RowsAffected(); err != nil {
return rollback(tx, fmt.Errorf("error getting rows affected: %w", err))
} else if rows != int64(len(queryValueStr)) {
return rollback(tx, fmt.Errorf("wanted to insert %d rows, but inserted %d\n", len(queryValueStr), rows))
}
return tx.Commit()
}
func rollback(tx *sql.Tx, addError error) error {
if err := tx.Rollback(); err != nil {
return fmt.Errorf("encountered error rolling back transaction %w %s", err, addError.Error())
}
return addError
}