Context.java

/*
 * SPDX-FileCopyrightText: 2025 kaumei.io
 * SPDX-License-Identifier: Apache-2.0
 */
package io.kaumei.jdbc.anno.ctx;

import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.util.Trees;
import io.kaumei.jdbc.anno.ProcessorException;
import io.kaumei.jdbc.anno.gen.GenerateService;
import io.kaumei.jdbc.anno.java2jdbc.Java2JdbcService;
import io.kaumei.jdbc.anno.jdbc2java.Jdbc2JavaService;
import io.kaumei.jdbc.anno.model.*;
import io.kaumei.jdbc.anno.msg.JdbcMsg;
import io.kaumei.jdbc.anno.msg.Msg;
import io.kaumei.jdbc.anno.store.SearchKey;
import io.kaumei.jdbc.annotation.*;
import io.kaumei.jdbc.annotation.config.JdbcConfig;
import io.kaumei.jdbc.annotation.config.JdbcLogLevel;
import org.jspecify.annotations.Nullable;

import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.*;
import javax.lang.model.type.ArrayType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import java.nio.file.Paths;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static io.kaumei.jdbc.anno.ctx.JavaModelUtils.component;
import static io.kaumei.jdbc.anno.ctx.JavaModelUtils.componentOpt;
import static io.kaumei.jdbc.anno.model.OptionalFlag.OPTIONAL_TYPE;
import static java.util.Objects.requireNonNull;

public class Context {

    public static final TypeMirror[] EMPTY_TYPE_MIRROR_ARRAY = new TypeMirror[0];

    public final JavaType JAVA_Boolean;
    public final JavaType JAVA_Byte;
    public final JavaType JAVA_Character;
    public final JavaType JAVA_Double;
    public final JavaType JAVA_Float;
    public final JavaType JAVA_Integer;
    public final JavaType JAVA_List;
    public final JavaType JAVA_Long;
    public final JavaType JAVA_Optional;
    public final JavaType JAVA_RuntimeException;
    public final JavaType JAVA_SQL_Connection;
    public final JavaType JAVA_SQL_PreparedStatement;
    public final JavaType JAVA_SQL_ResultSet;
    public final JavaType JAVA_SQL_SQLException;
    public final JavaType JAVA_Short;
    public final JavaType JAVA_Stream;
    public final JavaType JAVA_String;
    public final JavaType JAVA_Void;

    public final JavaAnnoType<ConverterNameDV> JAVA_TO_JDBC;
    public final JavaAnnoType<JavaAnnoType.JdbcBatchUpdateRecord> JDBC_BATCH_UPDATE;
    public final JavaAnnoType<ConverterNameDV> JDBC_CONVERTER_NAME;
    public final JavaAnnoType<JavaAnnoType.JdbcGenerationRecord> JDBC_GENERATION;
    public final JavaAnnoType<JavaAnnoType.JdbcDebugRecord> JDBC_DEBUG;
    public final JavaAnnoType<SQLNameDV> JDBC_NAME;
    public final JavaAnnoType<JavaAnnoType.JdbcNativeRecord> JDBC_NATIVE;
    public final JavaAnnoType<JavaAnnoType.JdbcSelectRecord> JDBC_SELECT;
    public final JavaAnnoType<ConverterNameDV> JDBC_TO_JAVA;
    public final JavaAnnoType<JavaAnnoType.JdbcUpdateRecord> JDBC_UPDATE;

    public final JavaAnnoType<JavaAnnoType.JdbcConfigRecord> JDBC_CONFIG_PROPS;
    public final JavaAnnoType<JavaAnnoType.JdbcLogLevelRecord> JDBC_LOG_LEVEL;

    public final Map<Name, JavaAnnoType<?>> KAUMEI_ANNO;

    public final JavaType KAUMEI_JDBC_JdbcBatch;
    public final JavaType KAUMEI_JDBC_JdbcIterable;
    public final JavaType KAUMEI_JDBC_JdbcResultSet;
    public final JavaType KAUMEI_JDBC_JdbcRow;

    // ----- service
    public final Types types;
    public final Elements elements;
    public final @Nullable Trees trees;
    public final JavaMessenger logger;
    public final JavaFiler filer;
    public final ConfigService kaumeiConfig;
    public final SourceMethodService sourceMethodService;
    public final Jdbc2JavaService kaumeiJdbc2Java;
    public final Java2JdbcService kaumeiJava2Jdbc;
    public final GenerateService kaumeiJdbcGenerator;

