/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.dex.visitors;

import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.BaseInvokeNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.dex.visitors.SimplifyVisitor;
import jadx.core.dex.visitors.methods.MutableMethodDetails;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.dex.visitors.typeinference.TypeCompare;
import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

@JadxVisitor(name="MethodInvokeVisitor", desc="Process additional info for method invocation (overload, vararg)", runAfter={CodeShrinkVisitor.class, ModVisitor.class}, runBefore={SimplifyVisitor.class})
public class MethodInvokeVisitor
extends AbstractVisitor {
    private RootNode root;

    @Override
    public void init(RootNode root) {
        this.root = root;
    }

    @Override
    public void visit(MethodNode mth) {
        if (mth.isNoCode()) {
            return;
        }
        for (BlockNode block : mth.getBasicBlocks()) {
            if (block.contains(AFlag.DONT_GENERATE)) continue;
            for (InsnNode insn : block.getInstructions()) {
                if (insn.contains(AFlag.DONT_GENERATE)) continue;
                this.processInsn(mth, insn);
            }
        }
    }

    private void processInsn(MethodNode mth, InsnNode insn) {
        if (insn instanceof BaseInvokeNode) {
            this.processInvoke(mth, (BaseInvokeNode)insn);
        }
        for (InsnArg insnArg : insn.getArguments()) {
            if (!(insnArg instanceof InsnWrapArg)) continue;
            InsnNode wrapInsn = ((InsnWrapArg)insnArg).getWrapInsn();
            this.processInsn(mth, wrapInsn);
        }
    }

    private void processInvoke(MethodNode parentMth, BaseInvokeNode invokeInsn) {
        MethodInfo callMth = invokeInsn.getCallMth();
        if (callMth.getArgsCount() == 0) {
            return;
        }
        IMethodDetails mthDetails = this.root.getMethodUtils().getMethodDetails(invokeInsn);
        if (mthDetails == null) {
            this.processUnknown(invokeInsn);
        } else {
            ArgType last;
            if (mthDetails.isVarArg() && (last = Utils.last(mthDetails.getArgTypes())) != null && last.isArray()) {
                invokeInsn.add(AFlag.VARARG_CALL);
            }
            this.processOverloaded(parentMth, invokeInsn, mthDetails);
        }
    }

    private void processOverloaded(MethodNode parentMth, BaseInvokeNode invokeInsn, IMethodDetails mthDetails) {
        MethodInfo callMth = invokeInsn.getCallMth();
        ArgType callCls = this.getCallClassFromInvoke(parentMth, invokeInsn, callMth);
        List<IMethodDetails> overloadMethods = this.root.getMethodUtils().collectOverloadedMethods(callCls, callMth);
        if (overloadMethods.isEmpty()) {
            return;
        }
        Map<ArgType, ArgType> typeVarsMapping = this.getTypeVarsMapping(invokeInsn);
        IMethodDetails effectiveMthDetails = this.resolveTypeVars(mthDetails, typeVarsMapping);
        ArrayList<IMethodDetails> effectiveOverloadMethods = new ArrayList<IMethodDetails>(overloadMethods.size() + 1);
        for (IMethodDetails overloadMethod : overloadMethods) {
            effectiveOverloadMethods.add(this.resolveTypeVars(overloadMethod, typeVarsMapping));
        }
        effectiveOverloadMethods.add(effectiveMthDetails);
        int argsOffset = invokeInsn.getFirstArgOffset();
        List<ArgType> compilerVarTypes = this.collectCompilerVarTypes(invokeInsn, argsOffset);
        List<ArgType> castTypes = this.searchCastTypes(parentMth, effectiveMthDetails, effectiveOverloadMethods, compilerVarTypes);
        this.applyArgsCast(invokeInsn, argsOffset, compilerVarTypes, castTypes);
    }

    private void processUnknown(BaseInvokeNode invokeInsn) {
        int argsOffset = invokeInsn.getFirstArgOffset();
        List<ArgType> compilerVarTypes = this.collectCompilerVarTypes(invokeInsn, argsOffset);
        ArrayList<ArgType> castTypes = new ArrayList<ArgType>(compilerVarTypes);
        if (this.replaceUnknownTypes(castTypes, invokeInsn.getCallMth().getArgumentsTypes())) {
            this.applyArgsCast(invokeInsn, argsOffset, compilerVarTypes, castTypes);
        }
    }

    private ArgType getCallClassFromInvoke(MethodNode parentMth, BaseInvokeNode invokeInsn, MethodInfo callMth) {
        ConstructorInsn constrInsn;
        if (invokeInsn instanceof ConstructorInsn && (constrInsn = (ConstructorInsn)invokeInsn).isSuper()) {
            return parentMth.getParentClass().getSuperClass();
        }
        InsnArg instanceArg = invokeInsn.getInstanceArg();
        if (instanceArg != null) {
            return instanceArg.getType();
        }
        return callMth.getDeclClass().getType();
    }

    private Map<ArgType, ArgType> getTypeVarsMapping(BaseInvokeNode invokeInsn) {
        ArgType declClsType = invokeInsn.getCallMth().getDeclClass().getType();
        ArgType callClsType = this.getClsCallType(invokeInsn, declClsType);
        return this.root.getTypeUtils().getTypeVariablesMapping(callClsType);
    }

    private ArgType getClsCallType(BaseInvokeNode invokeInsn, ArgType declClsType) {
        InsnArg instanceArg = invokeInsn.getInstanceArg();
        if (instanceArg != null) {
            return instanceArg.getType();
        }
        if (invokeInsn.getType() == InsnType.CONSTRUCTOR && invokeInsn.getResult() != null) {
            return invokeInsn.getResult().getType();
        }
        return declClsType;
    }

    private void applyArgsCast(BaseInvokeNode invokeInsn, int argsOffset, List<ArgType> compilerVarTypes, List<ArgType> castTypes) {
        int argsCount = invokeInsn.getArgsCount();
        for (int i = argsOffset; i < argsCount; ++i) {
            InsnNode wrapInsn;
            InsnArg arg = invokeInsn.getArg(i);
            int origPos = i - argsOffset;
            ArgType compilerType = compilerVarTypes.get(origPos);
            ArgType castType = castTypes.get(origPos);
            if (castType == null) continue;
            if (!castType.equals(compilerType)) {
                if (arg.isLiteral() && compilerType.isPrimitive() && castType.isPrimitive()) {
                    arg.setType(castType);
                    arg.add(AFlag.EXPLICIT_PRIMITIVE_TYPE);
                    continue;
                }
                IndexInsnNode castInsn = new IndexInsnNode(InsnType.CAST, castType, 1);
                castInsn.addArg(arg);
                castInsn.add(AFlag.EXPLICIT_CAST);
                InsnArg wrapCast = InsnArg.wrapArg(castInsn);
                wrapCast.setType(castType);
                invokeInsn.setArg(i, wrapCast);
                continue;
            }
            if (!arg.isInsnWrap() || (wrapInsn = ((InsnWrapArg)arg).getWrapInsn()).getType() != InsnType.CHECK_CAST) continue;
            wrapInsn.add(AFlag.EXPLICIT_CAST);
        }
    }

    private IMethodDetails resolveTypeVars(IMethodDetails mthDetails, Map<ArgType, ArgType> typeVarsMapping) {
        ArgType resolvedType;
        List<ArgType> argTypes = mthDetails.getArgTypes();
        int argsCount = argTypes.size();
        boolean fixed = false;
        ArrayList<ArgType> fixedArgTypes = new ArrayList<ArgType>(argsCount);
        for (int argNum = 0; argNum < argsCount; ++argNum) {
            ArgType argType = argTypes.get(argNum);
            if (argType == null) {
                throw new JadxRuntimeException("Null arg type in " + mthDetails + " at: " + argNum + " in: " + argTypes);
            }
            if (argType.containsTypeVariable()) {
                ArgType resolvedType2 = this.root.getTypeUtils().replaceTypeVariablesUsingMap(argType, typeVarsMapping);
                if (resolvedType2 == null || resolvedType2.equals(argType)) {
                    resolvedType2 = mthDetails.getMethodInfo().getArgumentsTypes().get(argNum);
                }
                fixedArgTypes.add(resolvedType2);
                fixed = true;
                continue;
            }
            fixedArgTypes.add(argType);
        }
        ArgType returnType = mthDetails.getReturnType();
        if (returnType.containsTypeVariable() && ((resolvedType = this.root.getTypeUtils().replaceTypeVariablesUsingMap(returnType, typeVarsMapping)) == null || resolvedType.containsTypeVariable())) {
            returnType = mthDetails.getMethodInfo().getReturnType();
            fixed = true;
        }
        if (!fixed) {
            return mthDetails;
        }
        MutableMethodDetails mutableMethodDetails = new MutableMethodDetails(mthDetails);
        mutableMethodDetails.setArgTypes(fixedArgTypes);
        mutableMethodDetails.setRetType(returnType);
        return mutableMethodDetails;
    }

    private List<ArgType> searchCastTypes(MethodNode parentMth, IMethodDetails mthDetails, List<IMethodDetails> overloadedMethods, List<ArgType> compilerVarTypes) {
        if (this.isOverloadResolved(mthDetails, overloadedMethods, compilerVarTypes)) {
            return compilerVarTypes;
        }
        int argsCount = compilerVarTypes.size();
        ArrayList<ArgType> castTypes = new ArrayList<ArgType>(compilerVarTypes);
        boolean changed = this.replaceUnknownTypes(castTypes, mthDetails.getArgTypes());
        if (changed && this.isOverloadResolved(mthDetails, overloadedMethods, castTypes)) {
            return castTypes;
        }
        changed = false;
        for (int i = 0; i < argsCount; ++i) {
            ArgType castType = (ArgType)castTypes.get(i);
            ArgType mthType = mthDetails.getArgTypes().get(i);
            if (castType.isGeneric() || !mthType.isGeneric()) continue;
            castTypes.set(i, mthType);
            changed = true;
        }
        if (changed && this.isOverloadResolved(mthDetails, overloadedMethods, castTypes)) {
            return castTypes;
        }
        if (argsCount == 1) {
            return mthDetails.getArgTypes();
        }
        return mthDetails.getArgTypes();
    }

    private boolean replaceUnknownTypes(List<ArgType> castTypes, List<ArgType> mthArgTypes) {
        int argsCount = castTypes.size();
        boolean changed = false;
        for (int i = 0; i < argsCount; ++i) {
            ArgType castType = castTypes.get(i);
            if (castType.isTypeKnown()) continue;
            ArgType mthType = mthArgTypes.get(i);
            castTypes.set(i, mthType);
            changed = true;
        }
        return changed;
    }

    private boolean isOverloadResolved(IMethodDetails expectedMthDetails, List<IMethodDetails> overloadedMethods, List<ArgType> castTypes) {
        if (overloadedMethods.isEmpty()) {
            return false;
        }
        List<IMethodDetails> strictMethods = this.filterApplicableMethods(overloadedMethods, castTypes, MethodInvokeVisitor::isStrictTypes);
        if (strictMethods.size() == 1) {
            return strictMethods.get(0).equals(expectedMthDetails);
        }
        List<IMethodDetails> resolvedMethods = this.filterApplicableMethods(overloadedMethods, castTypes, MethodInvokeVisitor::isTypeApplicable);
        if (resolvedMethods.size() == 1) {
            return resolvedMethods.get(0).equals(expectedMthDetails);
        }
        return false;
    }

    private static boolean isStrictTypes(TypeCompareEnum result) {
        return result.isEqual();
    }

    private static boolean isTypeApplicable(TypeCompareEnum result) {
        return result.isNarrowOrEqual() || result == TypeCompareEnum.WIDER_BY_GENERIC;
    }

    private List<IMethodDetails> filterApplicableMethods(List<IMethodDetails> methods, List<ArgType> types, Function<TypeCompareEnum, Boolean> acceptFunction) {
        ArrayList<IMethodDetails> list = new ArrayList<IMethodDetails>(methods.size());
        for (IMethodDetails m : methods) {
            if (!this.isMethodAcceptable(m, types, acceptFunction)) continue;
            list.add(m);
        }
        return list;
    }

    private boolean isMethodAcceptable(IMethodDetails methodDetails, List<ArgType> types, Function<TypeCompareEnum, Boolean> acceptFunction) {
        List<ArgType> mthTypes = methodDetails.getArgTypes();
        int argCount = mthTypes.size();
        if (argCount != types.size()) {
            return false;
        }
        TypeCompare typeCompare = this.root.getTypeUpdate().getTypeCompare();
        for (int i = 0; i < argCount; ++i) {
            ArgType mthType = mthTypes.get(i);
            ArgType argType = types.get(i);
            TypeCompareEnum result = typeCompare.compareTypes(argType, mthType);
            if (acceptFunction.apply(result).booleanValue()) continue;
            return false;
        }
        return true;
    }

    private List<ArgType> collectCompilerVarTypes(BaseInvokeNode insn, int argOffset) {
        int argsCount = insn.getArgsCount();
        ArrayList<ArgType> result = new ArrayList<ArgType>(argsCount);
        for (int i = argOffset; i < argsCount; ++i) {
            InsnArg arg = insn.getArg(i);
            result.add(this.getCompilerVarType(arg));
        }
        return result;
    }

    private ArgType getCompilerVarType(InsnArg arg) {
        if (arg instanceof LiteralArg) {
            LiteralArg literalArg = (LiteralArg)arg;
            ArgType type = literalArg.getType();
            if (literalArg.getLiteral() == 0L && (type.isObject() || type.isArray())) {
                return ArgType.UNKNOWN_OBJECT;
            }
            if (type.isPrimitive() && !arg.contains(AFlag.EXPLICIT_PRIMITIVE_TYPE)) {
                return ArgType.INT;
            }
            return arg.getType();
        }
        if (arg instanceof RegisterArg) {
            return arg.getType();
        }
        if (arg instanceof InsnWrapArg) {
            InsnWrapArg wrapArg = (InsnWrapArg)arg;
            InsnNode wrapInsn = wrapArg.getWrapInsn();
            if (wrapInsn.getResult() != null) {
                return wrapInsn.getResult().getType();
            }
            return arg.getType();
        }
        throw new JadxRuntimeException("Unknown var type for: " + arg);
    }
}

