coverage/expand.go (289 lines of code) (raw):

package coverage import ( "fmt" "path/filepath" "strings" "github.com/go-openapi/loads" openapiSpec "github.com/go-openapi/spec" lru "github.com/hashicorp/golang-lru/v2" "github.com/sirupsen/logrus" ) // http://azure.github.io/autorest/extensions/#x-ms-discriminator-value const msExtensionDiscriminator = "x-ms-discriminator-value" const msExtensionSecret = "x-ms-secret" var ( // {swaggerPath: doc Object} swaggerCache, _ = lru.New[string, *loads.Document](30) // {swaggerPath: {parentModelName: {childModelName: nil}}} allOfTableCache, _ = lru.New[string, map[string]map[string]interface{}](10) ) func loadSwagger(swaggerPath string) (*loads.Document, error) { if doc, ok := swaggerCache.Get(swaggerPath); ok { return doc, nil } doc, err := loads.JSONSpec(swaggerPath) if err != nil { return nil, err } swaggerCache.Add(swaggerPath, doc) return doc, nil } func getAllOfTable(swaggerPath string) (map[string]map[string]interface{}, error) { if vt, ok := allOfTableCache.Get(swaggerPath); ok { return vt, nil } doc, err := loadSwagger(swaggerPath) if err != nil { return nil, err } spec := doc.Spec() allOfTable := map[string]map[string]interface{}{} for k, v := range spec.Definitions { if len(v.AllOf) > 0 { for _, allOf := range v.AllOf { if allOf.Ref.String() != "" { modelName, absPath := SchemaNamePathFromRef(swaggerPath, allOf.Ref) if absPath != swaggerPath { continue } if _, ok := allOfTable[modelName]; !ok { allOfTable[modelName] = map[string]interface{}{} } allOfTable[modelName][k] = nil } } } } allOfTableCache.Add(swaggerPath, allOfTable) return allOfTable, nil } func trimPath(path string) string { if strings.Contains(path, "\\") { return strings.ReplaceAll(path, "\\.\\", "\\") } return strings.ReplaceAll(path, "/./", "/") } func Expand(modelName, swaggerPath string) (*Model, error) { if modelName == "" { return nil, fmt.Errorf("modelName is empty") } swaggerPath = trimPath(swaggerPath) doc, err := loadSwagger(swaggerPath) if err != nil { return nil, err } spec := doc.Spec() modelSchema, ok := spec.Definitions[modelName] if !ok { return nil, fmt.Errorf("%s not found in the definition of %s", modelName, swaggerPath) } output := expandSchema(modelSchema, swaggerPath, modelName, "#", spec, map[string]interface{}{}, map[string]interface{}{}) output.IsRoot = true return output, nil } func expandSchema(input openapiSpec.Schema, swaggerPath, modelName, identifier string, root interface{}, resolvedDiscriminator map[string]interface{}, resolvedModel map[string]interface{}) *Model { output := Model{ Identifier: identifier, ModelName: modelName, SourceFile: swaggerPath, } if _, ok := resolvedModel[modelName]; ok { return &output } resolvedModel[modelName] = nil if len(input.Type) > 0 { output.Type = &input.Type[0] if *output.Type == "boolean" { boolMap := make(map[string]bool) boolMap["true"] = false boolMap["false"] = false output.Bool = &boolMap } } if input.AdditionalProperties != nil { output.HasAdditionalProperties = true } if input.Format != "" { output.Format = &input.Format } if input.ReadOnly { output.IsReadOnly = input.ReadOnly } if isSecretRaw, ok := input.Extensions[msExtensionSecret]; ok && isSecretRaw != nil { if isSecret, ok := isSecretRaw.(bool); ok { output.IsSecret = isSecret } } if input.Enum != nil { enumMap := make(map[string]bool) for _, v := range input.Enum { switch t := v.(type) { case string: enumMap[t] = false case float64: enumMap[fmt.Sprintf("%v", t)] = false case int: enumMap[fmt.Sprintf("%v", t)] = false default: logrus.Errorf("unknown enum type %T", t) enumMap[fmt.Sprintf("%v", t)] = false } } output.Enum = &enumMap } properties := make(map[string]*Model) // expand ref if input.Ref.String() != "" { resolved, err := openapiSpec.ResolveRefWithBase(root, &input.Ref, &openapiSpec.ExpandOptions{RelativeBase: swaggerPath}) if err != nil { logrus.Panicf("resolve ref %s from %s: %+v", input.Ref.String(), swaggerPath, err) } modelName, refSwaggerPath := SchemaNamePathFromRef(swaggerPath, input.Ref) refRoot := root if refSwaggerPath != swaggerPath { doc, err := loadSwagger(refSwaggerPath) if err != nil { logrus.Panicf("load swagger %s: %+v", refSwaggerPath, err) } refRoot = doc.Spec() } referenceModel := expandSchema(*resolved, refSwaggerPath, modelName, identifier, refRoot, resolvedDiscriminator, resolvedModel) if referenceModel.Properties != nil { for k, v := range *referenceModel.Properties { properties[k] = v } } output.ModelName = referenceModel.ModelName if referenceModel.Enum != nil { output.Enum = referenceModel.Enum } if referenceModel.Type != nil { output.Type = referenceModel.Type } if referenceModel.Format != nil { output.Format = referenceModel.Format } if referenceModel.HasAdditionalProperties { output.HasAdditionalProperties = true } if referenceModel.Bool != nil { output.Bool = referenceModel.Bool } if referenceModel.IsReadOnly { output.IsReadOnly = referenceModel.IsReadOnly } if referenceModel.IsRequired { output.IsRequired = referenceModel.IsRequired } if referenceModel.Discriminator != nil { output.Discriminator = referenceModel.Discriminator } if referenceModel.Variants != nil { output.Variants = referenceModel.Variants } if referenceModel.Item != nil { output.Item = referenceModel.Item } } // expand properties for k, v := range input.Properties { properties[k] = expandSchema(v, swaggerPath, fmt.Sprintf("%s.%s", modelName, k), fmt.Sprintf("%s.%s", identifier, k), root, resolvedDiscriminator, resolvedModel) } // expand composition for _, v := range input.AllOf { allOf := expandSchema(v, swaggerPath, fmt.Sprintf("%s.allOf", modelName), identifier, root, resolvedDiscriminator, resolvedModel) if allOf.Properties != nil { for k, v := range *allOf.Properties { properties[k] = v } } // the model should be a variant if its allOf contains a discriminator if allOf.Discriminator != nil { output.Discriminator = allOf.Discriminator variantName := modelName if variantNameRaw, ok := input.Extensions[msExtensionDiscriminator]; ok && variantNameRaw != nil { variantName = variantNameRaw.(string) } output.VariantType = &variantName } } if len(properties) > 0 { for _, v := range input.Required { if p, ok := properties[v]; ok { p.IsRequired = true } else { logrus.Warnf("required property %s not found in %s", v, modelName) } } // check if all properties are readonly allReadOnly := true for _, v := range properties { if !v.IsReadOnly { allReadOnly = false break } } if allReadOnly { output.IsReadOnly = true } output.Properties = &properties } // expand items if input.Items != nil { item := expandSchema(*input.Items.Schema, swaggerPath, fmt.Sprintf("%s[]", modelName), fmt.Sprintf("%s[]", identifier), root, resolvedDiscriminator, resolvedModel) output.Item = item } delete(resolvedModel, modelName) // expand variants if input.Discriminator != "" || output.Discriminator != nil { if _, hasResolvedDiscriminator := resolvedDiscriminator[modelName]; !hasResolvedDiscriminator { allOfTable, err := getAllOfTable(swaggerPath) if err != nil { logrus.Panicf("get variant table %s: %+v", swaggerPath, err) } varSet, ok := allOfTable[modelName] if ok { resolvedDiscriminator[modelName] = nil variants := map[string]*Model{} // level order traverse to find all variants for len(varSet) > 0 { tempVarSet := make(map[string]interface{}) for variantModelName := range varSet { schema := root.(*openapiSpec.Swagger).Definitions[variantModelName] variantName := variantModelName if variantNameRaw, ok := schema.Extensions[msExtensionDiscriminator]; ok && variantNameRaw != nil { variantName = variantNameRaw.(string) } resolved := expandSchema(schema, swaggerPath, variantModelName, fmt.Sprintf("%s{%s}", identifier, variantName), root, resolvedDiscriminator, resolvedModel) resolved.VariantType = &variantName // in case of https://github.com/Azure/azure-rest-api-specs/issues/25104, use modelName as key variants[variantModelName] = resolved if varVarSet, ok := allOfTable[variantModelName]; ok { for v := range varVarSet { tempVarSet[v] = nil } } } varSet = tempVarSet } delete(resolvedDiscriminator, modelName) if input.Discriminator != "" { output.Discriminator = &input.Discriminator } output.Variants = &variants } } } return &output } func SchemaNamePathFromRef(swaggerPath string, ref openapiSpec.Ref) (schemaName string, schemaPath string) { refUrl := ref.GetURL() if refUrl == nil { return "", "" } schemaPath = refUrl.Path if schemaPath == "" { schemaPath = swaggerPath } else { swaggerPath, _ := filepath.Split(swaggerPath) schemaPath = swaggerPath + schemaPath schemaPath = trimPath(schemaPath) } fragments := strings.Split(refUrl.Fragment, "/") return fragments[len(fragments)-1], schemaPath }