def display_examples()

in vision/m4/evaluation/scripts/visualize_generations.py [0:0]


def display_examples(dataset, generation_data, dataset_question_id_index, multiple_images, num_examples=300):
    st.header(f"Sample Examples - First {num_examples}")

    for idx in range(num_examples):
        st.subheader(f"Example {idx}")

        question_id = generation_data["data"][0][idx]["question_id"]
        st.write(f"Q id: {question_id}")

        idx_dataset = dataset_question_id_index[question_id]
        st.subheader(f"Dataset idx  {idx_dataset}")

        example = dataset[idx_dataset]
        if multiple_images:
            images = [im for im in example["images"] if im is not None]
        else:
            images = [example["image"]]
        for im in images:
            st.image(im, width=250, caption=f"Image dimension: {im.size}")

        if "question" in example:
            question = example["question"]
        elif "query" in example:
            question = example["query"]
        else:
            raise ValueError("Dataset must contain a column question or query")
        st.markdown("<pre><strong>Question:<strong><pre>", unsafe_allow_html=True)
        # Might look strange, but it's a more accurate display of the question because it
        # preserves the newlines and doesn't affect <image> tags
        st.write([question])
        if "answers" in example:
            answers = example["answers"]
        elif "answer" in example:
            answers = example["answer"]
        elif "label" in example:
            answers = example["label"]
        else:
            raise ValueError("Dataset must contain a column answers or answer")
        if isinstance(answers, str):
            answers = [answers]

        for i, answer in enumerate(answers):
            display_text = f"<strong>Answer {i}:</strong> {answer}\n".replace("\n", "<br>")
            st.markdown(f"<pre>{display_text}</pre>", unsafe_allow_html=True)

        dominant_answer = [
            answer for answer in set(answers) if answers.count(answer) == max(map(answers.count, answers))
        ][0].lower()
        for model_name, data in zip(generation_data["model_name"], generation_data["data"]):
            generated_answer = data[idx]["answer"]
            if dominant_answer in generated_answer.lower():
                display_text = (
                    f'<strong>{model_name} Answer:</strong> <span style="color:green">{generated_answer}\n</span>'
                    .replace("\n", "<br>")
                )
            else:
                display_text = (
                    f'<strong>{model_name} Answer:</strong> <span style="color:red">{generated_answer}\n</span>'
                    .replace("\n", "<br>")
                )

            st.markdown(f"<pre>{display_text}</pre>", unsafe_allow_html=True)

        st.divider()