pkg/ec2pricing/odpricing.go (261 lines of code) (raw):

// Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ec2pricing import ( "context" "encoding/json" "errors" "fmt" "io" "log" "os" "path/filepath" "strconv" "sync" "time" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/pricing" pricingtypes "github.com/aws/aws-sdk-go-v2/service/pricing/types" "github.com/mitchellh/go-homedir" "github.com/patrickmn/go-cache" "go.uber.org/multierr" ) const ( ODCacheFileName = "on-demand-pricing-cache.json" ) type OnDemandPricing struct { Region string FullRefreshTTL time.Duration DirectoryPath string cache *cache.Cache pricingClient pricing.GetProductsAPIClient logger *log.Logger sync.RWMutex } type PricingList struct { Product PricingListProduct `json:"product"` ServiceCode string `json:"serviceCode"` Terms ProductTerms `json:"terms"` Version string `json:"version"` PublicationDate string `json:"publicationDate"` } type PricingListProduct struct { ProductFamily string `json:"productFamily"` ProductAttributes map[string]string `json:"attributes"` SKU string `json:"sku"` } type ProductTerms struct { OnDemand map[string]ProductPricingInfo `json:"OnDemand"` Reserved map[string]ProductPricingInfo `json:"Reserved"` } type ProductPricingInfo struct { PriceDimensions map[string]PriceDimensionInfo `json:"priceDimensions"` SKU string `json:"sku"` EffectiveDate string `json:"effectiveDate"` OfferTermCode string `json:"offerTermCode"` TermAttributes map[string]string `json:"termAttributes"` } type PriceDimensionInfo struct { Unit string `json:"unit"` EndRange string `json:"endRange"` Description string `json:"description"` AppliesTo []string `json:"appliesTo"` RateCode string `json:"rateCode"` BeginRange string `json:"beginRange"` PricePerUnit map[string]string `json:"pricePerUnit"` } func LoadODCacheOrNew(ctx context.Context, pricingClient pricing.GetProductsAPIClient, region string, fullRefreshTTL time.Duration, directoryPath string) (*OnDemandPricing, error) { expandedDirPath, err := homedir.Expand(directoryPath) if err != nil { return nil, fmt.Errorf("unable to load on-demand pricing cache directory %s: %w", expandedDirPath, err) } odPricing := &OnDemandPricing{ Region: region, FullRefreshTTL: fullRefreshTTL, DirectoryPath: expandedDirPath, pricingClient: pricingClient, cache: cache.New(fullRefreshTTL, fullRefreshTTL), logger: log.New(io.Discard, "", 0), } if fullRefreshTTL <= 0 { if err := odPricing.Clear(); err != nil { return nil, fmt.Errorf("unable to clear od pricing cache due to ttl <= 0 %w", err) } return odPricing, nil } // Start the cache refresh job go odPricing.odCacheRefreshJob(ctx) odCache, err := loadODCacheFrom(fullRefreshTTL, region, expandedDirPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("an on-demand pricing cache file could not be loaded: %v", err) } if err != nil { odCache = cache.New(0, 0) } odPricing.cache = odCache return odPricing, nil } func loadODCacheFrom(itemTTL time.Duration, region string, expandedDirPath string) (*cache.Cache, error) { cacheBytes, err := os.ReadFile(getODCacheFilePath(region, expandedDirPath)) if err != nil { return nil, err } odCache := &map[string]cache.Item{} if err := json.Unmarshal(cacheBytes, odCache); err != nil { return nil, err } c := cache.NewFrom(itemTTL, itemTTL, *odCache) c.DeleteExpired() return c, nil } func getODCacheFilePath(region string, directoryPath string) string { return filepath.Join(directoryPath, fmt.Sprintf("%s-%s", region, ODCacheFileName)) } func (c *OnDemandPricing) odCacheRefreshJob(ctx context.Context) { if c.FullRefreshTTL <= 0 { return } refreshTicker := time.NewTicker(c.FullRefreshTTL) for range refreshTicker.C { if err := c.Refresh(ctx); err != nil { c.logger.Printf("Periodic OD Cache Refresh Error: %v", err) } } } func (c *OnDemandPricing) SetLogger(logger *log.Logger) { c.logger = logger } func (c *OnDemandPricing) Refresh(ctx context.Context) error { c.Lock() defer c.Unlock() odInstanceTypeCosts, err := c.fetchOnDemandPricing(ctx, "") if err != nil { return fmt.Errorf("there was a problem refreshing the on-demand instance type pricing cache: %v", err) } for instanceType, cost := range odInstanceTypeCosts { c.cache.SetDefault(instanceType, cost) } if err := c.Save(); err != nil { return fmt.Errorf("unable to save the refreshed on-demand instance type pricing cache file: %v", err) } return nil } func (c *OnDemandPricing) Get(ctx context.Context, instanceType ec2types.InstanceType) (float64, error) { if cost, ok := c.cache.Get(string(instanceType)); ok { return cost.(float64), nil } c.RLock() defer c.RUnlock() costs, err := c.fetchOnDemandPricing(ctx, instanceType) if err != nil { return 0, fmt.Errorf("there was a problem fetching on-demand instance type pricing for %s: %v", instanceType, err) } c.cache.SetDefault(string(instanceType), costs[string(instanceType)]) return costs[string(instanceType)], nil } // Count of items in the cache. func (c *OnDemandPricing) Count() int { return c.cache.ItemCount() } func (c *OnDemandPricing) Save() error { if c.FullRefreshTTL == 0 || c.Count() == 0 { return nil } cacheBytes, err := json.Marshal(c.cache.Items()) if err != nil { return err } if err := os.Mkdir(c.DirectoryPath, 0o755); err != nil && !errors.Is(err, os.ErrExist) { return err } return os.WriteFile(getODCacheFilePath(c.Region, c.DirectoryPath), cacheBytes, 0600) } func (c *OnDemandPricing) Clear() error { c.Lock() defer c.Unlock() c.cache.Flush() if err := os.Remove(getODCacheFilePath(c.Region, c.DirectoryPath)); err != nil && !os.IsNotExist(err) { return err } return nil } // fetchOnDemandPricing makes a bulk request to the pricing api to retrieve all instance type pricing if the instanceType is the empty string // // or, if instanceType is specified, it can request a specific instance type pricing func (c *OnDemandPricing) fetchOnDemandPricing(ctx context.Context, instanceType ec2types.InstanceType) (map[string]float64, error) { start := time.Now() calls := 0 defer func() { c.logger.Printf("Took %s and %d calls to collect OD pricing", time.Since(start), calls) }() odPricing := map[string]float64{} productInput := pricing.GetProductsInput{ ServiceCode: c.StringMe(serviceCode), Filters: c.getProductsInputFilters(instanceType), } var processingErr error p := pricing.NewGetProductsPaginator(c.pricingClient, &productInput) for p.HasMorePages() { calls++ pricingOutput, err := p.NextPage(ctx) if err != nil { return nil, fmt.Errorf("failed to get next OD pricing page, %w", err) } for _, priceDoc := range pricingOutput.PriceList { instanceTypeName, price, errParse := c.parseOndemandUnitPrice(priceDoc) if errParse != nil { processingErr = multierr.Append(processingErr, errParse) continue } odPricing[instanceTypeName] = price } } return odPricing, processingErr } // StringMe takes an interface and returns a pointer to a string value // If the underlying interface kind is not string or *string then nil is returned. func (c *OnDemandPricing) StringMe(i interface{}) *string { if i == nil { return nil } switch v := i.(type) { case *string: return v case string: return &v default: c.logger.Printf("%s cannot be converted to a string", i) return nil } } func (c *OnDemandPricing) getProductsInputFilters(instanceType ec2types.InstanceType) []pricingtypes.Filter { filters := []pricingtypes.Filter{ {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("ServiceCode"), Value: c.StringMe(serviceCode)}, {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("operatingSystem"), Value: c.StringMe("linux")}, {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("regionCode"), Value: c.StringMe(c.Region)}, {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("capacitystatus"), Value: c.StringMe("used")}, {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("preInstalledSw"), Value: c.StringMe("NA")}, {Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("tenancy"), Value: c.StringMe("shared")}, } if instanceType != "" { filters = append(filters, pricingtypes.Filter{Type: pricingtypes.FilterTypeTermMatch, Field: c.StringMe("instanceType"), Value: c.StringMe(string(instanceType))}) } return filters } // parseOndemandUnitPrice takes a priceList from the pricing API and parses its weirdness. func (c *OnDemandPricing) parseOndemandUnitPrice(priceList string) (string, float64, error) { var productPriceList PricingList err := json.Unmarshal([]byte(priceList), &productPriceList) if err != nil { return "", float64(-1.0), fmt.Errorf("unable to parse pricing doc: %w", err) } attributes := productPriceList.Product.ProductAttributes instanceTypeName := attributes["instanceType"] for _, priceDimensions := range productPriceList.Terms.OnDemand { dim := priceDimensions.PriceDimensions for _, dimension := range dim { pricePerUnit := dimension.PricePerUnit pricePerUnitInUSDStr, ok := pricePerUnit["USD"] if !ok { return instanceTypeName, float64(-1.0), fmt.Errorf("unable to find on-demand price per unit in USD") } var err error pricePerUnitInUSD, err := strconv.ParseFloat(pricePerUnitInUSDStr, 64) if err != nil { return instanceTypeName, float64(-1.0), fmt.Errorf("could not convert price per unit in USD to a float64") } return instanceTypeName, pricePerUnitInUSD, nil } } return instanceTypeName, float64(-1.0), fmt.Errorf("unable to parse pricing doc") }