get_rope_index()

in src/models.js [4736:4902]


    get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
        // @ts-ignore
        const { vision_config, image_token_id, video_token_id, vision_start_token_id } = this.config;
        const spatial_merge_size = vision_config.spatial_merge_size ?? 2;

        const mrope_position_deltas = [];
        if (image_grid_thw || video_grid_thw) {
            let total_input_ids = input_ids.tolist();
            if (!attention_mask) {
                attention_mask = ones_like(input_ids);
            }

            const attention_mask_list = attention_mask.tolist();
            const position_ids_list = Array.from({ length: 3 }, _ => Array.from({ length: input_ids.dims[0] }, _ => Array.from({ length: input_ids.dims[1] }, _ => 1)));

            const image_grid_thw_list = image_grid_thw ? image_grid_thw.tolist() : [];
            const video_grid_thw_list = video_grid_thw ? video_grid_thw.tolist() : [];

            let image_index = 0;
            let video_index = 0;
            for (let i = 0; i < total_input_ids.length; ++i) {
                const ids = total_input_ids[i].filter((_, j) => attention_mask_list[i][j] == 1);

                const vision_start_indices = ids.reduce((acc, x, idx) => {
                    if (x == vision_start_token_id) acc.push(idx);
                    return acc;
                }, []);

                const vision_tokens = vision_start_indices.map(x => ids[x + 1]);
                const image_nums = vision_tokens.filter(x => x == image_token_id).length;
                const video_nums = vision_tokens.filter(x => x == video_token_id).length;

                /** @type {number[][]} */
                let llm_pos_ids_list = [];
                let st = 0;
                let remain_images = image_nums;
                let remain_videos = video_nums;
                for (let j = 0; j < vision_tokens.length; ++j) {
                    const next_image_token = ids.findIndex((x, i) => i > st && x == image_token_id);
                    const next_video_token = ids.findIndex((x, i) => i > st && x == video_token_id);

                    const ed_image = (remain_images > 0 && next_image_token !== -1)
                        ? next_image_token
                        : ids.length + 1;

                    const ed_video = (remain_videos > 0 && next_video_token !== -1)
                        ? next_video_token
                        : ids.length + 1;

                    let ed;
                    let t, h, w;
                    if (ed_image < ed_video) {
                        ([t, h, w] = image_grid_thw_list[image_index]);
                        ++image_index;
                        --remain_images;
                        ed = ed_image;
                    } else {
                        ([t, h, w] = video_grid_thw_list[video_index]);
                        ++video_index;
                        --remain_videos;
                        ed = ed_video;
                    }

                    const [llm_grid_t, llm_grid_h, llm_grid_w] = [
                        Number(t),
                        Math.floor(Number(h) / spatial_merge_size),
                        Math.floor(Number(w) / spatial_merge_size)
                    ]
                    const text_len = ed - st;
                    const st_idx = llm_pos_ids_list.length > 0
                        ? max(llm_pos_ids_list.at(-1))[0] + 1
                        : 0;

                    llm_pos_ids_list.push(
                        Array.from({ length: 3 * text_len }, (_, i) => st_idx + (i % text_len))
                    )

                    const offset = text_len + st_idx;
                    const grid_size = llm_grid_t * llm_grid_h * llm_grid_w;
                    const t_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / (llm_grid_h * llm_grid_w)))
                    const h_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / llm_grid_w) % llm_grid_h)
                    const w_index = Array.from({ length: grid_size }, (_, i) => offset + i % llm_grid_w)

                    llm_pos_ids_list.push([t_index, h_index, w_index].flat())

                    st = ed + grid_size;
                }

                if (st < ids.length) {
                    const st_idx = llm_pos_ids_list.length > 0
                        ? max(llm_pos_ids_list.at(-1))[0] + 1
                        : 0;
                    const text_len = ids.length - st;

                    llm_pos_ids_list.push(
                        Array.from({ length: 3 * text_len }, (_, i) => (st_idx + (i % text_len)))
                    )
                }

                // NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len),
                // meaning to perform concatenation along dim=1, we can do the following:
                const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0);
                /** @type {number[]} */
                const llm_positions = new Array(num_items);
                let index = 0;
                for (let x = 0; x < 3; ++x) {
                    for (let y = 0; y < llm_pos_ids_list.length; ++y) {
                        const val = llm_pos_ids_list[y];
                        const text_len = val.length / 3;
                        for (let z = x * text_len; z < (x + 1) * text_len; ++z) {
                            llm_positions[index++] = val[z];
                        }
                    }
                }

                let count = 0;
                const attn_mask = attention_mask_list[i];
                for (let y = 0; y < attn_mask.length; ++y) {
                    if (attn_mask[y] == 1) {
                        for (let x = 0; x < 3; ++x) {
                            position_ids_list[x][i][y] = llm_positions[x * num_items / 3 + count];
                        }
                        ++count;
                    }
                }

                const max_llm_positions = max(llm_positions)[0];
                mrope_position_deltas.push(max_llm_positions + 1 - total_input_ids[i].length);
            }

            return [
                new Tensor('int64', position_ids_list.flat(Infinity), [3, input_ids.dims[0], input_ids.dims[1]]),
                new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
            ];

        } else { // Text-only
            if (attention_mask) {
                const { data, dims } = cumsum_masked_fill(attention_mask);

                const position_ids = BigInt64Array.from(
                    { length: 3 * data.length },
                    (_, i) => data[i % data.length]
                );
                /** @type {bigint[]} */
                const mrope_position_deltas = Array.from(
                    { length: dims[0] },
                    (_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1n + BigInt(dims[1])
                );

                return [
                    new Tensor('int64', position_ids, [3, ...dims]),
                    new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
                ]
            } else {
                const [batch_size, seq_length] = input_ids.dims;
                const position_ids = BigInt64Array.from(
                    { length: 3 * batch_size * seq_length },
                    (_, i) => BigInt(Math.floor(i % seq_length / batch_size)),
                );

                return [
                    new Tensor('int64', position_ids, [3, ...input_ids.dims]),
                    zeros([batch_size, 1]),
                ]
            }
        }
    }