std::unique_ptr Swizzle::Convert()

in src/sksl/ir/SkSLSwizzle.cpp [246:439]


std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
                                             Position pos,
                                             Position maskPos,
                                             std::unique_ptr<Expression> base,
                                             std::string_view componentString) {
    if (componentString.size() > 4) {
        context.fErrors->error(Position::Range(maskPos.startOffset() + 4,
                                               maskPos.endOffset()),
                               "too many components in swizzle mask");
        return nullptr;
    }

    // Convert the component string into an equivalent array.
    ComponentArray components;
    for (size_t i = 0; i < componentString.length(); ++i) {
        char field = componentString[i];
        switch (field) {
            case '0': components.push_back(SwizzleComponent::ZERO); break;
            case '1': components.push_back(SwizzleComponent::ONE);  break;
            case 'x': components.push_back(SwizzleComponent::X);    break;
            case 'r': components.push_back(SwizzleComponent::R);    break;
            case 's': components.push_back(SwizzleComponent::S);    break;
            case 'L': components.push_back(SwizzleComponent::UL);   break;
            case 'y': components.push_back(SwizzleComponent::Y);    break;
            case 'g': components.push_back(SwizzleComponent::G);    break;
            case 't': components.push_back(SwizzleComponent::T);    break;
            case 'T': components.push_back(SwizzleComponent::UT);   break;
            case 'z': components.push_back(SwizzleComponent::Z);    break;
            case 'b': components.push_back(SwizzleComponent::B);    break;
            case 'p': components.push_back(SwizzleComponent::P);    break;
            case 'R': components.push_back(SwizzleComponent::UR);   break;
            case 'w': components.push_back(SwizzleComponent::W);    break;
            case 'a': components.push_back(SwizzleComponent::A);    break;
            case 'q': components.push_back(SwizzleComponent::Q);    break;
            case 'B': components.push_back(SwizzleComponent::UB);   break;
            default:
                context.fErrors->error(Position::Range(maskPos.startOffset() + i,
                                                       maskPos.startOffset() + i + 1),
                                       String::printf("invalid swizzle component '%c'", field));
                return nullptr;
        }
    }

    if (!validate_swizzle_domain(components)) {
        context.fErrors->error(maskPos, "invalid swizzle mask '" + MaskString(components) + "'");
        return nullptr;
    }

    const Type& baseType = base->type().scalarTypeForLiteral();

    if (!baseType.isVector() && !baseType.isScalar()) {
        context.fErrors->error(pos, "cannot swizzle value of type '" +
                                    baseType.displayName() + "'");
        return nullptr;
    }

    ComponentArray maskComponents;
    bool foundXYZW = false;
    for (int i = 0; i < components.size(); ++i) {
        switch (components[i]) {
            case SwizzleComponent::ZERO:
            case SwizzleComponent::ONE:
                // Skip over constant fields for now.
                break;
            case SwizzleComponent::X:
            case SwizzleComponent::R:
            case SwizzleComponent::S:
            case SwizzleComponent::UL:
                foundXYZW = true;
                maskComponents.push_back(SwizzleComponent::X);
                break;
            case SwizzleComponent::Y:
            case SwizzleComponent::G:
            case SwizzleComponent::T:
            case SwizzleComponent::UT:
                foundXYZW = true;
                if (baseType.columns() >= 2) {
                    maskComponents.push_back(SwizzleComponent::Y);
                    break;
                }
                [[fallthrough]];
            case SwizzleComponent::Z:
            case SwizzleComponent::B:
            case SwizzleComponent::P:
            case SwizzleComponent::UR:
                foundXYZW = true;
                if (baseType.columns() >= 3) {
                    maskComponents.push_back(SwizzleComponent::Z);
                    break;
                }
                [[fallthrough]];
            case SwizzleComponent::W:
            case SwizzleComponent::A:
            case SwizzleComponent::Q:
            case SwizzleComponent::UB:
                foundXYZW = true;
                if (baseType.columns() >= 4) {
                    maskComponents.push_back(SwizzleComponent::W);
                    break;
                }
                [[fallthrough]];
            default:
                // The swizzle component references a field that doesn't exist in the base type.
                context.fErrors->error(Position::Range(maskPos.startOffset() + i,
                                                       maskPos.startOffset() + i + 1),
                                       String::printf("invalid swizzle component '%c'",
                                                      mask_char(components[i])));
                return nullptr;
        }
    }

    if (!foundXYZW) {
        context.fErrors->error(maskPos, "swizzle must refer to base expression");
        return nullptr;
    }

    // Coerce literals in expressions such as `(12345).xxx` to their actual type.
    base = baseType.coerceExpression(std::move(base), context);
    if (!base) {
        return nullptr;
    }

    // Swizzles are complicated due to constant components. The most difficult case is a mask like
    // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that
    // evaluates 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0')
    // together and use a secondary swizzle to put them back into the right order, so in this case
    // we end up with 'float4(base.xw, 1, 0).xzyw'.
    //
    // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
    //   scalar.xxx  -> type3(scalar)
    //   scalar.x0x0 -> type2(scalar)
    //   vector.zyx  -> vector.zyx
    //   vector.x0y0 -> vector.xy
    std::unique_ptr<Expression> expr = Swizzle::Make(context, pos, std::move(base), maskComponents);

    // If we have processed the entire swizzle, we're done.
    if (maskComponents.size() == components.size()) {
        return expr;
    }

    // Now we create a constructor that has the correct number of elements for the final swizzle,
    // with all fields at the start. It's not finished yet; constants we need will be added below.
    //   scalar.x0x0 -> type4(type2(x), ...)
    //   vector.y111 -> type4(vector.y, ...)
    //   vector.z10x -> type4(vector.zx, ...)
    //
    // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
    ExpressionArray constructorArgs;
    constructorArgs.reserve_exact(3);
    constructorArgs.push_back(std::move(expr));

    // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
    // need are also tacked on to the end of the constructor.
    //   scalar.x0x0 -> type4(type2(x), 0).xyxy
    //   vector.y111 -> type2(vector.y, 1).xyyy
    //   vector.z10x -> type4(vector.zx, 1, 0).xzwy
    const Type* scalarType = &baseType.componentType();
    ComponentArray swizzleComponents;
    int maskFieldIdx = 0;
    int constantFieldIdx = maskComponents.size();
    int constantZeroIdx = -1, constantOneIdx = -1;

    for (int i = 0; i < components.size(); i++) {
        switch (components[i]) {
            case SwizzleComponent::ZERO:
                if (constantZeroIdx == -1) {
                    // Synthesize a '0' argument at the end of the constructor.
                    constructorArgs.push_back(Literal::Make(pos, /*value=*/0, scalarType));
                    constantZeroIdx = constantFieldIdx++;
                }
                swizzleComponents.push_back(constantZeroIdx);
                break;
            case SwizzleComponent::ONE:
                if (constantOneIdx == -1) {
                    // Synthesize a '1' argument at the end of the constructor.
                    constructorArgs.push_back(Literal::Make(pos, /*value=*/1, scalarType));
                    constantOneIdx = constantFieldIdx++;
                }
                swizzleComponents.push_back(constantOneIdx);
                break;
            default:
                // The non-constant fields are already in the expected order.
                swizzleComponents.push_back(maskFieldIdx++);
                break;
        }
    }

    expr = ConstructorCompound::Make(context, pos,
                                     scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
                                     std::move(constructorArgs));

    // Create (and potentially optimize-away) the resulting swizzle-expression.
    return Swizzle::Make(context, pos, std::move(expr), swizzleComponents);
}