def setup_nodes()

in shap_e/rendering/blender/blender_script.py [0:0]


def setup_nodes(output_path, capturing_material_alpha: bool = False, basic_lighting: bool = False):
    tree = bpy.context.scene.node_tree
    links = tree.links

    for node in tree.nodes:
        tree.nodes.remove(node)

    # Helpers to perform math on links and constants.
    def node_op(op: str, *args, clamp=False):
        node = tree.nodes.new(type="CompositorNodeMath")
        node.operation = op
        if clamp:
            node.use_clamp = True
        for i, arg in enumerate(args):
            if isinstance(arg, (int, float)):
                node.inputs[i].default_value = arg
            else:
                links.new(arg, node.inputs[i])
        return node.outputs[0]

    def node_clamp(x, maximum=1.0):
        return node_op("MINIMUM", x, maximum)

    def node_mul(x, y, **kwargs):
        return node_op("MULTIPLY", x, y, **kwargs)

    def node_add(x, y, **kwargs):
        return node_op("ADD", x, y, **kwargs)

    def node_abs(x, **kwargs):
        return node_op("ABSOLUTE", x, **kwargs)

    input_node = tree.nodes.new(type="CompositorNodeRLayers")
    input_node.scene = bpy.context.scene

    input_sockets = {}
    for output in input_node.outputs:
        input_sockets[output.name] = output

    if capturing_material_alpha:
        color_socket = input_sockets["Image"]
    else:
        raw_color_socket = input_sockets["Image"]
        if basic_lighting:
            # Compute diffuse lighting
            normal_xyz = tree.nodes.new(type="CompositorNodeSeparateXYZ")
            tree.links.new(input_sockets["Normal"], normal_xyz.inputs[0])
            normal_x, normal_y, normal_z = [normal_xyz.outputs[i] for i in range(3)]
            dot = node_add(
                node_mul(UNIFORM_LIGHT_DIRECTION[0], normal_x),
                node_add(
                    node_mul(UNIFORM_LIGHT_DIRECTION[1], normal_y),
                    node_mul(UNIFORM_LIGHT_DIRECTION[2], normal_z),
                ),
            )
            diffuse = node_abs(dot)
            # Compute ambient + diffuse lighting
            brightness = node_add(BASIC_AMBIENT_COLOR, node_mul(BASIC_DIFFUSE_COLOR, diffuse))
            # Modulate the RGB channels using the total brightness.
            rgba_node = tree.nodes.new(type="CompositorNodeSepRGBA")
            tree.links.new(raw_color_socket, rgba_node.inputs[0])
            combine_node = tree.nodes.new(type="CompositorNodeCombRGBA")
            for i in range(3):
                tree.links.new(node_mul(rgba_node.outputs[i], brightness), combine_node.inputs[i])
            tree.links.new(rgba_node.outputs[3], combine_node.inputs[3])
            raw_color_socket = combine_node.outputs[0]

        # We apply sRGB here so that our fixed-point depth map and material
        # alpha values are not sRGB, and so that we perform ambient+diffuse
        # lighting in linear RGB space.
        color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
        color_node.from_color_space = "Linear"
        color_node.to_color_space = "sRGB"
        tree.links.new(raw_color_socket, color_node.inputs[0])
        color_socket = color_node.outputs[0]
    split_node = tree.nodes.new(type="CompositorNodeSepRGBA")
    tree.links.new(color_socket, split_node.inputs[0])
    # Create separate file output nodes for every channel we care about.
    # The process calling this script must decide how to recombine these
    # channels, possibly into a single image.
    for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]:
        output_node = tree.nodes.new(type="CompositorNodeOutputFile")
        output_node.base_path = f"{output_path}_{channel}"
        links.new(split_node.outputs[i], output_node.inputs[0])

    if capturing_material_alpha:
        # No need to re-write depth here.
        return

    depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH))
    output_node = tree.nodes.new(type="CompositorNodeOutputFile")
    output_node.base_path = f"{output_path}_depth"
    links.new(depth_out, output_node.inputs[0])