quickstarts/Get_started_LyriaRealTime.py (153 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ## Setup To install the dependencies for this script, run: ``` pip install pyaudio websockets ``` Before running this script, ensure the `GOOGLE_API_KEY` environment variable is set to the api-key you obtained from Google AI Studio. ## Run To run the script: ``` python LyriaRealTime_EAP.py ``` The script takes a prompt from the command line and streams the audio back over websockets. """ import asyncio import pyaudio import os from google import genai from google.genai import types # Longer buffer reduces chance of audio drop, but also delays audio and user commands. BUFFER_SECONDS=1 CHUNK=4200 FORMAT=pyaudio.paInt16 CHANNELS=2 MODEL='models/lyria-realtime-exp' OUTPUT_RATE=48000 api_key = os.environ.get("GOOGLE_API_KEY") if api_key is None: print("Please enter your API key") api_key = input("API Key: ").strip() client = genai.Client( api_key=api_key, http_options={'api_version': 'v1alpha',}, # v1alpha since Lyria RealTime is only experimental ) async def main(): p = pyaudio.PyAudio() config = types.LiveMusicGenerationConfig() async with client.aio.live.music.connect(model=MODEL) as session: async def receive(): chunks_count = 0 output_stream = p.open( format=FORMAT, channels=CHANNELS, rate=OUTPUT_RATE, output=True, frames_per_buffer=CHUNK) async for message in session.receive(): chunks_count += 1 if chunks_count == 1: # Introduce a delay before starting playback to have a buffer for network jitter. await asyncio.sleep(BUFFER_SECONDS) # print("Received chunk: ", message) if message.server_content: # print("Received chunk with metadata: ", message.server_content.audio_chunks[0].source_metadata) audio_data = message.server_content.audio_chunks[0].data output_stream.write(audio_data) elif message.filtered_prompt: print("Prompt was filtered out: ", message.filtered_prompt) else: print("Unknown error occured with message: ", message) await asyncio.sleep(10**-12) async def send(): await asyncio.sleep(5) # Allow initial prompt to play a bit while True: print("Set new prompt ((bpm=<number|'AUTO'>, scale=<enum|'AUTO'>, top_k=<number|'AUTO'>, 'play', 'pause', 'prompt1:w1,prompt2:w2,...', or single text prompt)") prompt_str = await asyncio.to_thread( input, " > " ) if not prompt_str: # Skip empty input continue if prompt_str.lower() == 'q': print("Sending STOP command.") await session.stop(); return False if prompt_str.lower() == 'play': print("Sending PLAY command.") await session.play() continue if prompt_str.lower() == 'pause': print("Sending PAUSE command.") await session.pause() continue if prompt_str.startswith('bpm='): if prompt_str.strip().endswith('AUTO'): del config.bpm print(f"Setting BPM to AUTO, which requires resetting context.") else: bpm_value = int(prompt_str.removeprefix('bpm=')) print(f"Setting BPM to {bpm_value}, which requires resetting context.") config.bpm=bpm_value await session.set_music_generation_config(config=config) await session.reset_context() continue if prompt_str.startswith('scale='): if prompt_str.strip().endswith('AUTO'): del config.scale print(f"Setting Scale to AUTO, which requires resetting context.") else: found_scale_enum_member = None for scale_member in types.Scale: # types.Scale is an enum if scale_member.name.lower() == prompt_str.lower(): found_scale_enum_member = scale_member break if found_scale_enum_member: print(f"Setting scale to {found_scale_enum_member.name}, which requires resetting context.") config.scale = found_scale_enum_member else: print("Error: Matching enum not found.") await session.set_music_generation_config(config=config) await session.reset_context() continue if prompt_str.startswith('top_k='): top_k_value = int(prompt_str.removeprefix('top_k=')) print(f"Setting TopK to {top_k_value}.") config.top_k = top_k_value await session.set_music_generation_config(config=config) await session.reset_context() continue # Check for multiple weighted prompts "prompt1:number1, prompt2:number2, ..." if ":" in prompt_str: parsed_prompts = [] segments = prompt_str.split(',') malformed_segment_exists = False # Tracks if any segment had a parsing error for segment_str_raw in segments: segment_str = segment_str_raw.strip() if not segment_str: # Skip empty segments (e.g., from "text1:1, , text2:2") continue # Split on the first colon only, in case prompt text itself contains colons parts = segment_str.split(':', 1) if len(parts) == 2: text_p = parts[0].strip() weight_s = parts[1].strip() if not text_p: # Prompt text should not be empty print(f"Error: Empty prompt text in segment '{segment_str_raw}'. Skipping this segment.") malformed_segment_exists = True continue # Skip this malformed segment try: weight_f = float(weight_s) # Weights are floats parsed_prompts.append(types.WeightedPrompt(text=text_p, weight=weight_f)) except ValueError: print(f"Error: Invalid weight '{weight_s}' in segment '{segment_str_raw}'. Must be a number. Skipping this segment.") malformed_segment_exists = True continue # Skip this malformed segment else: # This segment is not in "text:weight" format. print(f"Error: Segment '{segment_str_raw}' is not in 'text:weight' format. Skipping this segment.") malformed_segment_exists = True continue # Skip this malformed segment if parsed_prompts: # If at least one prompt was successfully parsed. prompt_repr = [f"'{p.text}':{p.weight}" for p in parsed_prompts] if malformed_segment_exists: print(f"Partially sending {len(parsed_prompts)} valid weighted prompt(s) due to errors in other segments: {', '.join(prompt_repr)}") else: print(f"Sending multiple weighted prompts: {', '.join(prompt_repr)}") await session.set_weighted_prompts(prompts=parsed_prompts) else: # No valid prompts were parsed from the input string that contained ":" print("Error: Input contained ':' suggesting multi-prompt format, but no valid 'text:weight' segments were successfully parsed. No action taken.") continue # If none of the above, treat as a regular single text prompt print(f"Sending single text prompt: \"{prompt_str}\"") await session.set_weighted_prompts( prompts=[types.WeightedPrompt(text=prompt_str, weight=1.0)] ) print("Starting with some piano") await session.set_weighted_prompts( prompts=[types.WeightedPrompt(text="Piano", weight=1.0)] ) # Set initial BPM and Scale config.bpm = 120 config.scale = types.Scale.A_FLAT_MAJOR_F_MINOR # Example initial scale print(f"Setting initial BPM to {config.bpm} and scale to {config.scale.name}") await session.set_music_generation_config(config=config) print(f"Let's get the party started!") await session.play() send_task = asyncio.create_task(send()) receive_task = asyncio.create_task(receive()) # Don't quit the loop until tasks are done await asyncio.gather(send_task, receive_task) # Clean up PyAudio p.terminate() asyncio.run(main())