GenerateJdbcSelect.java

/*
 * SPDX-FileCopyrightText: 2025 kaumei.io
 * SPDX-License-Identifier: Apache-2.0
 */
package io.kaumei.jdbc.anno.gen;

import com.palantir.javapoet.CodeBlock;
import com.palantir.javapoet.MethodSpec;
import io.kaumei.jdbc.JdbcException;
import io.kaumei.jdbc.anno.ctx.ConfigService;
import io.kaumei.jdbc.anno.ctx.Context;
import io.kaumei.jdbc.anno.jdbc2java.ColumnIndex;
import io.kaumei.jdbc.anno.model.*;
import io.kaumei.jdbc.anno.msg.JdbcMsg;
import io.kaumei.jdbc.anno.msg.Msg;
import io.kaumei.jdbc.anno.store.SearchKey;
import io.kaumei.jdbc.anno.store.SourceDV;
import io.kaumei.jdbc.annotation.config.JdbcNoMoreRows;
import io.kaumei.jdbc.annotation.config.JdbcNoRows;
import io.kaumei.jdbc.annotation.config.JdbcResultSetConcurrency;
import io.kaumei.jdbc.annotation.config.JdbcResultSetType;
import io.kaumei.jdbc.impl.JdbcUtils;
import io.kaumei.jdbc.impl.ResultSetUtils;
import org.jspecify.annotations.Nullable;

import javax.lang.model.element.ExecutableElement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;

class GenerateJdbcSelect implements GenerateJdbc {
    // ----- services
    private final Context ctx;
    // ----- state
    private final SourceMethod sourceMethod;
    private final TargetMethod targetMethod;
    private final Msg.Builder messages;

    GenerateJdbcSelect(Context ctx, SourceMethod sourceMethod) {
        this.ctx = ctx;
        this.sourceMethod = sourceMethod;
        this.targetMethod = new TargetMethod(this.ctx, sourceMethod.method());
        this.messages = Msg.builder();
    }

    // ------------------------------------------------------------------------

    public MethodSpec generateMethod() {
        this.ctx.logger.debug("---- JdbcSelect ---- ",
                "method", sourceMethod,
                "returnType", sourceMethod.method().getReturnType());

        var returnType = sourceMethod.returnType();
        switch (returnType.kind()) { // JaCoCo:ignore
            case VOID:
                messages.add(JdbcMsg.jdbcSelectRequiresSupportedReturnType(returnType));
                break;
            case BOOLEAN, BYTE, SHORT, INT, LONG, CHAR, FLOAT, DOUBLE,
                 RECORD, ENUM, UNKNOWN, ARRAY, KAUMEI_JDBC_ROW:
                selectValue(returnType);
                break;
            case LIST:
                selectJavaList(returnType);
                break;
            case STREAM, KAUMEI_JDBC_ITERABLE, KAUMEI_JDBC_RESULT_SET:
                selectStreamIterableResultSet(returnType);
                break;
            default:
                throw new IllegalStateException(sourceMethod + ": illegal kind:" + returnType.kind()); // sanity-check
        }
        this.messages.add(sourceMethod.unusedAnno());
        return this.targetMethod.build(this.messages.build(), "@JdbcSelect method invalid");
    }

    // ------------------------------------------------------------------------

    private void selectValue(JdbcTypeMirror<ExecutableElement> returnType) {
        var returnTypeOptional = returnType.optional();
        var request = SearchKey.of(returnType.type(), sourceMethod.jdbcConverterName());
        var noRows = this.sourceMethod.jdbcNoRows();
        if (returnType.kind() == JdbcTypeKind.KAUMEI_JDBC_ROW) {
            if (returnTypeOptional.isNonNull()) {
                this.messages.add(JdbcMsg.jdbcSelectJdbcRowRequiresNullableOrOptional(returnTypeOptional));
            } else {
                // adjust the search type and optional to the component of JdbcRow
                request = SearchKey.of(returnType.cmpType(), sourceMethod.jdbcConverterName());
                returnTypeOptional = returnType.cmpOptional();
            }
        } else if (noRows == JdbcNoRows.Kind.RETURN_NULL) {
            if (returnTypeOptional.isNonNull()) {
                this.messages.add(JdbcMsg.jdbcSelectNoRowsReturnNullRequiresNullableOrOptional(returnTypeOptional));
            } else if (returnTypeOptional.isOptionalType()) {
                this.messages.add(JdbcMsg.jdbcSelectNoRowsReturnNullRequiresNullable(returnTypeOptional));
            }
        }
        if (this.messages.hasMessages()) {
            return;
        }

        var searchResult = ctx.kaumeiJdbc2Java.searchJava(sourceMethod.method(), request);
        if (searchResult.hasMessages()) {
            this.messages.add(JdbcMsg.invalidJdbcToJavaResultConverter(
                    SourceDV.converter(ctx, sourceMethod.method()), request, searchResult.messages()));
            return;
        }

        targetMethod.beginControlFlow("try");
        targetMethod.addStatement("var con = supplier.getConnection()");
        targetMethod.addCodeBlock(this.ctx.kaumeiJdbcGenerator.buildSqlVariable(sourceMethod, targetMethod, "sql"));
        targetMethod.beginControlFlow("try (var stmt = con.prepareStatement(sql))");
        this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
        targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());

