def split_pt_tf_code_blocks()

in src/doc_builder/convert_rst_to_mdx.py [0:0]


def split_pt_tf_code_blocks(text):
    """
    Split PyTorch and TensorFlow specific block codes.
    """
    lines = text.split("\n")
    new_lines = []
    idx = 0
    while idx < len(lines):
        if lines[idx].startswith("```"):
            code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
            is_pytorch = False
            is_tensorflow = False
            idx += 1
            while idx < len(lines) and lines[idx].strip() != "```":
                if "## PYTORCH CODE" in lines[idx]:
                    is_pytorch = True
                    is_tensorflow = False
                elif "## TENSORFLOW CODE" in lines[idx]:
                    is_tensorflow = True
                    is_pytorch = False
                elif is_pytorch:
                    code_lines["pytorch"].append(lines[idx])
                elif is_tensorflow:
                    code_lines["tensorflow"].append(lines[idx])
                else:
                    code_lines["common"].append(lines[idx])
                idx += 1
            if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
                block_lines = ["<frameworkcontent>", "<pt>"]
                block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
                block_lines.extend(["```", "</pt>", "<tf>"])
                block_lines.extend(code_lines["common"].copy() + code_lines["tensorflow"])
                block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
                new_lines.extend(block_lines)
            else:
                block_lines = code_lines["common"] + ["```"]
                new_lines.extend(block_lines)
            idx += 1
        else:
            new_lines.append(lines[idx])
            idx += 1
    return "\n".join(new_lines)