    // ------------------------------------------------------------------------

    private final BiFunction<Element, TypeMirror, OptionalFlag> jspecifyCheck;

    // ------------------------------------------------------------------------

    public Context(ProcessingEnvironment env) {
        this.types = requireNonNull(env.getTypeUtils());
        this.elements = requireNonNull(env.getElementUtils());
        this.trees = treesInstance(env);
        this.logger = new JavaMessenger(this, env);
        this.filer = new JavaFiler(this, env.getFiler());
        this.kaumeiConfig = new ConfigService(this, env.getOptions());
        this.sourceMethodService = new SourceMethodService(this);
        this.kaumeiJdbc2Java = new Jdbc2JavaService(this);
        this.kaumeiJava2Jdbc = new Java2JdbcService(this);
        this.kaumeiJdbcGenerator = new GenerateService(this);

        JAVA_Boolean = javaType(Boolean.class);
        JAVA_Byte = javaType(Byte.class);
        JAVA_Character = javaType(Character.class);
        JAVA_Double = javaType(Double.class);
        JAVA_Float = javaType(Float.class);
        JAVA_Integer = javaType(Integer.class);
        JAVA_List = javaType(List.class);
        JAVA_Long = javaType(Long.class);
        JAVA_Optional = javaType(Optional.class);
        JAVA_RuntimeException = javaType(RuntimeException.class);
        JAVA_SQL_Connection = javaType(Connection.class);
        JAVA_SQL_PreparedStatement = javaType(PreparedStatement.class);
        JAVA_SQL_ResultSet = javaType(java.sql.ResultSet.class);
        JAVA_SQL_SQLException = javaType(SQLException.class);
        JAVA_Short = javaType(Short.class);
        JAVA_Stream = javaType(Stream.class);
        JAVA_String = javaType(String.class);
        JAVA_Void = javaType(Void.class);

        // @formatter:off
        JAVA_TO_JDBC        = new JavaAnnoType<>(this,JavaToJdbc.class,        ConverterNameDV.class,ConverterNameDV::of);
        JDBC_BATCH_UPDATE   = new JavaAnnoType<>(this,JdbcBatchUpdate.class,   JavaAnnoType.JdbcBatchUpdateRecord.class, JavaAnnoType.JdbcBatchUpdateRecord::of);
        JDBC_CONVERTER_NAME = new JavaAnnoType<>(this,JdbcConverterName.class, ConverterNameDV.class, ConverterNameDV::of);
        JDBC_GENERATION     = new JavaAnnoType<>(this,JdbcGeneration.class,    JavaAnnoType.JdbcGenerationRecord.class, JavaAnnoType.JdbcGenerationRecord::of);
        JDBC_DEBUG          = new JavaAnnoType<>(this,JdbcDebug.class,         JavaAnnoType.JdbcDebugRecord.class,JavaAnnoType.JdbcDebugRecord::of);
        JDBC_NAME           = new JavaAnnoType<>(this,JdbcName.class,          SQLNameDV.class, SQLNameDV::of);
        JDBC_NATIVE         = new JavaAnnoType<>(this,JdbcNative.class,        JavaAnnoType.JdbcNativeRecord.class, JavaAnnoType.JdbcNativeRecord::of);
        JDBC_SELECT         = new JavaAnnoType<>(this,JdbcSelect.class,        JavaAnnoType.JdbcSelectRecord.class, JavaAnnoType.JdbcSelectRecord::of);
        JDBC_TO_JAVA        = new JavaAnnoType<>(this,JdbcToJava.class,        ConverterNameDV.class, ConverterNameDV::of);
        JDBC_UPDATE         = new JavaAnnoType<>(this,JdbcUpdate.class,        JavaAnnoType.JdbcUpdateRecord.class, JavaAnnoType.JdbcUpdateRecord::of);

        JDBC_CONFIG_PROPS            = new JavaAnnoType<>(this,JdbcConfig.class,               JavaAnnoType.JdbcConfigRecord.class, JavaAnnoType.JdbcConfigRecord::of);
        JDBC_LOG_LEVEL               = new JavaAnnoType<>(this,JdbcLogLevel.class,             JavaAnnoType.JdbcLogLevelRecord.class, JavaAnnoType.JdbcLogLevelRecord::of);
        // @formatter:on

        KAUMEI_ANNO = toMap(JAVA_TO_JDBC, JDBC_GENERATION, JDBC_CONVERTER_NAME,
                JDBC_NAME, JDBC_NATIVE, JDBC_SELECT, JDBC_TO_JAVA, JDBC_UPDATE, JDBC_BATCH_UPDATE,
                //
                JDBC_CONFIG_PROPS, JDBC_LOG_LEVEL,
                this.kaumeiConfig.BATCH_SIZE.anno(), this.kaumeiConfig.FETCH_DIRECTION.anno(),
                this.kaumeiConfig.FETCH_SIZE.anno(), this.kaumeiConfig.MAX_ROWS.anno(),
                this.kaumeiConfig.NO_MORE_ROWS.anno(), this.kaumeiConfig.NO_ROWS.anno(),
                this.kaumeiConfig.QUERY_TIMEOUT.anno(),
                this.kaumeiConfig.RESULT_SET_CONCURRENCY.anno(), this.kaumeiConfig.RESULT_SET_TYPE.anno()
        );

        var JdbcBatch = javaTypeOpt("io.kaumei.jdbc.JdbcBatch");
        var JdbcIterable = javaTypeOpt("io.kaumei.jdbc.JdbcIterable");
        var JdbcResultSet = javaTypeOpt("io.kaumei.jdbc.JdbcResultSet");
        var JdbcRow = javaTypeOpt("io.kaumei.jdbc.JdbcRow");
        if (JdbcBatch == null || JdbcIterable == null || JdbcResultSet == null || JdbcRow == null) {
            throw new IllegalStateException("""
                    Could not find Kaumei-JDBC lib on the classpath.
                    Add the following dependency io.kaumei.jdbc:jdbc-annotation and io.kaumei.jdbc:jdbc-core
                    to your classpath or disable the io.kaumei.jdbc:jdbc-processor annotation processor.
                    """);
        }
        this.KAUMEI_JDBC_JdbcBatch = JdbcBatch;
        this.KAUMEI_JDBC_JdbcIterable = JdbcIterable;
        this.KAUMEI_JDBC_JdbcResultSet = JdbcResultSet;
        this.KAUMEI_JDBC_JdbcRow = JdbcRow;

        this.jspecifyCheck = new JspecifyChecker(this);
    }

