GenerateJdbcNative.java

/*
 * SPDX-FileCopyrightText: 2025 kaumei.io
 * SPDX-License-Identifier: Apache-2.0
 */
package io.kaumei.jdbc.anno.gen;

import com.palantir.javapoet.CodeBlock;
import com.palantir.javapoet.MethodSpec;
import io.kaumei.jdbc.JdbcException;
import io.kaumei.jdbc.anno.ctx.Context;
import io.kaumei.jdbc.anno.model.SourceMethod;
import io.kaumei.jdbc.anno.msg.JdbcMsg;
import io.kaumei.jdbc.anno.msg.Msg;

import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeKind;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

public class GenerateJdbcNative implements GenerateJdbc {
    // ----- services
    private final Context ctx;
    // ----- state
    private final SourceMethod sourceMethod;
    private final TargetMethod targetMethod;
    private final Msg.Builder messages;

    GenerateJdbcNative(Context ctx, SourceMethod sourceMethod) {
        this.ctx = ctx;
        this.sourceMethod = sourceMethod;
        this.targetMethod = new TargetMethod(this.ctx, sourceMethod.method());
        this.messages = Msg.builder();
    }

    public MethodSpec generateMethod() {
        var annoProperties = sourceMethod.entryPoint().nativeJdbc();
        this.ctx.logger.debug("---- JdbcNative ----",
                "method", sourceMethod,
                "prop", annoProperties);

        var clsType = annoProperties.cls() == null
                ? (TypeElement) sourceMethod.method().getEnclosingElement()
                : this.ctx.asTypeElement(annoProperties.cls());
        var methodName = annoProperties.method() == null
                ? sourceMethod.method().getSimpleName().toString()
                : annoProperties.method();

        var nativeMethod = findNativeMethod(sourceMethod.method(), clsType, methodName);
        if (nativeMethod.hasMessages()) {
            messages.add(nativeMethod);
        } else {
            var otherMethod = nativeMethod.value();
            targetMethod.beginControlFlow("try");
            targetMethod.addStatement("var con = supplier.getConnection()");

            List<CodeBlock> args = new ArrayList<>();
            args.add(CodeBlock.of("$L", "con"));
            var size = otherMethod.getParameters().size() - 1;
            for (int i = 0; i < size; i++) {
                var p = sourceMethod.method().getParameters().get(i);
                args.add(targetMethod.paramCodeBlock(p.getSimpleName().toString()));
            }

            if (sourceMethod.method().getReturnType().getKind() == TypeKind.VOID) {
                targetMethod.addStatement("$T.$L($L)",
                        otherMethod.getEnclosingElement(),
                        otherMethod.getSimpleName(),
                        CodeBlock.join(args, ","));
            } else {
                targetMethod.addStatement("return $T.$L($L)",
                        otherMethod.getEnclosingElement(),
                        otherMethod.getSimpleName(),
                        CodeBlock.join(args, ","));
            }

            targetMethod.nextControlFlow("catch ($T e)", SQLException.class);
            targetMethod.addStatement("throw new $T(e.getMessage(), e)", JdbcException.class);
            targetMethod.endControlFlow();
        }
        this.messages.add(sourceMethod.unusedAnno());
        return this.targetMethod.build(this.messages.build(), "@JdbcNative method invalid");
    }

    private Msg.Result<ExecutableElement> findNativeMethod(ExecutableElement source,
                                                           TypeElement targetClass,
                                                           String methodName) {
        ExecutableElement found = null;
        var messages = Msg.builder();
        var hasCandidate = false;
        for (var element : targetClass.getEnclosedElements()) {
            if (element.getKind() == ElementKind.METHOD
                    && element instanceof ExecutableElement method
                    && methodName.contentEquals(method.getSimpleName())
                    && method.getModifiers().contains(Modifier.STATIC)) {
                hasCandidate = true;
                var cause = this.ctx.isCallable(source, method);
                if (cause.hasMessages()) {
                    messages.add(cause);
                    continue;
                }
                if (found != null) { // sanity-check
                    throw new IllegalStateException("Unexpected duplicate: " + found); // sanity-check
                }
                found = method;
            }
        }
        if (found != null) {
            return Msg.result(found);
        }
        if (!hasCandidate) {
            return Msg.result(JdbcMsg.jdbcNativeRequiresMatchingStaticMethod(targetClass, methodName));
        }
        var result = Msg.builder();
        result.add(JdbcMsg.jdbcNativeRequiresCallableStaticMethod(targetClass, methodName));
        result.add(messages.build());
        return Msg.result(result.build());
    }
}