def modify_segment()

in deep_gemm/jit/interleave_ffma.py [0:0]


def modify_segment(m, name, ffma_lines):
    num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
    assert num_lines % 2 == 0

    le_bytes, new_le_bytes = [], []
    reused_list = []
    dst_reg_set = set()
    last_reused, last_dst_reg = False, ''
    num_changed = 0
    for i in range(num_lines // 2):
        dst_reg = parse_registers(ffma_lines[i * 2])[-2]
        low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
        low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
        le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
        reused = (high_hex & 0x0800000000000000) != 0
        if reused:
            is_first_occurred = dst_reg not in dst_reg_set
            if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
                # Modify the `reuse` and `yield` bits
                assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
                high_hex ^= 0x0800200000000000
                reused = False
                num_changed += 1
            else:
                reused_list.append(i)
        dst_reg_set.add(dst_reg)
        new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
        last_reused, last_dst_reg = reused, dst_reg
    if os.getenv('DG_PRINT_REG_REUSE', None):
        print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')

    # Find the offset
    offsets = []
    offset = m.find(le_bytes[0])
    while offset != -1:
        offsets.append(offset)
        offset = m.find(le_bytes[0], offset + 1)
    offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))

    # Replace with `new_le_bytes`
    for offset in offsets:
        for i in range(num_lines // 2):
            m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]