        var noMoreRows = this.sourceMethod.jdbcNoMoreRows();
        if (noMoreRows == JdbcNoMoreRows.Kind.THROW_EXCEPTION) {
            targetMethod.addStatement("stmt.setFetchSize(2)");
            targetMethod.addStatement("stmt.setMaxRows(2)");
        } else {
            targetMethod.addStatement("stmt.setFetchSize(1)");
            targetMethod.addStatement("stmt.setMaxRows(1)");
        }
        targetMethod.beginControlFlow("try (var rs = stmt.executeQuery())");

        targetMethod.addCheckNoRows(noRows, returnType.optional());
        var converter = searchResult.value();
        if (converter.isColumn()) {
            var jdbcName = sourceMethod.jdbcName();
            if (jdbcName.hasName()) {
                converter.addColumnByName(targetMethod, "result", jdbcName, returnTypeOptional);
            } else {
                converter.addColumnByIndex(targetMethod, "result", ColumnIndex.ofValue(1), returnTypeOptional);
            }
        } else {
            converter.addResultSetToRow(targetMethod, "result", returnTypeOptional);
        }
        targetMethod.addCheckNoMoreRows(noMoreRows);

        if (returnType.kind() == JdbcTypeKind.KAUMEI_JDBC_ROW) {
            if (returnType.optional().isOptionalType()) {
                targetMethod.addStatement("return $T.of(new $T(result))", Optional.class, returnType.type());
            } else {
                targetMethod.addStatement("return new $T(result)", returnType.type());
            }
        } else {
            targetMethod.addStatement("return result");
        }

