/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.cluster;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Predicate;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.utils.MLNodeUtils;

public class DiscoveryNodeHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(DiscoveryNodeHelper.class);
    private final ClusterService clusterService;
    private final HotDataNodePredicate eligibleNodeFilter;
    private volatile Boolean onlyRunOnMLNode;
    private volatile Set<String> excludedNodeNames;
    private volatile Set<String> remoteModelEligibleNodeRoles;
    private volatile Set<String> localModelEligibleNodeRoles;

    public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
        this.clusterService = clusterService;
        this.eligibleNodeFilter = new HotDataNodePredicate();
        this.onlyRunOnMLNode = (Boolean)MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE, it -> {
            this.onlyRunOnMLNode = it;
        });
        this.excludedNodeNames = Strings.commaDelimitedListToSet((String)((String)MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES.get(settings)));
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES, it -> {
            this.excludedNodeNames = Strings.commaDelimitedListToSet((String)it);
        });
        this.remoteModelEligibleNodeRoles = new HashSet<String>();
        this.remoteModelEligibleNodeRoles.addAll((Collection)MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, it -> {
            this.remoteModelEligibleNodeRoles = new HashSet<String>((Collection<String>)it);
        });
        this.localModelEligibleNodeRoles = new HashSet<String>();
        this.localModelEligibleNodeRoles.addAll((Collection)MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, it -> {
            this.localModelEligibleNodeRoles = new HashSet<String>((Collection<String>)it);
        });
    }

    public String[] getEligibleNodeIds(FunctionName functionName) {
        DiscoveryNode[] nodes = this.getEligibleNodes(functionName);
        String[] nodeIds = new String[nodes.length];
        for (int i = 0; i < nodes.length; ++i) {
            nodeIds[i] = nodes[i].getId();
        }
        return nodeIds;
    }

    public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
        ClusterState state = this.clusterService.state();
        HashSet<DiscoveryNode> eligibleNodes = new HashSet<DiscoveryNode>();
        for (DiscoveryNode node : state.nodes()) {
            if (this.excludedNodeNames != null && this.excludedNodeNames.contains(node.getName())) continue;
            if (functionName == FunctionName.REMOTE || functionName == FunctionName.AGENT) {
                this.getEligibleNode(this.remoteModelEligibleNodeRoles, eligibleNodes, node);
                continue;
            }
            if (this.onlyRunOnMLNode.booleanValue()) {
                if (!MLNodeUtils.isMLNode(node)) continue;
                eligibleNodes.add(node);
                continue;
            }
            this.getEligibleNode(this.localModelEligibleNodeRoles, eligibleNodes, node);
        }
        return eligibleNodes.toArray(new DiscoveryNode[0]);
    }

    private void getEligibleNode(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
        if (allowedNodeRoles.contains("data") && this.isEligibleDataNode(node)) {
            eligibleNodes.add(node);
        }
        for (String nodeRole : allowedNodeRoles) {
            if ("data".equals(nodeRole) || !node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) continue;
            eligibleNodes.add(node);
        }
    }

    public String[] filterEligibleNodes(FunctionName functionName, String[] nodeIds) {
        if (nodeIds == null || nodeIds.length == 0) {
            return nodeIds;
        }
        DiscoveryNode[] nodes = this.getNodes(nodeIds);
        HashSet<String> eligibleNodes = new HashSet<String>();
        for (DiscoveryNode node : nodes) {
            if (this.excludedNodeNames != null && this.excludedNodeNames.contains(node.getName())) continue;
            if (functionName == FunctionName.REMOTE) {
                this.getEligibleNodeIds(this.remoteModelEligibleNodeRoles, eligibleNodes, node);
                continue;
            }
            if (this.onlyRunOnMLNode.booleanValue()) {
                if (!MLNodeUtils.isMLNode(node)) continue;
                eligibleNodes.add(node.getId());
                continue;
            }
            this.getEligibleNodeIds(this.localModelEligibleNodeRoles, eligibleNodes, node);
        }
        return eligibleNodes.toArray(new String[0]);
    }

    private void getEligibleNodeIds(Set<String> allowedNodeRoles, Set<String> eligibleNodes, DiscoveryNode node) {
        if (allowedNodeRoles.contains("data") && this.isEligibleDataNode(node)) {
            eligibleNodes.add(node.getId());
        }
        for (String nodeRole : allowedNodeRoles) {
            if ("data".equals(nodeRole) || !node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) continue;
            eligibleNodes.add(node.getId());
        }
    }

    public DiscoveryNode[] getAllNodes() {
        ClusterState state = this.clusterService.state();
        ArrayList<DiscoveryNode> nodes = new ArrayList<DiscoveryNode>();
        for (DiscoveryNode node : state.nodes()) {
            nodes.add(node);
        }
        return nodes.toArray(new DiscoveryNode[0]);
    }

    public String[] getAllNodeIds() {
        ClusterState state = this.clusterService.state();
        ArrayList<String> allNodes = new ArrayList<String>();
        for (DiscoveryNode node : state.nodes()) {
            allNodes.add(node.getId());
        }
        return allNodes.toArray(new String[0]);
    }

    public DiscoveryNode[] getNodes(String[] nodeIds) {
        ClusterState state = this.clusterService.state();
        HashSet<String> nodes = new HashSet<String>();
        for (String nodeId : nodeIds) {
            nodes.add(nodeId);
        }
        ArrayList<DiscoveryNode> discoveryNodes = new ArrayList<DiscoveryNode>();
        for (DiscoveryNode node : state.nodes()) {
            if (!nodes.contains(node.getId())) continue;
            discoveryNodes.add(node);
        }
        return discoveryNodes.toArray(new DiscoveryNode[0]);
    }

    public String[] getNodeIds(DiscoveryNode[] nodes) {
        ArrayList<String> nodeIds = new ArrayList<String>();
        for (DiscoveryNode node : nodes) {
            nodeIds.add(node.getId());
        }
        return nodeIds.toArray(new String[0]);
    }

    public boolean isEligibleDataNode(DiscoveryNode node) {
        return this.eligibleNodeFilter.test(node);
    }

    public DiscoveryNode getNode(String nodeId) {
        ClusterState state = this.clusterService.state();
        for (DiscoveryNode node : state.nodes()) {
            if (!node.getId().equals(nodeId)) continue;
            return node;
        }
        return null;
    }

    static class HotDataNodePredicate
    implements Predicate<DiscoveryNode> {
        HotDataNodePredicate() {
        }

        @Override
        public boolean test(DiscoveryNode discoveryNode) {
            return discoveryNode.isDataNode() && discoveryNode.getAttributes().getOrDefault("box_type", CommonValue.HOT_BOX_TYPE).equals(CommonValue.HOT_BOX_TYPE);
        }
    }
}

