pkg/api/api.go (264 lines of code) (raw):
package api
import (
"bufio"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"sync"
"time"
"github.com/mileusna/useragent"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
// BulkResponse is an Elastic Search Bulk Response, assuming
// filter_path is "errors,items.*.error,items.*.status"
type BulkResponse struct {
Errors bool `json:"errors"`
Items []map[string]any `json:"items,omitempty"`
}
// APIHandler struct. Use NewAPIHandler to make sure it is filled in correctly for use.
type APIHandler struct {
ActionOdds [100]int
MethodOdds [100]int
UUID fmt.Stringer
ClusterUUID string
Expire time.Time
Delay time.Duration
metrics *metrics
history []*RequestRecord
historyIndex int
historyMu sync.Mutex
configMu sync.RWMutex
}
// RequestRecord is a record of a request
type RequestRecord struct {
Method string `json:"method"`
URI string `json:"uri"`
Body string `json:"body"`
}
// NewAPIHandler return handler with Action and Method Odds array filled in
func NewAPIHandler(
uuid fmt.Stringer,
clusterUUID string,
meterProvider metric.MeterProvider,
expire time.Time,
delay time.Duration,
percentDuplicate,
percentTooMany,
percentNonIndex,
percentTooLarge,
historyCap uint,
) *APIHandler {
h := &APIHandler{UUID: uuid, Expire: expire, ClusterUUID: clusterUUID, Delay: delay}
if meterProvider == nil {
meterProvider = otel.GetMeterProvider()
}
metrics, err := newMetrics(meterProvider)
if err != nil {
panic(fmt.Errorf("failed to create metrics"))
}
h.metrics = metrics
h.history = make([]*RequestRecord, historyCap)
err = h.UpdateOdds(percentDuplicate, percentTooMany, percentNonIndex, percentTooLarge)
if err != nil {
panic(err)
}
return h
}
func (h *APIHandler) UpdateOdds(
percentDuplicate,
percentTooMany,
percentNonIndex,
percentTooLarge uint,
) error {
h.configMu.Lock()
defer h.configMu.Unlock()
if int((percentDuplicate + percentTooMany + percentNonIndex)) > len(h.ActionOdds) {
return fmt.Errorf("Total of percents can't be greater than %d", len(h.ActionOdds))
}
if int(percentTooLarge) > len(h.MethodOdds) {
return fmt.Errorf("percent TooLarge cannot be greater than %d", len(h.MethodOdds))
}
// Fill in ActionOdds
n := 0
for i := uint(0); i < percentDuplicate; i++ {
h.ActionOdds[n] = http.StatusConflict
n++
}
for i := uint(0); i < percentTooMany; i++ {
h.ActionOdds[n] = http.StatusTooManyRequests
n++
}
for i := uint(0); i < percentNonIndex; i++ {
h.ActionOdds[n] = http.StatusNotAcceptable
n++
}
for ; n < len(h.ActionOdds); n++ {
h.ActionOdds[n] = http.StatusOK
}
// Fill in MethodOdds
n = 0
for i := uint(0); i < percentTooLarge; i++ {
h.MethodOdds[n] = http.StatusRequestEntityTooLarge
n++
}
for ; n < len(h.MethodOdds); n++ {
h.MethodOdds[n] = http.StatusOK
}
return nil
}
// ServeHTTP looks at the request and routes it to the correct handler function
func (h *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
time.Sleep(h.Delay)
// required for official clients to recognize this as a valid endpoint.
w.Header().Set("X-Elastic-Product", "Elasticsearch")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/":
h.Root(w, r)
return
case r.Method == http.MethodPost && r.URL.Path == "/_bulk":
h.Bulk(w, r)
return
case r.Method == http.MethodGet && r.URL.Path == "/_license":
h.License(w, r)
return
case r.Method == http.MethodGet && r.URL.Path == "/_history":
h.History(w, r)
return
default:
w.Write([]byte("{\"tagline\": \"You Know, for Testing\"}"))
return
}
}
// Bulk handles bulk posts
func (h *APIHandler) Bulk(w http.ResponseWriter, r *http.Request) {
h.configMu.RLock()
defer h.configMu.RUnlock()
attrs := metric.WithAttributeSet(requestAttributes(r))
h.metrics.bulkCreateTotalMetrics.Add(context.Background(), 1, attrs)
methodStatus := h.MethodOdds[rand.Intn(len(h.MethodOdds))]
if methodStatus == http.StatusRequestEntityTooLarge {
h.metrics.bulkCreateTooLargeMetrics.Add(context.Background(), 1, attrs)
w.WriteHeader(methodStatus)
return
}
var scanner *bufio.Scanner
br := BulkResponse{}
encoding, prs := r.Header[http.CanonicalHeaderKey("Content-Encoding")]
switch {
case prs && encoding[0] == "gzip":
zr, err := gzip.NewReader(r.Body)
if err != nil {
log.Printf("error new gzip reader failed: %s", err)
return
}
scanner = bufio.NewScanner(zr)
default:
scanner = bufio.NewScanner(r.Body)
}
// bulk requests come in as 2 lines
// the action on first line, followed by the document on the next line.
// we only care about the action, which is why we have skipNextLine var
// eg:
// { "update": {"_id": "5", "_index": "index1"} }
// { "doc": {"my_field": "baz"} }
var skipNextLine bool
var body []byte
for scanner.Scan() {
b := scanner.Bytes()
body = append(body, b...)
if skipNextLine || len(b) == 0 {
skipNextLine = false
continue
}
var j map[string]any
err := json.Unmarshal(b, &j)
if err != nil {
log.Printf("error unmarshal: %s", err)
continue
}
if len(j) != 1 {
log.Printf("error, number of keys off: %d should be 1", len(j))
continue
}
for k := range j {
switch k {
case "index":
h.metrics.bulkIndexTotalMetrics.Add(context.Background(), 1, attrs)
skipNextLine = true
case "create":
skipNextLine = true
actionStatus := h.ActionOdds[rand.Intn(len(h.ActionOdds))]
switch actionStatus {
case http.StatusOK:
h.metrics.bulkCreateOkMetrics.Add(context.Background(), 1, attrs)
case http.StatusConflict:
br.Errors = true
h.metrics.bulkCreateDuplicateMetrics.Add(context.Background(), 1, attrs)
case http.StatusTooManyRequests:
br.Errors = true
h.metrics.bulkCreateTooManyMetrics.Add(context.Background(), 1, attrs)
case http.StatusNotAcceptable:
br.Errors = true
h.metrics.bulkCreateNonIndexMetrics.Add(context.Background(), 1, attrs)
}
br.Items = append(br.Items, map[string]any{"created": map[string]any{"status": actionStatus}})
case "update":
h.metrics.bulkUpdateTotalMetrics.Add(context.Background(), 1, attrs)
skipNextLine = true
case "delete":
h.metrics.bulkDeleteTotalMetrics.Add(context.Background(), 1, attrs)
skipNextLine = false
}
}
}
h.recordRequest(r, body)
brBytes, err := json.Marshal(br)
if err != nil {
log.Printf("error marshal bulk reply: %s", err)
return
}
w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json")
w.Write(brBytes)
return
}
// Root handles / get requests
func (h *APIHandler) Root(w http.ResponseWriter, r *http.Request) {
h.recordRequest(r, nil)
h.metrics.rootTotalMetrics.Add(context.Background(), 1, metric.WithAttributeSet(requestAttributes(r)))
ua := useragent.Parse(r.Header.Get("User-Agent"))
root := fmt.Sprintf("{\"name\" : \"mock\", \"cluster_uuid\" : \"%s\", \"version\" : { \"number\" : \"%s\", \"build_flavor\" : \"default\"}}", h.ClusterUUID, ua.VersionNoFull())
w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json")
w.Write([]byte(root))
return
}
// License handles /_license get requests
func (h *APIHandler) License(w http.ResponseWriter, r *http.Request) {
h.recordRequest(r, nil)
h.metrics.licenseTotalMetrics.Add(context.Background(), 1, metric.WithAttributeSet(requestAttributes(r)))
license := fmt.Sprintf("{\"license\" : {\"status\" : \"active\", \"uid\" : \"%s\", \"type\" : \"trial\", \"expiry_date_in_millis\" : %d}}", h.UUID.String(), h.Expire.UnixMilli())
w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json")
w.Write([]byte(license))
return
}
// History handles /_history get requests
func (h *APIHandler) History(w http.ResponseWriter, r *http.Request) {
w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json")
w.WriteHeader(http.StatusOK)
h.historyMu.Lock()
defer h.historyMu.Unlock()
nonNilHist := make([]RequestRecord, 0)
for _, v := range h.history {
if v != nil {
nonNilHist = append(nonNilHist, *v)
}
}
json.NewEncoder(w).Encode(nonNilHist)
return
}
// RequestHistory returns a list of all requests made to the handler
func (h *APIHandler) RequestHistory() []*RequestRecord {
h.historyMu.Lock()
defer h.historyMu.Unlock()
return h.history
}
func (h *APIHandler) recordRequest(r *http.Request, body []byte) {
if cap(h.history) == 0 {
return
}
h.historyMu.Lock()
defer h.historyMu.Unlock()
h.history[h.historyIndex] = &RequestRecord{Method: r.Method, URI: r.URL.RequestURI(), Body: string(body)}
h.historyIndex = (h.historyIndex + 1) % cap(h.history)
}
func requestAttributes(r *http.Request) attribute.Set {
return attribute.NewSet(
attribute.String("user_agent", r.UserAgent()),
attribute.String("path", r.URL.Path),
)
}