    // ------------------------------------------------------------------------

    private static @Nullable Trees treesInstance(ProcessingEnvironment env) {
        try {
            return Trees.instance(env);
        } catch (Exception e) {
            return null;
        }
    }

    public @Nullable String sourcePosition(Element element) {
        if (this.trees != null) {
            try {
                var path = this.trees.getPath(element);
                var sp = this.trees.getSourcePositions();
                if (path != null) {
                    CompilationUnitTree cu = path.getCompilationUnit();
                    long pos = sp.getStartPosition(cu, path.getLeaf());
                    var filename = Paths.get(cu.getSourceFile().toUri().getPath()).getFileName();
                    return "(" + filename + ':' + cu.getLineMap().getLineNumber(pos) + ')';
                }
            } catch (Exception e) {
                // fallback if something went wrong
                return "(unknown)";
            }
        }
        return null;
    }

    // ------------------------------------------------------------------------

    private JavaType javaType(Class<?> cls) {
        return new JavaType(this, elements.getTypeElement(cls.getCanonicalName()));
    }

    private @Nullable JavaType javaTypeOpt(CharSequence fcq) {
        var element = elements.getTypeElement(fcq);
        if (element != null) {
            return new JavaType(this, element);
        }
        return null;
    }

    // ------------------------------------------------------------------------

    public @Nullable TypeElement getTypeElementOpt(CharSequence name) {
        return this.elements.getTypeElement(name);
    }

    public boolean hasKaumeiPkg(Element e) {
        var elemPkg = this.elements.getPackageOf(e);
        return elemPkg.getQualifiedName().contentEquals("io.kaumei.jdbc.annotation")
                || elemPkg.getQualifiedName().contentEquals("io.kaumei.jdbc.annotation.config");
    }

    public PackageElement getPackageOf(Element e) {
        return this.elements.getPackageOf(e);
    }

    @SuppressWarnings("BooleanMethodIsAlwaysInverted")
    public boolean hasValidSqlExceptions(ExecutableElement method) {
        for (TypeMirror methodThrows : method.getThrownTypes()) {
            if (!this.JAVA_SQL_SQLException.isSupertypeOf(methodThrows)
                    && !this.JAVA_RuntimeException.isSupertypeOf(methodThrows)) {
                return false;
            }
        }
        return true;
    }

