GenerateService.java

/*
 * SPDX-FileCopyrightText: 2025 kaumei.io
 * SPDX-License-Identifier: Apache-2.0
 */
package io.kaumei.jdbc.anno.gen;

import com.palantir.javapoet.*;
import io.kaumei.jdbc.CodeGenerationException;
import io.kaumei.jdbc.anno.ProcessorEnvironment;
import io.kaumei.jdbc.anno.ProcessorSteps;
import io.kaumei.jdbc.anno.ctx.Context;
import io.kaumei.jdbc.anno.model.SourceMethod;
import io.kaumei.jdbc.anno.model.SourceMethodParameter;
import io.kaumei.jdbc.anno.msg.JdbcMsg;
import io.kaumei.jdbc.anno.msg.Msg;
import io.kaumei.jdbc.anno.store.SourceDV;
import io.kaumei.jdbc.anno.utils.SqlTokenizer;
import org.jspecify.annotations.Nullable;

import javax.lang.model.element.*;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.type.TypeVariable;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;

import static io.kaumei.jdbc.anno.ctx.JavaModelUtils.isTopLevel;
import static io.kaumei.jdbc.anno.utils.PrintStackTrace.appendStackTrace;
import static java.util.Objects.requireNonNull;

public class GenerateService implements ProcessorSteps {

    // ----- services
    final Context ctx;

    public GenerateService(Context ctx) {
        this.ctx = requireNonNull(ctx);
    }

    // ------------------------------------------------------------------------

    @Override
    public void process(ProcessorEnvironment roundEnv) {
        for (var entry : roundEnv.jdbcInterfaces()) {
            if (isTopLevel(entry)) {
                this.ctx.logger.acceptWithDebugFlag(entry, this::generateImplementation);
            } else {
                // batch interfaces are processed separately
                this.ctx.logger.debug("Skip interface because it is not top level:", entry);
            }
        }
    }

    // ------------------------------------------------------------------------

    private void generateImplementation(TypeElement iface) {
        this.ctx.logger.info("Process interface", iface);
        var implBuilder = this.create(iface);
        // ----- process methods
        for (var child : iface.getEnclosedElements()) {
            if (!child.getModifiers().contains(Modifier.STATIC)
                    && !child.getModifiers().contains(Modifier.DEFAULT)
                    && child.getKind() == ElementKind.METHOD
                    && child instanceof ExecutableElement method) {
                generateMethod(implBuilder, method);
            }
        }
        // ----- write to disk
        this.ctx.filer.writeJava(iface, implBuilder.packageName(), implBuilder.build());
    }

    void generateMethod(KaumeiClassBuilder implBuilder, ExecutableElement method0) {
        this.ctx.logger.acceptWithDebugFlag(method0, (method) -> {
            try {
                implBuilder.addMethod(generator(implBuilder, method));
            } catch (Exception e) {
                // sanity-check:on
                this.ctx.logger.warn("Catch exception.", e);
                var sb = new StringBuilder();
                sb.append("Annotation processing caught internal exception:\n");
                appendStackTrace(sb, e);
                implBuilder.addMethod(MethodSpec.overriding(method)
                        .addStatement("throw new $T($S)", CodeGenerationException.class, sb.toString())
                        .build());
                // sanity-check:off
            }
        });
    }

    private MethodSpec generator(KaumeiClassBuilder implBuilder, ExecutableElement method) {
        this.ctx.logger.debug("generateMethod", method);
        var sourceMethod = ctx.sourceMethodService.get(method);

        if (sourceMethod.hasMessages()) {
            return new GenerateJdbcThrowCodeGeneration(this.ctx, method).generateMethod(sourceMethod.messages(), invalidMethodHeader(sourceMethod));
        } else if (sourceMethod.entryPoint().selectOpt() != null) {
            return new GenerateJdbcSelect(this.ctx, sourceMethod).generateMethod();
        } else if (sourceMethod.entryPoint().updateOpt() != null) {
            return new GenerateJdbcUpdate(this.ctx, sourceMethod).generateMethod();
        } else if (sourceMethod.entryPoint().nativeJdbcOpt() != null) {
            return new GenerateJdbcNative(this.ctx, sourceMethod).generateMethod();
        } else if (sourceMethod.entryPoint().batchUpdateOpt() != null) {
            return new GenerateJdbcBatchUpdate(this.ctx, implBuilder, sourceMethod).generateMethod();
        } else {
            throw new IllegalStateException("Illegal state: " + sourceMethod.hasMessages() + " and " + sourceMethod.entryPoint()); // sanity-check
        }
    }

