def setup_material_extraction_shader_for_material()

in shap_e/rendering/blender/blender_script.py [0:0]


def setup_material_extraction_shader_for_material(mat, capturing_material_alpha: bool):
    mat.use_nodes = True

    # By default, most imported models should use the regular
    # "Principled BSDF" material, so we should always find this.
    # If not, this shader manipulation logic won't work.
    bsdf_node = None
    for node in mat.node_tree.nodes:
        if node.type == "BSDF_PRINCIPLED":
            bsdf_node = node
    assert bsdf_node is not None, "material has no Principled BSDF node to modify"

    socket_map = {}
    for input in bsdf_node.inputs:
        socket_map[input.name] = input
    for name in ["Base Color", "Emission", "Emission Strength", "Alpha", "Specular"]:
        assert name in socket_map.keys(), f"{name} not in {list(socket_map.keys())}"

    old_base_color = get_socket_value(mat.node_tree, socket_map["Base Color"])
    old_alpha = get_socket_value(mat.node_tree, socket_map["Alpha"])
    old_emission = get_socket_value(mat.node_tree, socket_map["Emission"])
    old_emission_strength = get_socket_value(mat.node_tree, socket_map["Emission Strength"])
    old_specular = get_socket_value(mat.node_tree, socket_map["Specular"])

    # Make sure the base color of all objects is black and the opacity
    # is 1, so that we are effectively just telling the shader what color
    # to make the pixels.
    clear_socket_input(mat.node_tree, socket_map["Base Color"])
    socket_map["Base Color"].default_value = [0, 0, 0, 1]
    clear_socket_input(mat.node_tree, socket_map["Alpha"])
    socket_map["Alpha"].default_value = 1
    clear_socket_input(mat.node_tree, socket_map["Specular"])
    socket_map["Specular"].default_value = 0.0

    old_blend_method = mat.blend_method
    mat.blend_method = "OPAQUE"

    if capturing_material_alpha:
        set_socket_value(mat.node_tree, socket_map["Emission"], old_alpha)
    else:
        set_socket_value(mat.node_tree, socket_map["Emission"], old_base_color)
    clear_socket_input(mat.node_tree, socket_map["Emission Strength"])
    socket_map["Emission Strength"].default_value = 1.0

    def undo_fn():
        mat.blend_method = old_blend_method
        set_socket_value(mat.node_tree, socket_map["Base Color"], old_base_color)
        set_socket_value(mat.node_tree, socket_map["Alpha"], old_alpha)
        set_socket_value(mat.node_tree, socket_map["Emission"], old_emission)
        set_socket_value(mat.node_tree, socket_map["Emission Strength"], old_emission_strength)
        set_socket_value(mat.node_tree, socket_map["Specular"], old_specular)

    return undo_fn