    public SQLNameDV jdbcName(Element elem) {
        var name = this.JDBC_NAME.getAnnoOpt(elem);
        if (name != null) {
            return name;
        }
        return SQLNameDV.ofJavaIdentifier(elem.getSimpleName());
    }

    // ------------------------------------------------------------------------

    private record CheckExceptionResult(boolean containsSqlException, Set<TypeMirror> notCovered) {
    }

    private CheckExceptionResult checkExceptions(ExecutableElement source, ExecutableElement target) {
        var containsSqlException = false;
        Set<TypeMirror> notCovered = new HashSet<>();
        for (var targetExp : target.getThrownTypes()) {
            if (this.JAVA_SQL_SQLException.isSupertypeOf(targetExp)) {
                containsSqlException = true;
            } else if (this.JAVA_RuntimeException.isSupertypeOf(targetExp)) {
                // ok
            } else if (!this.isSubtype(targetExp, source.getThrownTypes())) {
                notCovered.add(targetExp);
            }
        }
        return new CheckExceptionResult(containsSqlException, Collections.unmodifiableSet(notCovered));
    }

    // ------------------------------------------------------------------------

    /**
     * Can we call target with in source.
     */
    public Msg.Messages isCallable(ExecutableElement source, ExecutableElement target) {
        var messages = Msg.builder();

        if (source.getReturnType().getKind() == TypeKind.VOID) {
            // ok
        } else if (!this.types.isAssignable(target.getReturnType(), source.getReturnType())) {
            messages.add(JdbcMsg.returnTypeNotAssignable(target.getReturnType().toString(), source.getReturnType().toString()));
        } else {
            var sourceFlag = this.optionalFlag(source, source.getReturnType());
            var targetFlag = this.optionalFlag(target, target.getReturnType());
            if (!targetFlag.isAssignableTo(sourceFlag)) {
                messages.add(JdbcMsg.returnTypeNullnessMismatch(targetFlag, sourceFlag));
            }
        }

        var parameters1 = source.getParameters();
        var parameters2 = target.getParameters();
        var size1 = parameters1.size();
        var size2 = parameters2.size();
        if (size1 + 1 != size2) {
            messages.add(JdbcMsg.javaMethodsHaveDifferentParameter());
        } else {
            if (!this.JAVA_SQL_Connection.isSameType(parameters2.get(0).asType())) {
                messages.add(JdbcMsg.firstParameterMustBeConnection());
            }
            for (int i = 0; i < size1; i++) {
                var p1 = parameters1.get(i).asType();
                var p2 = parameters2.get(i + 1).asType();
                var sourceFlag = this.optionalFlag(source, p1);
                var targetFlag = this.optionalFlag(target, p2);

                if (sourceFlag.isOptionalType()) {
                    messages.add(JdbcMsg.sourceParamOptionalTypeNotSupported(i));
                } else if (targetFlag.isOptionalType()) {
                    messages.add(JdbcMsg.targetParamOptionalTypeNotSupported(i + 1));
                } else if (!sourceFlag.isAssignableTo(targetFlag)) {
                    messages.add(JdbcMsg.paramNullnessMismatch(i, sourceFlag, targetFlag));
                } else if (!this.types.isSameType(p1, p2)) {
                    messages.add(JdbcMsg.paramTypeMismatch(i, p1.toString(), p2.toString()));
                }
            }
        }
        var checkResult = checkExceptions(source, target);
        for (var exp : checkResult.notCovered()) {
            messages.add(JdbcMsg.exceptionNotCompatible(exp.toString()));
        }
        return messages.build();
    }

    // ------------------------------------------------------------------------

    /**
     * Only Jspecify is supported and must be used
     * @param context the element the type belongs to
     * @param type the type to check
     * @return result of the check, not null
     */
    public OptionalFlag optionalFlag(Element context, TypeMirror type) {
        if (type.getKind().isPrimitive()) {
            return OptionalFlag.NON_NULL;
        } else if (this.JAVA_Optional.isSameType(type)) {
            return OPTIONAL_TYPE;
        }
        return jspecifyCheck.apply(context, type);
    }

    public OptionalFlag optionalFlag(VariableElement varElem) {
        return optionalFlag(varElem.getEnclosingElement(), varElem.asType());
    }

    public OptionalFlag optionalFlag(RecordComponentElement varElem) {
        return optionalFlag(varElem.getEnclosingElement(), varElem.asType());
    }