    private static String invalidMethodHeader(SourceMethod sourceMethod) {
        var entryPoint = sourceMethod.entryPoint();
        if (entryPoint.selectOpt() != null) {
            return "@JdbcSelect method invalid";
        } else if (entryPoint.updateOpt() != null) {
            return "@JdbcUpdate method invalid";
        } else if (entryPoint.nativeJdbcOpt() != null) {
            return "@JdbcNative method invalid";
        } else if (entryPoint.batchUpdateOpt() != null) {
            return "@JdbcBatchUpdate method invalid";
        }
        return "JDBC method invalid";
    }

    // ------------------------------------------------------------------------

    public KaumeiClassBuilder create(TypeElement iface) {
        var packageName = this.ctx.getPackageOf(iface).getQualifiedName().toString();
        return new KaumeiClassBuilder(this.ctx, packageName, iface);
    }

    // ------------------------------------------------------------------------

    /**
     * @return the AnnotationSpec if the mirror, return null if it was a Kaumei internal annotation
     */
    @Nullable
    AnnotationSpec annotationSpec(AnnotationMirror mirror) {
        return this.ctx.hasKaumeiPkg(mirror.getAnnotationType().asElement())
                ? null
                : AnnotationSpec.get(mirror);
    }

    private TypeName typeNameWithAnnotations(TypeMirror typeMirror) {
        var annotationMirrors = typeMirror.getAnnotationMirrors();
        if (annotationMirrors.isEmpty()) {
            return TypeName.get(typeMirror);
        }
        List<AnnotationSpec> specs = new ArrayList<>();
        for (AnnotationMirror annoMirror : annotationMirrors) {
            var spec = this.annotationSpec(annoMirror);
            if (spec != null) {
                specs.add(spec);
            }
        }
        return specs.isEmpty() ? TypeName.get(typeMirror) : TypeName.get(typeMirror).annotated(specs);
    }

    // ------------------------------------------------------------------------

    MethodSpec.Builder createMethodBuilder(ExecutableElement method, TargetMethod target) {
        var methodName = method.getSimpleName().toString();
        var methodBuilder = MethodSpec.methodBuilder(methodName);

        // ----- method annotations
        methodBuilder.addAnnotation(Override.class);
        for (var mirror : method.getAnnotationMirrors()) {
            if (!this.ctx.hasKaumeiPkg(mirror.getAnnotationType().asElement())) {
                methodBuilder.addAnnotation(AnnotationSpec.get(mirror));
            }
        }

        // ----- modifiers
        var modifiers = new LinkedHashSet<>(method.getModifiers());
        modifiers.remove(Modifier.ABSTRACT);
        modifiers.remove(Modifier.DEFAULT);
        methodBuilder.addModifiers(modifiers);

        // ----- type parameter
        for (TypeParameterElement typeParameterElement : method.getTypeParameters()) {
            TypeVariable var = (TypeVariable) typeParameterElement.asType();
            methodBuilder.addTypeVariable(TypeVariableName.get(var));
        }

        // ----- return type with annotations
        methodBuilder.returns(this.typeNameWithAnnotations(method.getReturnType()));

        // ----- parameter
        for (VariableElement param : method.getParameters()) {
            var type = this.typeNameWithAnnotations(param.asType());
            var name = target.paramName(param.getSimpleName().toString());
            var paramBuilder = ParameterSpec.builder(type, name);
            paramBuilder.addModifiers(param.getModifiers());
            for (AnnotationMirror mirror : param.getAnnotationMirrors()) {
                var spec = this.annotationSpec(mirror);
                if (spec != null) {
                    paramBuilder.addAnnotation(AnnotationSpec.get(mirror));
                }
            }
            methodBuilder.addParameter(paramBuilder.build());
        }
        methodBuilder.varargs(method.isVarArgs());

        // ----- throws
        for (TypeMirror thrownType : method.getThrownTypes()) {
            methodBuilder.addException(TypeName.get(thrownType));
        }

        return methodBuilder;
    }

