/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.client;

import java.util.Map;
import java.util.function.Function;
import lombok.Generated;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ml.client.MachineLearningClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.InputHelper;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

public class MachineLearningNodeClient
implements MachineLearningClient {
    private final Client client;

    @Override
    public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().mlInput(mlInput).modelId(modelId).dispatchTask(true).build();
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)predictionRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).dispatchTask(true).build();
        this.client.execute((ActionType)MLTrainAndPredictionTaskAction.INSTANCE, (ActionRequest)request, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder().mlInput(mlInput).async(asyncTask).dispatchTask(true).build();
        this.client.execute((ActionType)MLTrainingTaskAction.INSTANCE, (ActionRequest)trainingTaskRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutput> listener) {
        String action = InputHelper.getAction(args);
        if (action == null) {
            throw new IllegalArgumentException("The parameter action is required.");
        }
        FunctionName functionName = InputHelper.getFunctionName(args);
        MLAlgoParams mlAlgoParams = InputHelper.convertArgumentToMLParameter(args, functionName);
        mlInput.setAlgorithm(functionName);
        mlInput.setParameters(mlAlgoParams);
        switch (action) {
            case "train": {
                boolean asyncTask = args.containsKey("async") ? (Boolean)args.get("async") : false;
                this.train(mlInput, asyncTask, listener);
                break;
            }
            case "predict": {
                String modelId = (String)args.get("modelid");
                if (modelId == null) {
                    throw new IllegalArgumentException("The model ID is required for prediction.");
                }
                this.predict(modelId, mlInput, listener);
                break;
            }
            case "trainandpredict": {
                this.trainAndPredict(mlInput, listener);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported action.");
            }
        }
    }

    @Override
    public void getModel(String modelId, ActionListener<MLModel> listener) {
        MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
        this.client.execute((ActionType)MLModelGetAction.INSTANCE, (ActionRequest)mlModelGetRequest, this.getMlGetModelResponseActionListener(listener));
    }

    private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(ActionListener<MLModel> listener) {
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> listener.onResponse((Object)predictionResponse.getMlModel()), arg_0 -> listener.onFailure(arg_0));
        ActionListener<MLModelGetResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res);
            return getResponse;
        });
        return actionListener;
    }

    @Override
    public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
        MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
        this.client.execute((ActionType)MLModelDeleteAction.INSTANCE, (ActionRequest)mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> listener.onResponse(deleteResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
        this.client.execute((ActionType)MLModelSearchAction.INSTANCE, (ActionRequest)searchRequest, ActionListener.wrap(searchResponse -> listener.onResponse(searchResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void getTask(String taskId, ActionListener<MLTask> listener) {
        MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
        this.client.execute((ActionType)MLTaskGetAction.INSTANCE, (ActionRequest)mlTaskGetRequest, ActionListener.wrap(response -> listener.onResponse((Object)MLTaskGetResponse.fromActionResponse(response).getMlTask()), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
        MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();
        this.client.execute((ActionType)MLTaskDeleteAction.INSTANCE, (ActionRequest)mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> listener.onResponse(deleteResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
        this.client.execute((ActionType)MLTaskSearchAction.INSTANCE, (ActionRequest)searchRequest, ActionListener.wrap(searchResponse -> listener.onResponse(searchResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> listener.onResponse((Object)predictionResponse.getOutput()), arg_0 -> listener.onFailure(arg_0));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
            return predictionResponse;
        });
        return actionListener;
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListener.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    private void validateMLInput(MLInput mlInput, boolean requireInput) {
        if (mlInput == null) {
            throw new IllegalArgumentException("ML Input can't be null");
        }
        if (requireInput && mlInput.getInputDataset() == null) {
            throw new IllegalArgumentException("input data set can't be null");
        }
    }

    @Generated
    public MachineLearningNodeClient(Client client) {
        this.client = client;
    }
}

