tooling/image-sync/internal/repository.go (401 lines of code) (raw):
// Copyright 2025 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
)
// Registry is the interface for accessing image repositories
type Registry interface {
GetTags(context.Context, string) ([]string, error)
}
// AuthedTransport is a http.RoundTripper that adds an Authorization header
type AuthedTransport struct {
Key string
Wrapped http.RoundTripper
}
// RoundTrip implements http.RoundTripper and sets Authorization header
func (t *AuthedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", t.Key)
return t.Wrapped.RoundTrip(req)
}
// QuayRegistry implements Quay Repository access
type QuayRegistry struct {
httpclient *http.Client
baseUrl string
numberOftags int
}
// NewQuayRegistry creates a new QuayRegistry access client
func NewQuayRegistry(cfg *SyncConfig, bearerToken string) *QuayRegistry {
q := &QuayRegistry{
httpclient: &http.Client{Timeout: time.Duration(cfg.RequestTimeout) * time.Second,
Transport: &AuthedTransport{
Key: "Bearer " + bearerToken,
Wrapped: http.DefaultTransport,
},
},
baseUrl: "https://quay.io",
numberOftags: cfg.NumberOfTags,
}
return q
}
type TagsResponse struct {
Tags []Tags
Page int
HasAdditional bool `json:"has_additional"`
}
type Tags struct {
Name string
}
func (q *QuayRegistry) getTagPage(ctx context.Context, image string, page int) (*TagsResponse, error) {
path := fmt.Sprintf("%s/api/v1/repository/%s/tag/?limit=100&page=%s", q.baseUrl, image, strconv.Itoa(page))
req, err := http.NewRequestWithContext(ctx, "GET", path, nil)
Log().Debugw("Sending request", "path", path)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
resp, err := q.httpclient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
Log().Debugw("Got response", "statuscode", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
var tagsResponse TagsResponse
err = json.Unmarshal(body, &tagsResponse)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &tagsResponse, nil
}
// GetTags returns the tags for the given image
func (q *QuayRegistry) GetTags(ctx context.Context, image string) ([]string, error) {
Log().Debugw("Getting tags for image", "image", image)
var tags []string
hasAdditional := true
// hard coded limit of 100, to make sure process does not get stuck
for page := 1; len(tags) < q.numberOftags && hasAdditional && page < 100; page++ {
tagsResponse, err := q.getTagPage(ctx, image, page)
if err != nil {
return nil, fmt.Errorf("failed to get tags: %v", err)
}
for _, tag := range tagsResponse.Tags {
if tag.Name == "latest" {
continue
}
tags = append(tags, tag.Name)
// Check length again, cause pagesize might be way bigger than number of tags
if len(tags) >= q.numberOftags {
return tags, nil
}
}
if !tagsResponse.HasAdditional {
break
}
}
return tags, nil
}
type getAccessToken func(context.Context, azcore.TokenCredential) (string, error)
type getACRUrl func(string) string
// AzureContainerRegistry implements ACR Repository access
type AzureContainerRegistry struct {
acrName string
credential azcore.TokenCredential
acrClient *azcontainerregistry.Client
httpClient *http.Client
numberOfTags int
tenantId string
getAccessTokenImpl getAccessToken
getACRUrlImpl getACRUrl
}
// NewAzureContainerRegistry creates a new AzureContainerRegistry access client
func NewAzureContainerRegistry(cfg *SyncConfig) *AzureContainerRegistry {
var cred azcore.TokenCredential
var err error
if cfg.ManagedIdentityClientID != "" {
cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(cfg.ManagedIdentityClientID),
})
if err != nil {
Log().Fatalf("failed to obtain a credentials for managed identity %s: %v", cfg.ManagedIdentityClientID, err)
}
} else {
cred, err = azidentity.NewDefaultAzureCredential(nil)
if err != nil {
Log().Fatalf("failed to obtain default credentials: %v", err)
}
}
client, err := azcontainerregistry.NewClient(fmt.Sprintf("https://%s", cfg.AcrTargetRegistry), cred, nil)
if err != nil {
Log().Fatalf("failed to create client: %v", err)
}
return &AzureContainerRegistry{
acrName: cfg.AcrTargetRegistry,
acrClient: client,
credential: cred,
httpClient: &http.Client{Timeout: time.Duration(cfg.RequestTimeout) * time.Second},
numberOfTags: cfg.NumberOfTags,
tenantId: cfg.TenantId,
getAccessTokenImpl: func(ctx context.Context, dac azcore.TokenCredential) (string, error) {
accessToken, err := dac.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{"https://management.core.windows.net//.default"}})
if err != nil {
return "", err
}
return accessToken.Token, nil
},
getACRUrlImpl: func(acrName string) string {
return fmt.Sprintf("https://%s", acrName)
},
}
}
type AuthSecret struct {
RefreshToken string `json:"refresh_token"`
}
func (a *AzureContainerRegistry) createOauthRequest(ctx context.Context, accessToken string) (*http.Request, error) {
path := fmt.Sprintf("%s/oauth2/exchange/", a.getACRUrlImpl(a.acrName))
form := url.Values{}
form.Add("grant_type", "access_token")
form.Add("service", a.acrName)
form.Add("tenant", a.tenantId)
form.Add("access_token", accessToken)
Log().Debugw("Creating request", "path", path)
req, err := http.NewRequestWithContext(ctx, "POST", path, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
return req, nil
}
func (a *AzureContainerRegistry) GetPullSecret(ctx context.Context) (*AuthSecret, error) {
accessToken, err := a.getAccessTokenImpl(ctx, a.credential)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %v", err)
}
req, err := a.createOauthRequest(ctx, accessToken)
if err != nil {
return nil, fmt.Errorf("failed to create OAuth request: %v", err)
}
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
var authSecret AuthSecret
err = json.Unmarshal(body, &authSecret)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &authSecret, nil
}
// EnsureRepositoryExists ensures that the repository exists
func (a *AzureContainerRegistry) RepositoryExists(ctx context.Context, repository string) (bool, error) {
pager := a.acrClient.NewListRepositoriesPager(nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return false, fmt.Errorf("failed to advance page: %v", err)
}
for _, v := range page.Names {
if *v == repository {
return true, nil
}
}
}
return false, nil
}
func ptr[T any](v T) *T {
return &v
}
// GetTags returns the tags in the given repository
func (a *AzureContainerRegistry) GetTags(ctx context.Context, repository string) ([]string, error) {
var tags []string
pager := a.acrClient.NewListTagsPager(repository, &azcontainerregistry.ClientListTagsOptions{OrderBy: ptr(azcontainerregistry.ArtifactTagOrderByLastUpdatedOnDescending)})
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to advance page: %v", err)
}
for _, v := range page.Tags {
if *v.Name == "latest" {
continue
}
tags = append(tags, *v.Name)
}
if len(tags) >= a.numberOfTags {
break
}
}
return tags, nil
}
type ACRWithTokenAuth struct {
httpclient *http.Client
acrName string
numberOftags int
bearerToken string
}
type AccessSecret struct {
AccessToken string `json:"access_token"`
}
type rawACRTagResponse struct {
Tags []rawACRTags
}
type rawACRTags struct {
Name string
}
func getACRBearerToken(ctx context.Context, secret AzureSecretFile, acrName string) (string, error) {
scope := "repository:*:*"
path := fmt.Sprintf("https://%s/oauth2/token?service=%s&scope=%s", acrName, acrName, scope)
Log().Debugw("Creating request", "path", path)
req, err := http.NewRequestWithContext(ctx, "GET", path, nil)
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", secret.BasicAuthEncoded()))
if err != nil {
return "", fmt.Errorf("failed to create request: %v", err)
}
// todo replace with timeout enabled client
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %v", err)
}
var accessSecret AccessSecret
err = json.Unmarshal(body, &accessSecret)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response: %v", err)
}
return accessSecret.AccessToken, nil
}
func NewACRWithTokenAuth(cfg *SyncConfig, acrName string, bearerToken string) *ACRWithTokenAuth {
return &ACRWithTokenAuth{
httpclient: &http.Client{Timeout: time.Duration(cfg.RequestTimeout) * time.Second},
acrName: acrName,
bearerToken: bearerToken,
numberOftags: cfg.NumberOfTags,
}
}
func (n *ACRWithTokenAuth) GetTags(ctx context.Context, image string) ([]string, error) {
Log().Debugw("Getting tags for image", "image", image)
path := fmt.Sprintf("https://%s/acr/v1/%s/_tags?orderby=%s&n=%d", n.acrName, image, azcontainerregistry.ArtifactTagOrderByLastUpdatedOnDescending, n.numberOftags)
req, err := http.NewRequestWithContext(ctx, "GET", path, nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", n.bearerToken))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
Log().Debugw("Sending request", "path", path)
resp, err := n.httpclient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
Log().Debugw("Got response", "statuscode", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
var acrResponse rawACRTagResponse
err = json.Unmarshal(body, &acrResponse)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
tagList := make([]string, 0)
for _, tag := range acrResponse.Tags {
tagList = append(tagList, tag.Name)
}
return tagList, nil
}
// OCIRegistry implements OCI Repository access
type OCIRegistry struct {
httpclient *http.Client
baseURL string
numberOftags int
bearerToken string
}
// NewOCIRegistry creates a new OCIRegistry access client
func NewOCIRegistry(cfg *SyncConfig, baseURL, bearerToken string) *OCIRegistry {
o := &OCIRegistry{
httpclient: &http.Client{Timeout: time.Duration(cfg.RequestTimeout) * time.Second},
numberOftags: cfg.NumberOfTags,
bearerToken: bearerToken,
}
if !strings.HasPrefix(o.baseURL, "https://") {
o.baseURL = fmt.Sprintf("https://%s", baseURL)
} else {
o.baseURL = baseURL
}
return o
}
type rawManifest struct {
TimeUploadedMs string
Tag []string
}
type rawOCIResponse struct {
Manifest map[string]rawManifest
Tags []string
}
func getNewestTags(response *rawOCIResponse, numberOfTags int) ([]string, error) {
var returnTags []string
uploadedTagAt := make(map[int][]string)
uploadTimes := make([]int, 0, len(response.Manifest))
for _, manifest := range response.Manifest {
if len(manifest.Tag) == 0 {
continue
}
uploadedAt, err := strconv.Atoi(manifest.TimeUploadedMs)
if err != nil {
return nil, fmt.Errorf("failed to parse manifest %s time: %v", manifest, err)
}
uploadedTagAt[uploadedAt] = manifest.Tag
uploadTimes = append(uploadTimes, uploadedAt)
}
sort.Sort(sort.Reverse(sort.IntSlice(uploadTimes)))
for i, k := range uploadTimes {
if i >= numberOfTags {
break
}
returnTags = append(returnTags, uploadedTagAt[k]...)
}
return returnTags, nil
}
// GetTags returns the tags in the given repository
func (o *OCIRegistry) GetTags(ctx context.Context, image string) ([]string, error) {
Log().Debugw("Getting tags for image", "image", image)
path := fmt.Sprintf("%s/v2/%s/tags/list", o.baseURL, image)
req, err := http.NewRequestWithContext(ctx, "GET", path, nil)
if o.bearerToken != "" {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", o.bearerToken))
}
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
Log().Debugw("Sending request", "path", path)
resp, err := o.httpclient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
Log().Debugw("Got response", "statuscode", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
var rawOCIResponse rawOCIResponse
err = json.Unmarshal(body, &rawOCIResponse)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return getNewestTags(&rawOCIResponse, o.numberOftags)
}