in codex-rs/mcp-types/generate_mcp_types.py [0:0]
def main() -> int:
num_args = len(sys.argv)
if num_args == 1:
schema_file = (
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
)
elif num_args == 2:
schema_file = Path(sys.argv[1])
else:
print("Usage: python3 codegen.py <schema.json>")
return 1
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
global DEFINITIONS # Allow helper functions to access the schema.
with schema_file.open(encoding="utf-8") as f:
schema_json = json.load(f)
DEFINITIONS = schema_json["definitions"]
out = [
f"""
// @generated
// DO NOT EDIT THIS FILE DIRECTLY.
// Run the following in the crate root to regenerate this file:
//
// ```shell
// ./generate_mcp_types.py
// ```
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::convert::TryFrom;
pub const MCP_SCHEMA_VERSION: &str = "{SCHEMA_VERSION}";
pub const JSONRPC_VERSION: &str = "{JSONRPC_VERSION}";
/// Paired request/response types for the Model Context Protocol (MCP).
pub trait ModelContextProtocolRequest {{
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
type Result: DeserializeOwned + Serialize + Send + Sync + 'static;
}}
/// One-way message in the Model Context Protocol (MCP).
pub trait ModelContextProtocolNotification {{
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
}}
fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
"""
]
definitions = schema_json["definitions"]
# Keep track of every *Request type so we can generate the TryFrom impl at
# the end.
# The concrete *Request types referenced by the ClientRequest enum will be
# captured dynamically while we are processing that definition.
for name, definition in definitions.items():
add_definition(name, definition, out)
# No-op: list collected via define_any_of("ClientRequest").
# Generate TryFrom impl string and append to out before writing to file.
try_from_impl_lines: list[str] = []
try_from_impl_lines.append("impl TryFrom<JSONRPCRequest> for ClientRequest {\n")
try_from_impl_lines.append(" type Error = serde_json::Error;\n")
try_from_impl_lines.append(
" fn try_from(req: JSONRPCRequest) -> std::result::Result<Self, Self::Error> {\n"
)
try_from_impl_lines.append(" match req.method.as_str() {\n")
for req_name in CLIENT_REQUEST_TYPE_NAMES:
defn = definitions[req_name]
method_const = (
defn.get("properties", {}).get("method", {}).get("const", req_name)
)
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
try_from_impl_lines.append(f' "{method_const}" => {{\n')
try_from_impl_lines.append(
" let params_json = req.params.unwrap_or(serde_json::Value::Null);\n"
)
try_from_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
try_from_impl_lines.append(
f" Ok(ClientRequest::{req_name}(params))\n"
)
try_from_impl_lines.append(" },\n")
try_from_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", req.method)))),\n'
)
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append("}\n\n")
out.extend(try_from_impl_lines)
# Generate TryFrom for ServerNotification
notif_impl_lines: list[str] = []
notif_impl_lines.append(
"impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
)
notif_impl_lines.append(" type Error = serde_json::Error;\n")
notif_impl_lines.append(
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
)
notif_impl_lines.append(" match n.method.as_str() {\n")
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
n_def = definitions[notif_name]
method_const = (
n_def.get("properties", {}).get("method", {}).get("const", notif_name)
)
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
notif_impl_lines.append(f' "{method_const}" => {{\n')
# params may be optional
notif_impl_lines.append(
" let params_json = n.params.unwrap_or(serde_json::Value::Null);\n"
)
notif_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
notif_impl_lines.append(
f" Ok(ServerNotification::{notif_name}(params))\n"
)
notif_impl_lines.append(" },\n")
notif_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", n.method)))),\n'
)
notif_impl_lines.append(" }\n")
notif_impl_lines.append(" }\n")
notif_impl_lines.append("}\n")
out.extend(notif_impl_lines)
with open(lib_rs, "w", encoding="utf-8") as f:
for chunk in out:
f.write(chunk)
subprocess.check_call(
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
cwd=lib_rs.parent.parent,
stderr=subprocess.DEVNULL,
)
return 0