tools/ipam-autopilot/container/data_access.go (332 lines of code) (raw):

// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "database/sql" "fmt" "log" "net" "strings" "github.com/apparentlymart/go-cidr/cidr" "github.com/jackc/pgtype" _ "github.com/golang-migrate/migrate/v4/source/file" ) type RoutingDomain struct { Id int `db:"routing_domain_id"` Name string `db:"name"` Vpcs string `db:"vpcs"` // associated VPCs that should be tracked for subnet creation } type Range struct { Subnet_id int `db:"subnet_id"` Parent_id int `db:"parent_id"` Routing_domain_id int `db:"routing_domain_id"` Name string `db:"name"` Cidr string `db:"cidr"` } func GetRangesFromDB() ([]Range, error) { var ranges []Range rows, err := db.Query("SELECT subnet_id, parent_id, routing_domain_id, name, cidr FROM subnets") if err != nil { return nil, err } for rows.Next() { var subnet_id int var routing_domain_id int tmp := pgtype.Int4{} var name string var cidr string err := rows.Scan(&subnet_id, &tmp, &routing_domain_id, &name, &cidr) if err != nil { return nil, err } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } ranges = append(ranges, Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }) } return ranges, nil } func GetRangesForParentFromDB(tx *sql.Tx, parent_id int64) ([]Range, error) { var ranges []Range rows, err := tx.Query("SELECT subnet_id, parent_id, routing_domain_id, name, cidr FROM subnets WHERE parent_id = ? FOR UPDATE", parent_id) if err != nil { return nil, err } for rows.Next() { var subnet_id int var routing_domain_id int tmp := pgtype.Int4{} var name string var cidr string err := rows.Scan(&subnet_id, &tmp, &routing_domain_id, &name, &cidr) if err != nil { return nil, err } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } ranges = append(ranges, Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }) } return ranges, nil } func GetRangeFromDB(id int64) (*Range, error) { var subnet_id int var routing_domain_id int tmp := pgtype.Int4{} var name string var cidr string err := db.QueryRow("SELECT subnet_id, parent_id, routing_domain_id, name, cidr FROM subnets WHERE subnet_id = ?", id).Scan(&subnet_id, &tmp, &routing_domain_id, &name, &cidr) if err != nil { return nil, err } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } return &Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }, nil } func GetRangeFromDBWithTx(tx *sql.Tx, id int64) (*Range, error) { var subnet_id int var routing_domain_id int tmp := pgtype.Int4{} var name string var cidr string err := tx.QueryRow("SELECT subnet_id, parent_id, routing_domain_id, name, cidr FROM subnets WHERE subnet_id = ? FOR UPDATE", id).Scan(&subnet_id, &tmp, &routing_domain_id, &name, &cidr) if err != nil { return nil, err } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } return &Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }, nil } func getRangeByCidrAndRoutingDomain(tx *sql.Tx, request_cidr string, routing_domain_id int) (*Range, error) { var subnet_id int tmp := pgtype.Int4{} var name string var cidr string err := tx.QueryRow("SELECT subnet_id, parent_id, name, cidr FROM subnets WHERE cidr = ? and routing_domain_id = ? FOR UPDATE", request_cidr, routing_domain_id).Scan(&subnet_id, &tmp, &name, &cidr) if err != nil { return nil, err } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } return &Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }, nil } func GetRangeByCidrFromDB(tx *sql.Tx, routing_domain_id int, cidr_request string) (*Range, error) { var subnet_id int tmp := pgtype.Int4{} var name string var cidr string if cidr_request != "" { err := tx.QueryRow("SELECT subnet_id, parent_id, name, cidr FROM subnets WHERE cidr = ? and routing_domain_id = ? FOR UPDATE", cidr_request, routing_domain_id).Scan(&subnet_id, &tmp, &name, &cidr) if err != nil { return nil, err } } else { err := tx.QueryRow("SELECT subnet_id, parent_id, name, cidr FROM subnets WHERE routing_domain_id = ? LIMIT 1 FOR UPDATE", routing_domain_id).Scan(&subnet_id, &tmp, &name, &cidr) if err != nil { return nil, err } } parent_id := -1 if tmp.Status == pgtype.Present { tmp.AssignTo(&parent_id) } return &Range{ Subnet_id: subnet_id, Parent_id: parent_id, Routing_domain_id: routing_domain_id, Name: name, Cidr: cidr, }, nil } func DeleteRangeFromDb(id int64) error { _, err := db.Query("DELETE FROM subnets WHERE subnet_id = ?", id) if err != nil { return err } return nil } func DeleteRoutingDomainFromDB(id int64) error { _, err := db.Query("DELETE FROM routing_domains WHERE routing_domain_id = ?", id) if err != nil { return err } return nil } func CreateRangeInDb(tx *sql.Tx, parent_id int64, routing_domain_id int, name string, cidr string) (int64, error) { if parent_id == -1 { res, err := tx.Exec("INSERT INTO subnets (routing_domain_id, name, cidr) VALUES (?,?,?);", routing_domain_id, name, cidr) if err != nil { return -1, err } subnet_id, err := res.LastInsertId() if err != nil { return -1, err } return subnet_id, nil } else { res, err := tx.Exec("INSERT INTO subnets (parent_id, routing_domain_id, name, cidr) VALUES (?,?,?,?);", parent_id, routing_domain_id, name, cidr) if err != nil { return -1, err } subnet_id, err := res.LastInsertId() if err != nil { return -1, err } return subnet_id, nil } } func createNewSubnetLease(prevCidr string, range_size int, subnetIndex int) (*net.IPNet, int, error) { _, network, err := net.ParseCIDR(prevCidr) if err != nil { return nil, -1, fmt.Errorf("unable to calculate subnet %v", err) } ones, size := network.Mask.Size() subnet, err := cidr.Subnet(network, int(range_size)-ones, subnetIndex) if err != nil { return nil, -1, fmt.Errorf("unable to calculate subnet %v", err) } subnet.Mask = net.CIDRMask(range_size, size) return subnet, range_size, nil } func verifyNoOverlap(parentCidr string, subnetRanges []Range, newSubnet *net.IPNet) error { _, parentNetwork, err := net.ParseCIDR(parentCidr) if err != nil { return fmt.Errorf("can't parse CIDR %v", err) } log.Printf("Checking Overlap\nparentCidr:\t%s", parentCidr) log.Printf("newSubnet:\t%s/%d", newSubnet.IP.String(), netMask(newSubnet.Mask)) for i := 0; i < len(subnetRanges); i++ { subnetRange := subnetRanges[i] netAddr, subnetCidr, err := net.ParseCIDR(subnetRange.Cidr) if err != nil { return fmt.Errorf("can't parse CIDR %v", err) } if parentNetwork.Contains(netAddr) { err = cidr.VerifyNoOverlap([]*net.IPNet{subnetCidr, newSubnet}, parentNetwork) if err != nil { return err } } } return nil } func netMask(mask net.IPMask) int { ones, _ := mask.Size() return ones } func GetDefaultRoutingDomainFromDB(tx *sql.Tx) (*RoutingDomain, error) { var routing_domain_id int var name string var vpcs sql.NullString err := tx.QueryRow("SELECT routing_domain_id, name, vpcs FROM routing_domains LIMIT 1 FOR UPDATE").Scan(&routing_domain_id, &name, &vpcs) if err != nil { return nil, err } return &RoutingDomain{ Id: routing_domain_id, Name: name, Vpcs: vpcs.String, }, nil } func GetRoutingDomainsFromDB() ([]RoutingDomain, error) { var domains []RoutingDomain rows, err := db.Query("SELECT routing_domain_id, name, vpcs FROM routing_domains") if err != nil { return nil, err } for rows.Next() { var routing_domain_id int var name string var vpcs sql.NullString err := rows.Scan(&routing_domain_id, &name, &vpcs) if err != nil { return nil, err } domains = append(domains, RoutingDomain{ Id: routing_domain_id, Name: name, Vpcs: vpcs.String, }) } return domains, nil } func GetRoutingDomainFromDB(id int64) (*RoutingDomain, error) { var routing_domain_id int var name string var vpcs sql.NullString err := db.QueryRow("SELECT routing_domain_id, name, vpcs FROM routing_domains WHERE routing_domain_id = ?", id).Scan(&routing_domain_id, &name, &vpcs) if err != nil { return nil, err } return &RoutingDomain{ Id: routing_domain_id, Name: name, Vpcs: vpcs.String, }, nil } func UpdateRoutingDomainOnDb(id int64, name JSONString, vpcs JSONStringArray) error { if name.Set && vpcs.Set { _, err := db.Query("UPDATE routing_domains SET name = ?, vpcs = ? WHERE routing_domain_id = ?", name.Value, strings.Join(vpcs.Value, ","), id) if err != nil { return err } } else if vpcs.Set { _, err := db.Query("UPDATE routing_domains SET vpcs = ? WHERE routing_domain_id = ?", strings.Join(vpcs.Value, ","), id) if err != nil { return err } } else if name.Set { _, err := db.Query("UPDATE routing_domains SET name = ? WHERE routing_domain_id = ?", name.Value, id) if err != nil { return err } } return nil } func CreateRoutingDomainOnDb(name string, vpcs []string) (int64, error) { res, err := db.Exec("INSERT INTO routing_domains (name, vpcs) VALUES (?,?);", name, strings.Join(vpcs, ",")) if err != nil { return -1, err } domain_id, err := res.LastInsertId() if err != nil { return -1, err } return domain_id, nil }