def main()

in codex-rs/mcp-types/generate_mcp_types.py [0:0]


def main() -> int:
    num_args = len(sys.argv)
    if num_args == 1:
        schema_file = (
            Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
        )
    elif num_args == 2:
        schema_file = Path(sys.argv[1])
    else:
        print("Usage: python3 codegen.py <schema.json>")
        return 1

    lib_rs = Path(__file__).resolve().parent / "src/lib.rs"

    global DEFINITIONS  # Allow helper functions to access the schema.

    with schema_file.open(encoding="utf-8") as f:
        schema_json = json.load(f)

    DEFINITIONS = schema_json["definitions"]

    out = [
        f"""
// @generated
// DO NOT EDIT THIS FILE DIRECTLY.
// Run the following in the crate root to regenerate this file:
//
// ```shell
// ./generate_mcp_types.py
// ```
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::convert::TryFrom;

pub const MCP_SCHEMA_VERSION: &str = "{SCHEMA_VERSION}";
pub const JSONRPC_VERSION: &str = "{JSONRPC_VERSION}";

/// Paired request/response types for the Model Context Protocol (MCP).
pub trait ModelContextProtocolRequest {{
    const METHOD: &'static str;
    type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
    type Result: DeserializeOwned + Serialize + Send + Sync + 'static;
}}

/// One-way message in the Model Context Protocol (MCP).
pub trait ModelContextProtocolNotification {{
    const METHOD: &'static str;
    type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
}}

fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}

"""
    ]
    definitions = schema_json["definitions"]
    # Keep track of every *Request type so we can generate the TryFrom impl at
    # the end.
    # The concrete *Request types referenced by the ClientRequest enum will be
    # captured dynamically while we are processing that definition.
    for name, definition in definitions.items():
        add_definition(name, definition, out)
    # No-op: list collected via define_any_of("ClientRequest").

    # Generate TryFrom impl string and append to out before writing to file.
    try_from_impl_lines: list[str] = []
    try_from_impl_lines.append("impl TryFrom<JSONRPCRequest> for ClientRequest {\n")
    try_from_impl_lines.append("    type Error = serde_json::Error;\n")
    try_from_impl_lines.append(
        "    fn try_from(req: JSONRPCRequest) -> std::result::Result<Self, Self::Error> {\n"
    )
    try_from_impl_lines.append("        match req.method.as_str() {\n")

    for req_name in CLIENT_REQUEST_TYPE_NAMES:
        defn = definitions[req_name]
        method_const = (
            defn.get("properties", {}).get("method", {}).get("const", req_name)
        )
        payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
        try_from_impl_lines.append(f'            "{method_const}" => {{\n')
        try_from_impl_lines.append(
            "                let params_json = req.params.unwrap_or(serde_json::Value::Null);\n"
        )
        try_from_impl_lines.append(
            f"                let params: {payload_type} = serde_json::from_value(params_json)?;\n"
        )
        try_from_impl_lines.append(
            f"                Ok(ClientRequest::{req_name}(params))\n"
        )
        try_from_impl_lines.append("            },\n")

    try_from_impl_lines.append(
        '            _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", req.method)))),\n'
    )
    try_from_impl_lines.append("        }\n")
    try_from_impl_lines.append("    }\n")
    try_from_impl_lines.append("}\n\n")

    out.extend(try_from_impl_lines)

    # Generate TryFrom for ServerNotification
    notif_impl_lines: list[str] = []
    notif_impl_lines.append(
        "impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
    )
    notif_impl_lines.append("    type Error = serde_json::Error;\n")
    notif_impl_lines.append(
        "    fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
    )
    notif_impl_lines.append("        match n.method.as_str() {\n")

    for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
        n_def = definitions[notif_name]
        method_const = (
            n_def.get("properties", {}).get("method", {}).get("const", notif_name)
        )
        payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
        notif_impl_lines.append(f'            "{method_const}" => {{\n')
        # params may be optional
        notif_impl_lines.append(
            "                let params_json = n.params.unwrap_or(serde_json::Value::Null);\n"
        )
        notif_impl_lines.append(
            f"                let params: {payload_type} = serde_json::from_value(params_json)?;\n"
        )
        notif_impl_lines.append(
            f"                Ok(ServerNotification::{notif_name}(params))\n"
        )
        notif_impl_lines.append("            },\n")

    notif_impl_lines.append(
        '            _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", n.method)))),\n'
    )
    notif_impl_lines.append("        }\n")
    notif_impl_lines.append("    }\n")
    notif_impl_lines.append("}\n")

    out.extend(notif_impl_lines)

    with open(lib_rs, "w", encoding="utf-8") as f:
        for chunk in out:
            f.write(chunk)

    subprocess.check_call(
        ["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
        cwd=lib_rs.parent.parent,
        stderr=subprocess.DEVNULL,
    )

    return 0