def writing_assistant()

in genai-design-marketing-studio/streamlit_app.py [0:0]


def writing_assistant(key: str, persona: str) -> None:
    """
    Manage the conversation for the writing assistant.
    :param key: streamlit key to identify the UI element
    :param persona: the persona of the chatbot.
    :return: None
    """

    # model name
    model_name = "gemini-1.5-flash-001"

    # Init the conversation
    chat_input_container = st.container()
    with chat_input_container:
        message = st.chat_input("👋 Hello, How can I help you today?", key=f"{key}_input")

    main_chat_container = st.container(height=400)
    with main_chat_container:
        if message or (f'{key}_text_chat_history' in st.session_state and st.session_state[f'{key}_text_chat_history']):
            # init the chat history in the session state
            # if f'{key}_text_chat_history' not in st.session_state or not st.session_state[f'{key}_text_chat_history']:
            if f'{key}_text_chat_history' not in st.session_state:
                chat_session = create_chat_session(model_name, persona)
                st.session_state[f'{key}_text_chat_history'] = []
                st.session_state[f'{key}_text_chat_session'] = chat_session

            # Display chat messages from history on app rerun
            for m in st.session_state[f'{key}_text_chat_history']:
                with st.chat_message(m["role"]):
                    st.markdown(m["content"])

            # Add user message to chat history
            if message:
                st.session_state[f'{key}_text_chat_history'].append({"role": "user", "content": message})

                # Display user message in chat message container
                with st.chat_message("user"):
                    st.markdown(message)
                with st.chat_message("assistant"):
                    current_chat_session = st.session_state[f'{key}_text_chat_session']
                    responses = st.session_state[f'{key}_text_chat_session'].send_message(message, stream=True)
                    texts = [r.text for r in responses]
                    full_response = ""
                    for t in texts:
                        full_response += ' ' + t
                    st.write_stream(texts)
                    st.session_state[f'{key}_text_chat_history'].append(
                        {"role": "assistant", "content": full_response}
                    )