private Variable visitProjection()

in codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java [236:490]


    private Variable visitProjection(ProjectionExpression expr, Variable current) {
        var left = visit(expr.getLeft(), current);
        if (expr.getRight() instanceof CurrentExpression) { // e.g. "Field[]" - the projection is just itself
            return left;
        }

        Shape leftMember;
        if (left.shape instanceof CollectionShape col) {
            leftMember = expectMember(col);
        } else if (left.shape instanceof MapShape map) {
            leftMember = expectMember(map);
        } else {
            // left of projection HAS to be an array/map by spec, otherwise something is wrong
            throw new CodegenException("projection did not create a list: " + expr);
        }

        var leftSymbol = ctx.symbolProvider().toSymbol(leftMember);

        // We have to know the element type for the list that we're generating, use a dummy writer to "peek" ahead and
        // get the traversal result
        var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter(""))
                .generate(expr.getRight(), new Variable(leftMember, "v", leftSymbol));

        var ident = nextIdent();
        writer.write("""
                var $L []$T
                for _, v := range $L {""", ident, ctx.symbolProvider().toSymbol(lookahead.shape), left.ident);

        writer.indent();
        // projected.shape is the _member_ of the resulting list
        var projected = visit(expr.getRight(), new Variable(leftMember, "v", leftSymbol));
        if (isPointable(lookahead.type)) { // projections implicitly filter out nil evaluations of RHS...
            var deref = lookahead.shape instanceof CollectionShape || lookahead.shape instanceof MapShape
                    ? "" : "*"; // ...but slices/maps do not get dereferenced
            writer.write("""
                    if $1L != nil {
                        $2L = append($2L, $3L$1L)
                    }""", projected.ident, ident, deref);
        } else {
            writer.write("$1L = append($1L, $2L)", ident, projected.ident);
        }
        writer.dedent();
        writer.write("}");

        return new Variable(listOf(projected.shape), ident, sliceOf(ctx.symbolProvider().toSymbol(projected.shape)));
    }

    private Variable visitSub(Subexpression expr, Variable current) {
        var left = visit(expr.getLeft(), current);
        if (!isNilable(left.type)) {
            return visit(expr.getRight(), left);
        }

        var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter(""))
                .generate(expr.getRight(), left);
        var ident = nextIdent();
        writer.write("var $L $P", ident, lookahead.type);
        writer.write("if $L != nil {", left.ident);
        writer.indent();
        var inner = visit(expr.getRight(), left);
        writer.write("$L = $L", ident, inner.ident);
        writer.dedent();
        writer.write("}");
        return new Variable(inner.shape, ident, inner.type);
    }

    private Variable visitField(FieldExpression expr, Variable current) {
        var member = current.shape.getMember(expr.getName()).orElseThrow(() ->
                new CodegenException("field expression referenced nonexistent member: " + expr.getName()));

        var target = ctx.model().expectShape(member.getTarget());
        var ident = nextIdent();
        writer.write("$L := $L.$L", ident, current.ident, capitalize(expr.getName()));
        return new Variable(target, ident, ctx.symbolProvider().toSymbol(member));
    }

    private Variable visitFunction(FunctionExpression expr, Variable current) {
        return switch (expr.name) {
            case "keys" -> visitKeysFunction(expr.arguments, current);
            case "length" -> visitLengthFunction(expr.arguments, current);
            case "contains" -> visitContainsFunction(expr.arguments, current);
            default -> throw new CodegenException("unsupported function " + expr.name);
        };
    }

    private Variable visitContainsFunction(List<JmespathExpression> args, Variable current) {
        if (args.size() != 2) {
            throw new CodegenException("unexpected contains() arg length " + args.size());
        }

        var list = visit(args.get(0), current);
        var item = visit(args.get(1), current);
        var ident = nextIdent();
        writer.write("""
                var $1L bool
                for _, v := range $2L {
                    if v == $3L {
                        $1L = true
                        break
                    }
                }""", ident, list.ident, item.ident);
        return new Variable(BOOL_SHAPE, ident, GoUniverseTypes.Bool);
    }

    private Variable visitLengthFunction(List<JmespathExpression> args, Variable current) {
        if (args.size() != 1) {
            throw new CodegenException("unexpected length() arg length " + args.size());
        }

        var arg = visit(args.get(0), current);
        var ident = nextIdent();

        // length() can be used on a string (so also *string) - dereference if required
        if (arg.shape instanceof StringShape && isPointable(arg.type)) {
            writer.write("""
                    var _$1L string
                    if $1L != nil {
                        _$1L = *$1L
                    }
                    $2L := len(_$1L)""", arg.ident, ident);
        } else {
            writer.write("$L := len($L)", ident, arg.ident);
        }

        return new Variable(INT_SHAPE, ident, GoUniverseTypes.Int);
    }

    private Variable visitKeysFunction(List<JmespathExpression> args, Variable current) {
        if (args.size() != 1) {
            throw new CodegenException("unexpected keys() arg length " + args.size());
        }

        var arg = visit(args.get(0), current);
        ++idIndex;
        writer.write("""
                var v$1L []string
                for k := range $2L {
                    v$1L = append(v$1L, k)
                }""", idIndex, arg.ident);

        return new Variable(listOf(STRING_SHAPE), "v" + idIndex, sliceOf(GoUniverseTypes.String));
    }

    private String nextIdent() {
        ++idIndex;
        return "v" + idIndex;
    }

    private Shape listOf(Shape shape) {
        var list = ShapeUtil.listOf(shape);
        synthetics.putIfAbsent(list, shape);
        return list;
    }

    private Shape expectMember(CollectionShape shape) {
        return synthetics.containsKey(shape)
                ? synthetics.get(shape)
                : ShapeUtil.expectMember(ctx.model(), shape);
    }

    private Shape expectMember(MapShape shape) {
        return synthetics.containsKey(shape)
                ? synthetics.get(shape)
                : ShapeUtil.expectMember(ctx.model(), shape);
    }

    // helper to generate comparisons from two results, automatically handling any dereferencing in the process
    private GoWriter.Writable compareVariables(String ident, Variable left, Variable right, ComparatorType cmp,
                                               String cast) {
        var isLPtr = isPointable(left.type);
        var isRPtr = isPointable(right.type);
        if (!isLPtr && !isRPtr) {
            return goTemplate("$1L := $5L($2L) $4L $5L($3L)", ident, left.ident, right.ident, cmp, cast);
        }

        // undocumented jmespath behavior: null in numeric _ordering_ comparisons coerces to 0
        // this means the subsequent nil checks for numerics are moot, but it's either this or branch the codegen even
        // further for questionable benefit
        var nilCoerceLeft = emptyGoTemplate();
        var nilCoerceRight = emptyGoTemplate();
        if (isOrderComparator(cmp)) {
            if (isLPtr && left.shape instanceof NumberShape) {
                nilCoerceLeft = goTemplate("""
                        if ($1L == nil) {
                            $1L = new($2T)
                            *$1L = 0
                        }""", left.ident, left.type);
            }
            if (isRPtr && right.shape instanceof NumberShape) {
                nilCoerceRight = goTemplate("""
                        if ($1L == nil) {
                            $1L = new($2T)
                            *$1L = 0
                        }""", right.ident, right.type);
            }
        }

        // also, if they're both pointers, and it's (in)equality, there's an additional true case where both are nil,
        // or both are different
        var elseCheckPtrs = emptyGoTemplate();
        if (isLPtr && isRPtr) {
            if (cmp == ComparatorType.EQUAL) {
                elseCheckPtrs = goTemplate("else { $L = $L == nil && $L == nil }",
                        ident, left.ident, right.ident);
            } else if (cmp == ComparatorType.NOT_EQUAL) {
                elseCheckPtrs = goTemplate("else { $1L = ($2L == nil && $3L != nil) || ($2L != nil && $3L == nil) }",
                        ident, left.ident, right.ident);
            }
        }

        return goTemplate("""
                 var $ident:L bool
                 $nilCoerceLeft:W
                 $nilCoerceRight:W
                 if $lif:L $amp:L $rif:L {
                     $ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L)
                 }$elseCheckPtrs:W""",
                Map.of(
                        "ident", ident,
                        "lif", isLPtr ? left.ident + " != nil" : "",
                        "rif", isRPtr ? right.ident + " != nil" : "",
                        "amp", isLPtr && isRPtr ? "&&" : "",
                        "cmp", cmp,
                        "lhs", isLPtr ? "*" + left.ident : left.ident,
                        "rhs", isRPtr ? "*" + right.ident : right.ident,
                        "cast", cast,
                        "nilCoerceLeft", nilCoerceLeft,
                        "nilCoerceRight", nilCoerceRight
                ),
                Map.of(
                        "elseCheckPtrs", elseCheckPtrs
                ));
    }

    private static boolean isOrderComparator(ComparatorType cmp) {
        return cmp == ComparatorType.GREATER_THAN || cmp == ComparatorType.LESS_THAN
                || cmp == ComparatorType.GREATER_THAN_EQUAL || cmp == ComparatorType.LESS_THAN_EQUAL;
    }

    /**
     * Represents a variable (input, intermediate, or final output) of a JMESPath traversal.
     * @param shape The underlying shape referenced by this variable. For certain jmespath expressions (e.g.
     *              LiteralExpression) the value here is a synthetic shape and does not necessarily have meaning.
     * @param ident The identifier of the variable in the generated traversal.
     * @param type The symbol that records the type of the variable. This does NOT necessarily correspond to the result
     *             of toSymbol(shape) because certain jmespath expressions (such as projections) may affect the type of
     *             the resulting variable in a way that severs that relationship. The caller MUST use this field to
     *             determine whether the variable is pointable/nillable.
     */
    public record Variable(Shape shape, String ident, Symbol type) {
        public Variable(Shape shape, String ident) {
            this(shape, ident, null);
        }
    }
}