        targetMethod.endControlFlow();
        targetMethod.endControlFlow();
        targetMethod.nextControlFlow("catch ($T e)", SQLException.class);
        targetMethod.addStatement("throw new $T(e.getMessage(), e)", JdbcException.class);
        targetMethod.endControlFlow();
    }

    private void selectJavaList(JdbcTypeMirror<ExecutableElement> returnType) {
        var optReason = returnType.optional().checkNonNullOrUnspecified();
        if (optReason != null) {
            this.messages.add(JdbcMsg.jdbcSelectRequiresNonNullOrUnspecifiedReturnType(returnType.optional()));
            return;
        }

        var request = SearchKey.of(returnType.cmpType(), sourceMethod.jdbcConverterName());
        var searchResult = ctx.kaumeiJdbc2Java.searchJava(sourceMethod.method(), request);
        this.ctx.logger.debug("selectJavaList", "returnType", returnType, "searchResult", searchResult);
        if (searchResult.hasMessages()) {
            this.messages.add(JdbcMsg.invalidJdbcToJavaResultConverter(
                    SourceDV.converter(ctx, sourceMethod.method()), request, searchResult.messages()));
            return;
        } else if (!searchResult.value().isColumn() && hasInvalidRowReturnComponent(returnType)) {
            return;
        }

        targetMethod.beginControlFlow("try");
        targetMethod.addStatement("var con = supplier.getConnection()");
        targetMethod.addCodeBlock(this.ctx.kaumeiJdbcGenerator.buildSqlVariable(sourceMethod, targetMethod, "sql"));
        targetMethod.beginControlFlow("try (var stmt = $L)", prepareStatement(sourceMethod.jdbcResultSetType(), sourceMethod.jdbcResultSetConcurrency()));
        this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
        targetMethod.addIfAnnotationIsPresent("stmt.setFetchDirection($L.sqlMagicNumber())", sourceMethod.jdbcFetchDirection());
        targetMethod.addIfAnnotationIsPresent("stmt.setFetchSize($L)", sourceMethod.jdbcFetchSize());
        targetMethod.addIfAnnotationIsPresent("stmt.setMaxRows($L)", sourceMethod.jdbcMaxRows());
        targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());
        targetMethod.beginControlFlow("try (var resultSet = stmt.executeQuery())");

        var jdbcName = searchResult.value().isColumn() ? sourceMethod.jdbcName() : SQLNameDV.noName();
        var lambda = targetMethod.lambda(jdbcName, returnType.cmpOptional(), searchResult.value());
        targetMethod.addStatement("return $T.toList(resultSet, $L)", ResultSetUtils.class, lambda.toString());

        targetMethod.endControlFlow();
        targetMethod.endControlFlow();
        targetMethod.nextControlFlow("catch ($T e)", SQLException.class);
        targetMethod.addStatement("throw new $T(e.getMessage(), e)", JdbcException.class);
        targetMethod.endControlFlow();
    }

    private void selectStreamIterableResultSet(JdbcTypeMirror<ExecutableElement> returnType) {
        var optReason = returnType.optional().checkNonNullOrUnspecified();
        if (optReason != null) {
            this.messages.add(JdbcMsg.jdbcSelectRequiresNonNullOrUnspecifiedReturnType(returnType.optional()));
            return;
        }
        var request = SearchKey.of(returnType.cmpType(), sourceMethod.jdbcConverterName());
        var searchResult = ctx.kaumeiJdbc2Java.searchJava(sourceMethod.method(), request);
        if (searchResult.hasMessages()) {
            this.messages.add(JdbcMsg.invalidJdbcToJavaResultConverter(
                    SourceDV.converter(ctx, sourceMethod.method()), request, searchResult.messages()));
            return;
        } else if (!searchResult.value().isColumn() && hasInvalidRowReturnComponent(returnType)) {
            return;
        }

        targetMethod.addComment("stream");

        targetMethod.addStatement("$T stmt = null", PreparedStatement.class);
        targetMethod.addStatement("$T resultSet = null", ResultSet.class);
        targetMethod.beginControlFlow("try");
        targetMethod.addStatement("var con = supplier.getConnection()");
        targetMethod.addCodeBlock(this.ctx.kaumeiJdbcGenerator.buildSqlVariable(sourceMethod, targetMethod, "sql"));
        targetMethod.addStatement("stmt = $L", prepareStatement(sourceMethod.jdbcResultSetType(), sourceMethod.jdbcResultSetConcurrency()));
        this.ctx.kaumeiJdbcGenerator.processParameter(messages, sourceMethod, targetMethod);
        targetMethod.addIfAnnotationIsPresent("stmt.setFetchDirection($L.sqlMagicNumber())", sourceMethod.jdbcFetchDirection());
        targetMethod.addIfAnnotationIsPresent("stmt.setFetchSize($L)", sourceMethod.jdbcFetchSize());
        targetMethod.addIfAnnotationIsPresent("stmt.setMaxRows($L)", sourceMethod.jdbcMaxRows());
        targetMethod.addIfAnnotationIsPresent("stmt.setQueryTimeout($L)", this.sourceMethod.jdbcQueryTimeout());
        targetMethod.addStatement("resultSet = stmt.executeQuery()");

        var jdbcName = searchResult.value().isColumn() ? sourceMethod.jdbcName() : SQLNameDV.noName();
        var lambda = targetMethod.lambda(jdbcName, returnType.cmpOptional(), searchResult.value());

        switch (returnType.kind()) { // JaCoCo:ignore
            case STREAM ->
                    targetMethod.addStatement("return $T.toStream(stmt, resultSet, $L)", ResultSetUtils.class, lambda.toString());
            case KAUMEI_JDBC_RESULT_SET ->
                    targetMethod.addStatement("return $T.toJdbcResultSet(stmt, resultSet, $L)", ResultSetUtils.class, lambda.toString());
            case KAUMEI_JDBC_ITERABLE ->
                    targetMethod.addStatement("return $T.toJdbcIterable(stmt, resultSet, $L)", ResultSetUtils.class, lambda.toString());
            default -> // sanity-check
                    throw new IllegalStateException("Invalid kind: " + returnType.kind()); // sanity-check
        }

        targetMethod.nextControlFlow("catch ($T e)", Exception.class);
        targetMethod.addStatement("$T.close(e, stmt, resultSet)", JdbcUtils.class);
        targetMethod.addStatement("throw e instanceof $T re ? re :new $T(e.getMessage(), e)", RuntimeException.class, JdbcException.class);
        targetMethod.endControlFlow();
    }

    // -----------------------------------------------------------------

    private boolean hasInvalidRowReturnComponent(JdbcTypeMirror<ExecutableElement> returnType) {
        if (returnType.cmpOptional().isNonNullOrUnspecified()) {
            return false;
        }
        this.messages.add(JdbcMsg.jdbcSelectRequiresNonNullOrUnspecifiedRowReturnComponent(returnType.cmpOptional()));
        return true;
    }

    // -----------------------------------------------------------------

    private CodeBlock prepareStatement(ConfigService.@Nullable OptionValue<JdbcResultSetType.Kind> resultSetType,
                                       ConfigService.@Nullable OptionValue<JdbcResultSetConcurrency.Kind> resultSetConcurrency) {
        if (resultSetType == null && resultSetConcurrency == null) {
            return CodeBlock.of("con.prepareStatement(sql)");
        } else if (resultSetType != null && resultSetConcurrency != null) {
            if (resultSetType instanceof ConfigService.OptionValue.Dynamic d) {
                this.targetMethod.addCodeBlock(targetMethod.checkValue(d));
            }
            if (resultSetConcurrency instanceof ConfigService.OptionValue.Dynamic d) {
                this.targetMethod.addCodeBlock(targetMethod.checkValue(d));
            }
            return CodeBlock.of("con.prepareStatement(sql, $L.sqlMagicNumber(), $L.sqlMagicNumber())",
                    targetMethod.accessValue(resultSetType),
                    targetMethod.accessValue(resultSetConcurrency));
        }
        this.messages.add(JdbcMsg.JDBC_SELECT_REQUIRES_RESULT_SET_TYPE_AND_CONCURRENCY);
        return CodeBlock.of("null");
    }
}