games/play.py (152 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from enum import Enum from textwrap import dedent from typing import Callable import matplotlib.pyplot as plt import params from crewai import Agent, Crew, Task from crewai_tools import tool from langchain_google_genai import ChatGoogleGenerativeAI from pettingzoo.utils.wrappers.order_enforcing import OrderEnforcingWrapper Environment = OrderEnforcingWrapper class Move(Enum): ROCK = 0 PAPER = 1 SCISSORS = 2 START = 3 class Reward(Enum): START = -2 LOSS = -1 TIE = 0 WIN = 1 @dataclass class Player: crew: Crew move: Callable[[], None] observations: list[str] = field(default_factory=list) scores: list[int] = field(default_factory=list) def make_play(): last_move = None @tool("play") def play( observations: list[str], reward: str, step: int, game: int, move: str, rationale: str, ) -> str: """Play a move in rock-paper-scissors. Given observations and reward; play a move with the given rationale. Args: observations (list[str]): Previous moves from the opponent, list of START, ROCK, PAPER, SCISSORS. reward (str): Previous reward from the last turn, one of LOSS, TIE, WIN. step (int): The step number. game (int): The game number. move (str): The move to make, one of ROCK, PAPER, SCISSORS. rationale (str): Why we're making this move. Returns: Move made with rationale.""" nonlocal last_move last_move = move return f"Played {move} because {rationale}" return play, lambda: last_move def make_ares(): play, move = make_play() agent = Agent( role="Ares the rock-paper-scissors player", goal="Play rock-paper-scissors with a brute-force heuristic", backstory=dedent( """\ You are a Ares the god of war. You are an hilariously aggressive rock-paper-scissors player. You start with rock. When you win, you stick with your winning move. When you lose or tie, cycle clockwise to the next move (rock to paper to scissors to rock, etc.).""" ), verbose=True, llm=make_gemini(), max_iter=5, tools=[play], ) task = Task( description=dedent( """\ Play an aggressive game of rock-paper-scissors; given prior observations {observations} and reward {reward}. This is step {step} of game {game}.""" ), expected_output="The move played with rationale", agent=agent, ) return Player( Crew( agents=[agent], tasks=[task], verbose=1, cache=False, ), move, ) def make_athena(): play, move = make_play() agent = Agent( role="Athena the rock-paper-scissors player", goal="Play rock-paper-scissors with a strategic heuristic", backstory=dedent( """\ You are a Athena the goddess of wisdom. You are a flawlessly strategic rock-paper-scissors player. Attempt to observe patterns in your opponent's moves and counter accordingly: use paper against rock; scissors against paper; and rock against scissors. Be volatile to avoid becoming predictable.""" ), verbose=True, llm=make_gemini(), max_iter=5, tools=[play], ) task = Task( description=dedent( """\ Play a strategic game of rock-paper-scissors; given prior observations {observations} and reward {reward}. This is step {step} of game {game}.""" ), expected_output="The move played with rationale", agent=agent, ) return Player( Crew( agents=[agent], tasks=[task], verbose=1, ), move, ) def make_gemini(): return ChatGoogleGenerativeAI( model="gemini-pro", google_api_key=params.GOOGLE_API_KEY, temperature=1.0, ) def plot_scores(players): num_games = len(players["player_0"].scores) # Setting the global font size plt.rcParams.update( {"font.size": 24, "font.family": "serif", "font.serif": ["PT Serif"]} ) # Adjust font size here plt.figure(figsize=(10, 6), constrained_layout=True) plt.plot( range(num_games), players["player_0"].scores, label="Ares Wins", color="blue", linewidth=5, # Thicker line ) plt.plot( range(num_games), players["player_1"].scores, label="Athena Wins", color="red", linewidth=5, # Thicker line ) plt.xlabel("Number of Turns", fontsize=26) plt.ylabel("Cumulative Wins", fontsize=26) plt.title("Cumulative Wins Over Time for Rock-Paper-Scissors", fontsize=28) plt.legend(fontsize=24) plt.grid(True, linewidth=1.2) plt.tight_layout(pad=0) plt.show(block=True) def play_game(game: int, env: Environment, players: dict[str, Player]): env.reset() for step, agent in enumerate(env.agent_iter()): observation, reward, termination, truncation, info = env.last() if termination or truncation: break player = players[agent] player.observations.append(Move(observation).name) player.scores.append( (player.scores[-1] if player.scores else 0) + (1 if reward > 0 else 0) ) player.crew.kickoff( inputs={ "observations": player.observations, "reward": Reward( -2 if step == 0 or step == 1 else reward ).name, "step": step + 1, "game": game + 1, } ) env.step(Move[player.move()].value) env.close()