    // ------------------------------------------------------------------------

    public TypeMirror typeMirror(TypeKind kind) {
        return this.types.getPrimitiveType(kind);
    }

    public TypeMirror typeMirror(Class<?> cls) {
        return this.elements.getTypeElement(cls.getCanonicalName()).asType();
    }

    public ArrayType getArrayType(TypeMirror componentType) {
        return this.types.getArrayType(componentType);
    }

    /**
     * Tests whether t1 is a subtype of one in list.
     * Any type is considered to be a subtype of itself.
     */
    public boolean isSubtype(TypeMirror type, List<? extends TypeMirror> list) {
        for (TypeMirror methodThrows : list) {
            if (this.types.isSubtype(type, methodThrows)) {
                return true;
            }
        }
        return false;
    }

    // ------------------------------------------------------------------------

    public TypeMirror erasure(TypeMirror type) {
        return this.types.erasure(type);
    }

    public boolean isAssignable(TypeMirror t1, TypeMirror t2) {
        if (t1.getKind() == TypeKind.ERROR || t2.getKind() == TypeKind.ERROR) {
            return false;
        }
        return this.types.isAssignable(t1, t2);
    }

    /**
     * Tests whether t1 is a subtype of t2. Any type is considered to be a subtype of itself.
     */
    public boolean isSubtype(TypeMirror t1, TypeMirror t2) {
        // in same rare cases the annotation process is called with invalid TypeMirror
        if (t1.getKind() == TypeKind.ERROR || t2.getKind() == TypeKind.ERROR) {
            return false;
        }
        return this.types.isSubtype(t1, t2);
    }

    public boolean isSameType(TypeMirror t1, TypeMirror t2) {
        // in same rare cases the annotation process is called with invalid TypeMirror
        if (t1.getKind() == TypeKind.ERROR || t2.getKind() == TypeKind.ERROR) {
            return false;
        }
        return types.isSameType(t1, t2);
    }

    // ------------------------------------------------------------------------

    public @Nullable Element asElementOpt(TypeMirror tm) {
        return this.types.asElement(tm);
    }

    public Element asElement(TypeMirror tm) {
        return requireNonNull(this.types.asElement(tm));
    }

    public TypeElement asTypeElement(TypeMirror tm) {
        var element = this.types.asElement(tm);
        if (element instanceof TypeElement te) {
            return te;
        }
        throw new IllegalArgumentException("Could not convert into type:" + tm);
    }

    // ------------------------------------------------------------------------

    @Nullable
    public TypeOptional resolveComponentOpt(Element element, TypeMirror mirror) {
        var cmp = componentOpt(mirror);
        if (cmp == null) {
            return null;
        }
        var optional = optionalFlag(element, cmp);
        return optional.isOptionalType()
                ? new TypeOptional(component(cmp), optional)
                : new TypeOptional(cmp, optional);
    }

    public TypeOptional resolveComponent(Element element, TypeMirror mirror) {
        return Objects.requireNonNull(resolveComponentOpt(element, mirror), "TypeMirror has no component.");
    }

    // ------------------------------------------------------------------------

    public JdbcTypeMirror<ExecutableElement> jdbcTypeMirror(ExecutableElement element) {
        return JdbcTypeMirror.of(this, element);
    }

    // ------------------------------------------------------------------------

    public Map<SearchKey, Integer> collectTypeHierarchy2(TypeMirror type) {
        var levels = new LinkedHashMap<SearchKey, Integer>();
        visitTypeHierarchy2(levels, type);
        return levels;
    }

    private int visitTypeHierarchy2(Map<SearchKey, Integer> levels, TypeMirror current) {
        var id = SearchKey.of( current);
        var level = levels.get(id);
        if (level == null) {
            level = 0;
            for (TypeMirror parent : this.types.directSupertypes(current)) {
                var parentLevel = visitTypeHierarchy2(levels, parent);
                if (parentLevel > level) {
                    level = parentLevel;
                }
            }
            level++;
            levels.put(id, level);
        }
        return level;
    }

    // ------------------------------------------------------------------------

    private static Map<Name, JavaAnnoType<?>> toMap(JavaAnnoType<?>... all) {
        var map = new HashMap<Name, JavaAnnoType<?>>();
        for (var a : all) {
            if (map.put(a.typeElement().getQualifiedName(), a) != null) {
                throw new ProcessorException("Duplicate key: " + a.typeElement());
            }
        }
        return map;
    }

}