util/util.go (194 lines of code) (raw):

// Copyright (c) 2017 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package util import ( "bytes" "context" "database/sql" "fmt" "io/ioutil" "net/http" "sort" "time" gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-sql-driver/mysql" "github.com/uber/storagetapper/log" "github.com/uber/storagetapper/types" ) var h = &http.Client{} //Timeout is a http request wait timeout var Timeout = 15 * time.Second //HTTPGetWithHeaders helper which returns response as a byte array func HTTPGetWithHeaders(_ context.Context, url string, headers map[string]string) (body []byte, err error) { log.Debugf("GetURL: %v", url) var req *http.Request req, err = http.NewRequest("GET", url, nil) if err != nil { return } for k, v := range headers { req.Header.Set(k, v) } var resp *http.Response resp, err = h.Do(req) if err != nil { return } body, err = ioutil.ReadAll(resp.Body) if err != nil { return } err = resp.Body.Close() if err != nil { return } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP Status: %v %v", resp.StatusCode, resp.Status) } return body, err } //HTTPGet helper which returns response as a byte array func HTTPGet(ctx context.Context, url string) (body []byte, err error) { return HTTPGetWithHeaders(ctx, url, nil) } //HTTPPostJSONWithHeaders posts given JSON message to given URL func HTTPPostJSONWithHeaders(url string, body string, headers map[string]string) error { log.Debugf("URL: %v, BODY: %v", url, body) req, err := http.NewRequest("POST", url, bytes.NewBuffer([]byte(body))) if err != nil { return err } for k, v := range headers { req.Header.Set(k, v) } req.Header.Set("Content-Type", "application/json") resp, err := h.Do(req) if err != nil { return err } b, err := ioutil.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusAccepted { return fmt.Errorf("%+v: %+v", resp.Status, string(b)) } log.Debugf("POST response: %v", string(b)) err = resp.Body.Close() return err } //HTTPPostJSON posts given JSON message to given URL func HTTPPostJSON(url string, body string) error { return HTTPPostJSONWithHeaders(url, body, nil) } //BytesToString converts zero terminated byte array to string //Whole array is used when there is no zero in the string func BytesToString(b []byte) string { n := bytes.IndexByte(b, 0) if n == -1 { n = len(b) } return string(b[:n]) } //ExecTxSQL executes SQL query in given transaction func ExecTxSQL(tx *sql.Tx, query string, param ...interface{}) error { log.Debugf("SQLTX: %v %v", query, param) _, err := tx.Exec(query, param...) return err } //ExecSQL executes SQL query func ExecSQL(d *sql.DB, query string, param ...interface{}) error { log.Debugf("SQL: %v %v", query, param) _, err := d.Exec(query, param...) for i := 0; MySQLError(err, 1213) && err != nil && i < 3; i++ { log.Debugf("SQL(retrying after deadlock): %v %v", query, param) _, err = d.Exec(query, param...) } return err } //QuerySQL executes SQL query func QuerySQL(d *sql.DB, query string, param ...interface{}) (*sql.Rows, error) { log.Debugf("SQL: %v %v", query, param) return d.Query(query, param...) } //QueryTxSQL executes SQL query func QueryTxSQL(tx *sql.Tx, query string, param ...interface{}) (*sql.Rows, error) { log.Debugf("SQLTX: %v %v", query, param) return tx.Query(query, param...) } //QueryRowSQL executes SQL query which return single row func QueryRowSQL(d *sql.DB, query string, param ...interface{}) *sql.Row { log.Debugf("SQLROW: %v %v", query, param) return d.QueryRow(query, param...) } //QueryTxRowSQL executes SQL query which return single row func QueryTxRowSQL(tx *sql.Tx, query string, param ...interface{}) *sql.Row { log.Debugf("SQLTXROW: %v %v", query, param) return tx.QueryRow(query, param...) } //CheckTxIsolation return nil if transaction isolation is "level" func CheckTxIsolation(tx *sql.Tx, level string) error { var txLevel string err := tx.QueryRow("select @@session.tx_isolation").Scan(&txLevel) if err != nil { return err } if txLevel != level { err = fmt.Errorf("transaction isolation level must be: %v, got: %v", level, txLevel) } return err } //MySQLError checks if givens error is MySQL error with given code func MySQLError(err error, code uint16) bool { merr, ok := err.(*mysql.MySQLError) return ok && merr.Number == code } // SortedGTIDString convert GTID set into string, where UUIDs comes in // lexicographically sorted order. // The order is the same as in output of select @@global.gtid_executed func SortedGTIDString(set *gomysql.MysqlGTIDSet) string { uuids := make([]string, 0, len(set.Sets)) for u := range set.Sets { uuids = append(uuids, u) } sort.Strings(uuids) var b bytes.Buffer for _, v := range uuids { if b.Len() > 0 { b.WriteString(",") } b.WriteString(set.Sets[v].String()) } return b.String() } // MySQLToDriverType converts mysql type names to sql driver type suitable for // scan /*FIXME: Use sql.ColumnType.DatabaseType instead if this function if go1.8 is * used */ func MySQLToDriverType(mtype string, ftype string) interface{} { switch mtype { case "int", "integer", "tinyint", "smallint", "mediumint": if ftype == types.MySQLBoolean { return new(sql.NullBool) } return new(sql.NullInt64) case "timestamp", "datetime": return new(sql.NullTime) case "bigint", "bit", "year": return new(sql.NullInt64) case "float", "double", "decimal", "numeric": return new(sql.NullFloat64) case "char", "varchar", "json": return new(sql.NullString) case "blob", "tinyblob", "mediumblob", "longblob": return new(sql.RawBytes) case "text", "tinytext", "mediumtext", "longtext", "date", "time", "enum", "set": return new(sql.NullString) default: // "binary", "varbinary" and others return new(sql.RawBytes) } } // PostgresToDriverType converts mysql type names to sql driver type suitable for // scan func PostgresToDriverType(psql string) interface{} { switch psql { case "int2", "int4", "int8": return new(sql.NullInt64) case "float4", "float8", "numeric": return new(sql.NullFloat64) case "text", "varchar": return new(sql.NullString) case "bool": return new(sql.NullBool) default: return new(sql.RawBytes) } } // ClickHouseToDriverType converts mysql type names to sql driver type suitable for // scan func ClickHouseToDriverType(psql string) interface{} { switch psql { case "int8", "int16", "int32", "int64": return new(sql.NullInt64) case "uint8", "uint16", "uint32": // FIXME: uint64 return new(sql.NullInt64) case "float32", "float64": return new(sql.NullFloat64) case "string", "fixedstring": return new(sql.NullString) default: //FIXME: Date, DateTime, Enum, Array return new(sql.RawBytes) } }