package software.amazon.ai.mms.plugins.endpoint;

import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Properties;
import software.amazon.ai.mms.servingsdk.Context;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;
import software.amazon.ai.mms.servingsdk.http.Request;
import software.amazon.ai.mms.servingsdk.http.Response;

/**
The modified endpoint source code for the jar used in this container.
You can create this endpoint by moving it by cloning the MMS repo:
> git clone https://github.com/awslabs/mxnet-model-server.git

Copy this file into plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoints/
and then from the plugins directory, run:

> ./gradlew fJ

Modify file in plugins/endpoint/resources/META-INF/services/* to specify this file location

Then build the JAR:

> ./gradlew build

The jar should be available in plugins/endpoints/build/libs as endpoints-1.0.jar
**/
@Endpoint(
        urlPattern = "execution-parameters",
        endpointType = EndpointTypes.INFERENCE,
        description = "Execution parameters endpoint")
public class ExecutionParameters extends ModelServerEndpoint {

    @Override
    public void doGet(Request req, Response rsp, Context ctx) throws IOException {
        Properties prop = ctx.getConfig();
        // 6 * 1024 * 1024
        int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456"));
        SagemakerXgboostResponse response = new SagemakerXgboostResponse();
        response.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1")));
        response.setBatchStrategy("MULTI_RECORD");
        response.setMaxPayloadInMB(maxRequestSize / (1024 * 1024));
        rsp.getOutputStream()
                .write(
                        new GsonBuilder()
                                .setPrettyPrinting()
                                .create()
                                .toJson(response)
                                .getBytes(StandardCharsets.UTF_8));
    }

    /** Response for Model server endpoint */
    public static class SagemakerXgboostResponse {
        @SerializedName("MaxConcurrentTransforms")
        private int maxConcurrentTransforms;

        @SerializedName("BatchStrategy")
        private String batchStrategy;

        @SerializedName("MaxPayloadInMB")
        private int maxPayloadInMB;

        public SagemakerXgboostResponse() {
            maxConcurrentTransforms = 4;
            batchStrategy = "MULTI_RECORD";
            maxPayloadInMB = 6;
        }

        public int getMaxConcurrentTransforms() {
            return maxConcurrentTransforms;
        }

        public String getBatchStrategy() {
            return batchStrategy;
        }

        public int getMaxPayloadInMB() {
            return maxPayloadInMB;
        }

        public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) {
            maxConcurrentTransforms = newMaxConcurrentTransforms;
        }

        public void setBatchStrategy(String newBatchStrategy) {
            batchStrategy = newBatchStrategy;
        }

        public void setMaxPayloadInMB(int newMaxPayloadInMB) {
            maxPayloadInMB = newMaxPayloadInMB;
        }
    }
}
