GenerateJdbcUpdate.java

/*
 * SPDX-FileCopyrightText: 2025 kaumei.io
 * SPDX-License-Identifier: Apache-2.0
 */
package io.kaumei.jdbc.anno.gen;

import com.palantir.javapoet.MethodSpec;
import io.kaumei.jdbc.JdbcException;
import io.kaumei.jdbc.anno.ProcessorException;
import io.kaumei.jdbc.anno.ctx.Context;
import io.kaumei.jdbc.anno.jdbc2java.ColumnIndex;
import io.kaumei.jdbc.anno.model.JdbcTypeKind;
import io.kaumei.jdbc.anno.model.JdbcTypeMirror;
import io.kaumei.jdbc.anno.model.SourceMethod;
import io.kaumei.jdbc.anno.msg.JdbcMsg;
import io.kaumei.jdbc.anno.msg.Msg;
import io.kaumei.jdbc.anno.store.SearchKey;
import io.kaumei.jdbc.anno.store.SourceDV;
import io.kaumei.jdbc.annotation.JdbcUpdate;
import io.kaumei.jdbc.annotation.config.JdbcNoMoreRows;
import io.kaumei.jdbc.annotation.config.JdbcNoRows;

import javax.lang.model.element.ExecutableElement;
import java.sql.SQLException;
import java.sql.Statement;

public class GenerateJdbcUpdate implements GenerateJdbc {
    // ----- services
    private final Context ctx;
    // ----- state
    private final SourceMethod sourceMethod;
    private final TargetMethod targetMethod;
    private final Msg.Builder messages;

    GenerateJdbcUpdate(Context ctx, SourceMethod sourceMethod) {
        this.ctx = ctx;
        this.sourceMethod = sourceMethod;
        this.targetMethod = new TargetMethod(this.ctx, sourceMethod.method());
        this.messages = Msg.builder();
    }

    // ------------------------------------------------------------------------

    public MethodSpec generateMethod() {
        this.ctx.logger.debug("---- @JdbcUpdate ----",
                "method", sourceMethod,
                "returnType", sourceMethod.method().getReturnType());

        var returnType = sourceMethod.returnType();

        var gen = sourceMethod.jdbcReturnGeneratedValues();
        if (gen == JdbcUpdate.GeneratedValues.GENERATED_KEYS || gen == JdbcUpdate.GeneratedValues.EXECUTE_QUERY) {
            updateReturning(returnType, gen);
        } else if (returnType.kind() == JdbcTypeKind.VOID
                || returnType.kind() == JdbcTypeKind.INT
                || returnType.kind() == JdbcTypeKind.BOOLEAN) {
            updateSimple();
        } else {
            this.messages.add(JdbcMsg.jdbcUpdateRequiresSupportedReturnType(returnType));
        }

        this.messages.add(sourceMethod.unusedAnno());
        return this.targetMethod.build(this.messages.build(), "@JdbcUpdate method invalid");
    }

    private void updateSimple() {
        this.ctx.logger.debug("updateSimple");

        targetMethod.beginControlFlow("try");
        targetMethod.addStatement("var con = supplier.getConnection()");
        targetMethod.addCodeBlock(this.ctx.kaumeiJdbcGenerator.buildSqlVariable(sourceMethod, targetMethod, "sql"));
        targetMethod.beginControlFlow("try (var stmt = con.prepareStatement(sql))");
        this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
        targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());
        // ----
        switch (sourceMethod.method().getReturnType().getKind()) { // JaCoCo:ignore
            case VOID -> targetMethod.addStatement("stmt.executeUpdate()");
            case INT -> targetMethod.addStatement("return stmt.executeUpdate()");
            case BOOLEAN -> targetMethod.addStatement("return stmt.executeUpdate() != 0");
            default ->
                    throw new IllegalStateException("Unexpected return type: " + sourceMethod.method().getReturnType().getKind()); // sanity-check
        }
        // ----
        targetMethod.endControlFlow();
        targetMethod.nextControlFlow("catch ($T e)", SQLException.class);
        targetMethod.addStatement("throw new $T(e.getMessage(), e)", JdbcException.class);
        targetMethod.endControlFlow();
    }

    private void updateReturning(JdbcTypeMirror<ExecutableElement> resultType,
                                 JdbcUpdate.GeneratedValues jdbcReturnGeneratedValues) {
        this.ctx.logger.debug("updateReturning", "resultType", resultType, "jdbcReturnGeneratedValues", jdbcReturnGeneratedValues);

        switch (resultType.kind()) {
            case VOID,
                 KAUMEI_JDBC_ITERABLE, KAUMEI_JDBC_RESULT_SET, KAUMEI_JDBC_BATCH,
                 ARRAY, LIST, STREAM:
                this.messages.add(JdbcMsg.jdbcUpdateGeneratedValuesRequiresSupportedReturnType(resultType));
                return;
        }

        var request = SearchKey.of(resultType.type(), sourceMethod.jdbcConverterName());
        var searchResult = ctx.kaumeiJdbc2Java.searchJava(sourceMethod.method(), request);
        if (searchResult.hasMessages()) {
            this.messages.add(JdbcMsg.invalidJdbcToJavaResultConverter(
                    SourceDV.converter(ctx, sourceMethod.method()), request, searchResult.messages()));
            return;
        }

        var optReason = resultType.optional().checkNonNullOrUnspecified();
        if (optReason != null) {
            this.messages.add(JdbcMsg.jdbcUpdateGeneratedValuesRequiresNonNullOrUnspecifiedReturnType(resultType.optional()));
            return;
        }

        targetMethod.beginControlFlow("try");
        targetMethod.addStatement("var con = supplier.getConnection()");
        targetMethod.addCodeBlock(this.ctx.kaumeiJdbcGenerator.buildSqlVariable(sourceMethod, targetMethod, "sql"));
        switch (jdbcReturnGeneratedValues) { // JaCoCo:ignore
            case GENERATED_KEYS -> {
                targetMethod.beginControlFlow("try (var stmt = con.prepareStatement(sql, $T.RETURN_GENERATED_KEYS))", Statement.class);
                this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
                targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());
                targetMethod.addStatement("stmt.executeUpdate()");
                targetMethod.beginControlFlow("try(var rs = stmt.getGeneratedKeys())");
            }
            case EXECUTE_QUERY -> {
                targetMethod.beginControlFlow("try (var stmt = con.prepareStatement(sql))");
                this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
                targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());
                targetMethod.beginControlFlow("try(var rs = stmt.executeQuery())");
            }
            default ->
                    throw new ProcessorException("Unexpected return type: " + jdbcReturnGeneratedValues); // sanity-check
        }

        var converter = searchResult.value();
        this.targetMethod.addCheckNoRows(JdbcNoRows.Kind.THROW_EXCEPTION, resultType.optional());
        if (converter.isColumn()) {
            converter.addColumnByIndex(this.targetMethod, "result", ColumnIndex.ofValue(1), resultType.optional());
        } else {
            converter.addResultSetToRow(this.targetMethod, "result", resultType.optional());
        }
        this.targetMethod.addCheckNoMoreRows(JdbcNoMoreRows.Kind.THROW_EXCEPTION);
        this.targetMethod.addStatement("return result");

        targetMethod.endControlFlow();
        targetMethod.endControlFlow();
        targetMethod.nextControlFlow("catch ($T e)", SQLException.class);
        targetMethod.addStatement("throw new $T(e.getMessage(), e)", JdbcException.class);
        targetMethod.endControlFlow();
    }

}