experimental/piranha_playground/main.py (89 lines of code) (raw):

# Copyright (c) 2023 Uber Technologies, Inc. # # <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file # except in compliance with the License. You may obtain a copy of the License at # <p>http://www.apache.org/licenses/LICENSE-2.0 # # <p>Unless required by applicable law or agreed to in writing, software distributed under the # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. import logging import os import sys import openai from flask import Flask, Response, jsonify, render_template, request, session from piranha_playground.data_validation import (ImproveData, InferData, RefactorData, RefactorSnippet) from piranha_playground.rule_inference.piranha_agent import (PiranhaAgent, PiranhaAgentError) from piranha_playground.rule_inference.rule_application import ( CodebaseRefactorer, CodebaseRefactorerException) from piranha_playground.rule_inference.utils.logger_formatter import \ CustomFormatter # Create Flask app app = Flask(__name__) logger = logging.getLogger("FlaskApp") logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.INFO) ch.setFormatter(CustomFormatter()) logger.addHandler(ch) @app.route("/") def home(): """ The main route that returns the index.html template. """ return render_template("index.html") @app.route("/refactor_codebase", methods=["POST"]) def process_folder(): """ Route for the refactor_codebase event. Attempts to refactor a codebase based on the provided rules. :param data: A dictionary containing the necessary information to perform the refactoring. """ data = request.get_json() try: data = RefactorData(**data) refactorer = CodebaseRefactorer(data.language, data.folder_path, data.rules) refactorer.refactor_codebase(False) return jsonify({"result": True}) except (ValueError, CodebaseRefactorerException) as e: return jsonify({"result": False, "error": str(e)}), 400 @app.route("/infer_rule_graph", methods=["POST"]) def infer_static_rule(): """ Route for the infer_static_rule event. Infers static coding rules based on source and target code examples. :param data: A dictionary containing the source and target code examples and the programming language. """ data = request.get_json() try: data = InferData(**data) agent = PiranhaAgent( data.source_code, data.target_code, language=data.language, hints="", ) static_rules = agent.infer_rules_statically() return jsonify({"rules": static_rules}) except (ValueError, PiranhaAgentError) as e: return jsonify({"error": str(e)}), 400 @app.route("/improve_rule_graph", methods=["POST"]) def improve_rules(): """ Route for the improve_piranha event. Improves the inferred coding rules. :param data: A dictionary containing the requirements and current rules. """ data = request.get_json() try: data = ImproveData(**data) agent: PiranhaAgent = PiranhaAgent( data.source_code, data.target_code, language=data.language, hints="", ) rules = agent.improve_rule_graph(data.requirements, data.rules, data.option) return jsonify( { "rule": rules, "gpt_output": agent.get_explanation(), } ) except (ValueError, PiranhaAgentError, AttributeError) as e: return jsonify({"error": str(e)}), 400 @app.route("/test_rule", methods=["POST"]) def test_rule(): """ Route for the test_rule event. Tests the inferred rules by applying them to the provided source code. :param data: A dictionary containing the language, rules, and source code. """ data = request.get_json() try: data = RefactorSnippet(**data) refactored_code = CodebaseRefactorer.refactor_snippet( data.source_code, data.language, data.rules ) return jsonify({"refactored_code": refactored_code}) except (ValueError, CodebaseRefactorerException) as e: return jsonify({"error": str(e)}), 400 def main(): openai.api_key = os.getenv("OPENAI_API_KEY") if not openai.api_key: sys.exit( "Please set the OPENAI_API_KEY environment variable to your OpenAI API key." ) logger.info(f"Starting server. Listening at: http://127.0.0.1:5000") app.run(debug=True) if __name__ == "__main__": main()