    // ------------------------------------------------------------------------

    public CodeBlock buildSqlVariable(SourceMethod source, TargetMethod target, String variableName) {
        source.messages().requireEmpty();
        var code = CodeBlock.builder();
        var maxCollectionPlaceholders = this.ctx.kaumeiConfig.maxCollectionPlaceholders();
        var maxTotalPlaceholders = this.ctx.kaumeiConfig.maxTotalPlaceholders();
        var parameter = source.parameter();

        code.addStatement("var plCount = 0");
        code.addStatement("var sqlSb = new $T()", StringBuilder.class);

        for (var token : parameter.tokens()) {
            if (token instanceof SqlTokenizer.TextToken tt) {
                code.addStatement("sqlSb.append($S)", tt.sql());
            } else if (token instanceof SqlTokenizer.SingleParameter sqlParam) {
                code.addStatement("sqlSb.append($S)", "?");
                var length = CodeBlock.of("1");
                code.addStatement("plCount = $L", KaumeiLib.checkPlaceholder("plCount", length, maxTotalPlaceholders));
            } else if (token instanceof SqlTokenizer.AllValues sqlParam) {
                var name = sqlParam.root();
                var param = parameter.getValues(name);
                if (param instanceof SourceMethodParameter.ParamArray pa) {
                    var length = CodeBlock.of("$L.length", KaumeiLib.requireNonNull(target.paramCodeBlock(name)));
                    code.addStatement("plCount = $L", KaumeiLib.checkPlaceholder("plCount", length, maxTotalPlaceholders));
                    code.addStatement("sqlSb.append($L)", KaumeiLib.marks(length, maxCollectionPlaceholders, name));
                } else if (param instanceof SourceMethodParameter.ParamList pl) {
                    var length = CodeBlock.of("$L.size()", KaumeiLib.requireNonNull(target.paramCodeBlock(name)));
                    code.addStatement("plCount = $L", KaumeiLib.checkPlaceholder("plCount", length, maxTotalPlaceholders));
                    code.addStatement("sqlSb.append($L)", KaumeiLib.marks(length, maxCollectionPlaceholders, name));
                } else if (param instanceof SourceMethodParameter.ParamComposite pc) {
                    for (int i = 0; i < pc.searchRequests().length; i++) {
                        if (i > 0) {
                            code.addStatement("sqlSb.append($S)", ",");
                        }
                        code.addStatement("sqlSb.append($S)", "?");
                    }
                    var length = CodeBlock.of("" + pc.searchRequests().length);
                    code.addStatement("plCount = $L", KaumeiLib.checkPlaceholder("plCount", length, maxTotalPlaceholders));
                } else {
                    throw new IllegalStateException("Unexpected parameter: " + param); // sanity-check
                }
            } else if (token instanceof SqlTokenizer.AllNames sqlParam) {
                var name = sqlParam.root();
                var param = parameter.getValues(name);
                if (param instanceof SourceMethodParameter.ParamComposite pc) {
                    for (int i = 0; i < pc.searchRequests().length; i++) {
                        if (i > 0) {
                            code.addStatement("sqlSb.append($S)", ",");
                        }
                        code.addStatement("sqlSb.append($S)", ctx.jdbcName(pc.variable()[i]).value());
                    }
                } else {
                    throw new IllegalStateException("type not expected: " + param); // sanity-check
                }
            }
        } // for
        code.addStatement("var $L = sqlSb.toString()", variableName);
        return code.build();
    }

