api/internal/filter/schema.go (194 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package filter import ( "bytes" "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "io/ioutil" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/apisix/manager-api/internal/core/entity" "github.com/apisix/manager-api/internal/core/store" "github.com/apisix/manager-api/internal/log" "github.com/apisix/manager-api/internal/utils/consts" ) var resources = map[string]string{ "routes": "route", "upstreams": "upstream", "services": "service", "consumers": "consumer", "ssl": "ssl", "global_rules": "global_rule", "proto": "proto", } const ( StatusDisable entity.Status = iota StatusEnable ) func parseCert(crt, key string) ([]string, error) { if crt == "" || key == "" { return nil, errors.New("empty certificate or private key") } certDERBlock, _ := pem.Decode([]byte(crt)) if certDERBlock == nil { return nil, errors.New("Certificate resolution failed") } // match _, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { return nil, errors.New("key and cert don't match") } x509Cert, err := x509.ParseCertificate(certDERBlock.Bytes) if err != nil { return nil, errors.New("Certificate resolution failed") } //domain var snis []string if x509Cert.DNSNames != nil && len(x509Cert.DNSNames) > 0 { snis = x509Cert.DNSNames } else if x509Cert.IPAddresses != nil && len(x509Cert.IPAddresses) > 0 { for _, ip := range x509Cert.IPAddresses { snis = append(snis, ip.String()) } } else { if x509Cert.Subject.Names != nil && len(x509Cert.Subject.Names) > 0 { var attributeTypeNames = map[string]string{ "2.5.4.6": "C", "2.5.4.10": "O", "2.5.4.11": "OU", "2.5.4.3": "CN", "2.5.4.5": "SERIALNUMBER", "2.5.4.7": "L", "2.5.4.8": "ST", "2.5.4.9": "STREET", "2.5.4.17": "POSTALCODE", } for _, tv := range x509Cert.Subject.Names { oidString := tv.Type.String() typeName, ok := attributeTypeNames[oidString] if ok && typeName == "CN" { valueString := fmt.Sprint(tv.Value) snis = append(snis, valueString) } } } } return snis, nil } func handleSpecialField(resource string, reqBody []byte) ([]byte, error) { var bodyMap map[string]interface{} err := json.Unmarshal(reqBody, &bodyMap) if err != nil { return reqBody, fmt.Errorf("read request body failed: %s", err) } if _, ok := bodyMap["create_time"]; ok { return reqBody, errors.New("we don't accept create_time from client") } if _, ok := bodyMap["update_time"]; ok { return reqBody, errors.New("we don't accept update_time from client") } // remove script, because it's a map, and need to be parsed into lua code if resource == "routes" { var route map[string]interface{} err := json.Unmarshal(reqBody, &route) if err != nil { return nil, fmt.Errorf("read request body failed: %s", err) } if _, ok := route["script"]; ok { delete(route, "script") reqBody, err = json.Marshal(route) if err != nil { return nil, fmt.Errorf("read request body failed: %s", err) } } } // SSL does not need to pass sni, we need to parse the SSL to get sni if resource == "ssl" { var ssl map[string]interface{} err := json.Unmarshal(reqBody, &ssl) if err != nil { return nil, fmt.Errorf("read request body failed: %s", err) } ssl["snis"], err = parseCert(ssl["cert"].(string), ssl["key"].(string)) if err != nil { return nil, fmt.Errorf("SSL parse failed: %s", err) } reqBody, err = json.Marshal(ssl) if err != nil { return nil, fmt.Errorf("read request body failed: %s", err) } } return reqBody, nil } func handleDefaultValue(resource string, reqBody []byte) ([]byte, error) { // go jsonschema lib doesn't support setting default values, so we need to set for some fields necessary if resource == "routes" { var route map[string]interface{} err := json.Unmarshal(reqBody, &route) if err != nil { return reqBody, fmt.Errorf("read request body failed: %s", err) } if _, ok := route["status"]; !ok { route["status"] = StatusEnable reqBody, err = json.Marshal(route) if err != nil { return nil, fmt.Errorf("read request body failed: %s", err) } } } return reqBody, nil } func SchemaCheck() gin.HandlerFunc { return func(c *gin.Context) { pathPrefix := "/apisix/admin/" resource := strings.TrimPrefix(c.Request.URL.Path, pathPrefix) idx := strings.LastIndex(resource, "/") if idx > 1 { resource = resource[:idx] } method := strings.ToUpper(c.Request.Method) if method != "PUT" && method != "POST" { c.Next() return } schemaKey, ok := resources[resource] if !ok { c.Next() return } reqBody, err := c.GetRawData() if err != nil { log.Errorf("read request body failed: %s", err) c.AbortWithStatusJSON(http.StatusBadRequest, consts.ErrInvalidRequest) return } // set default value reqBody, err = handleDefaultValue(resource, reqBody) if err != nil { errMsg := err.Error() c.AbortWithStatusJSON(http.StatusBadRequest, consts.InvalidParam(errMsg)) log.Error(errMsg) return } // other filter need it c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) validator, err := store.NewAPISIXSchemaValidator("main." + schemaKey) if err != nil { errMsg := err.Error() c.AbortWithStatusJSON(http.StatusBadRequest, consts.InvalidParam(errMsg)) log.Error(errMsg) return } reqBody, err = handleSpecialField(resource, reqBody) if err != nil { errMsg := err.Error() c.AbortWithStatusJSON(http.StatusBadRequest, consts.InvalidParam(errMsg)) log.Error(errMsg) return } if err := validator.Validate(reqBody); err != nil { errMsg := err.Error() c.AbortWithStatusJSON(http.StatusBadRequest, consts.InvalidParam(errMsg)) log.Warn(errMsg) return } c.Next() } }