    // ------------------------------------------------------------------------

    public void processParameter(Msg.Builder messages, SourceMethod source, TargetMethod body) {
        this.ctx.logger.debug("processParameter");
        var parameter = source.parameter();

        if (parameter.hasCollections()) {
            body.addStatement("var index = 1");
        }

        int sqlIndex = 1;
        for (var token : parameter.tokens()) {
            if (token instanceof SqlTokenizer.SingleParameter sqlParam) {
                var name = sqlParam.root();
                var param = parameter.getSingle(name);
                var indexCode = parameter.hasCollections() ? CodeBlock.of("index++") : CodeBlock.of("$L", sqlIndex++);
                var converter = ctx.kaumeiJava2Jdbc.searchJava(source.method(), param.searchRequest());
                if (converter.hasMessages()) {
                    messages.add(JdbcMsg.invalidSqlParameterConverter(
                            sqlParam.name(), SourceDV.converter(ctx, param.element()), converter.messages()));
                } else {
                    converter.value().setParameter(body, body.paramCodeBlock(name), indexCode, param.optional());
                }
            } else if (token instanceof SqlTokenizer.AllValues sqlParam) {
                var name = sqlParam.root();
                var param = parameter.getValues(name);
                if (param instanceof SourceMethodParameter.ParamArray pa) {
                    body.addStatement("// array");
                    var indexCode = parameter.hasCollections() ? CodeBlock.of("index++") : CodeBlock.of("$L", sqlIndex++);
                    var item = body.tempVarName(name + "Item");
                    body.beginControlFlow("for (var $L : $L)", item, body.paramCodeBlock(name));
                    var converter = ctx.kaumeiJava2Jdbc.searchJava(source.method(), pa.searchRequest());
                    if (converter.hasMessages()) {
                        messages.add(JdbcMsg.invalidSqlParameterConverter(
                                sqlParam.name(), SourceDV.converter(ctx, param.element()), converter.messages()));
                    } else {
                        converter.value().setParameter(body, CodeBlock.of("$L", item), indexCode, pa.cmpOptional());
                    }
                    body.endControlFlow();
                } else if (param instanceof SourceMethodParameter.ParamList pl) {
                    body.addStatement("// list");
                    var indexCode = parameter.hasCollections() ? CodeBlock.of("index++") : CodeBlock.of("$L", sqlIndex++);
                    var item = body.tempVarName(name + "Item");
                    body.beginControlFlow("for (var $L : $L)", item, body.paramCodeBlock(name));
                    var converter = ctx.kaumeiJava2Jdbc.searchJava(source.method(), pl.searchRequest());
                    if (converter.hasMessages()) {
                        messages.add(JdbcMsg.invalidSqlParameterConverter(
                                sqlParam.name(), SourceDV.converter(ctx, param.element()), converter.messages()));
                    } else {
                        converter.value().setParameter(body, CodeBlock.of("$L", item), indexCode, pl.cmpOptional());
                    }
                    body.endControlFlow();
                } else if (param instanceof SourceMethodParameter.ParamComposite pc) {
                    for (int i = 0; i < pc.names().length; i++) {
                        var access = CodeBlock.of("$L.$L()", body.paramCodeBlock(name), pc.variable()[i].getSimpleName());
                        var indexCode = parameter.hasCollections() ? CodeBlock.of("index++") : CodeBlock.of("$L", sqlIndex++);
                        var converter = ctx.kaumeiJava2Jdbc.searchJava(source.method(), pc.searchRequests()[i]);
                        if (converter.hasMessages()) {
                            messages.add(JdbcMsg.invalidSqlParameterConverter(
                                    sqlParam.name(), SourceDV.converter(ctx, pc.variable()[i]), converter.messages()));
                        } else {
                            converter.value().setParameter(body, access, indexCode, pc.optionals()[i]);
                        }
                    }
                } else {
                    throw new IllegalStateException("Unexpected parameter: " + param); // sanity-check
                }
            }
        }
    }
}