core/cmd_interface/cerberus_protocol_required_commands.c (360 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include <stdint.h> #include <stdlib.h> #include <string.h> #include "cerberus_protocol.h" #include "cerberus_protocol_required_commands.h" #include "cmd_background.h" #include "cmd_interface.h" #include "cmd_logging.h" #include "device_manager.h" #include "session_manager.h" #include "attestation/attestation_responder.h" #include "common/buffer_util.h" #include "common/certificate.h" #include "common/common_math.h" #include "mctp/mctp_logging.h" /** * Populate the message payload with a Cerberus protocol error status response with the specified * error information. * * @param message The response message to populate with error details. If this is null, nothing * will be done. * @param error_code The Cerberus error code to report in the response. * @param error_data The detailed error code to provide with the response. * @param cmd_set The value to assign to the rq bit in the response header, corresponding to the * type of message that generated the error. * @param command_code The Cerberus command that generated the error. If the command code is unknown * or unavailable, set this to 0. */ void cerberus_protocol_build_error_response (struct cmd_interface_msg *message, uint8_t error_code, uint32_t error_data, uint8_t cmd_set, uint8_t command_code) { struct cerberus_protocol_error *error_msg; if (message == NULL) { return; } error_msg = (struct cerberus_protocol_error*) message->payload; memset (error_msg, 0, sizeof (struct cerberus_protocol_error)); /* TODO: Don't populate the MCTP header. */ error_msg->header.msg_type = MCTP_BASE_PROTOCOL_MSG_TYPE_VENDOR_DEF; buffer_unaligned_write16 (&error_msg->header.pci_vendor_id, CERBERUS_PROTOCOL_MSFT_PCI_VID); error_msg->header.rq = cmd_set; error_msg->header.command = CERBERUS_PROTOCOL_ERROR; error_msg->error_code = error_code; error_msg->error_data = error_data; cmd_interface_msg_set_message_payload_length (message, sizeof (*error_msg)); if (error_code != CERBERUS_PROTOCOL_NO_ERROR) { debug_log_create_entry (DEBUG_LOG_SEVERITY_ERROR, DEBUG_LOG_COMPONENT_CMD_INTERFACE, CMD_LOGGING_CERBERUS_REQUEST_FAIL, ((error_code << 24) | (command_code << 16) | (message->source_eid << 8) | message->channel_id), error_data); } } /** * Process FW version request * * @param fw_version The firmware version data * @param request FW version request to process * * @return 0 if request processed successfully or an error code. */ int cerberus_protocol_get_fw_version (const struct cmd_interface_fw_version *fw_version, struct cmd_interface_msg *request) { struct cerberus_protocol_get_fw_version *rq = (struct cerberus_protocol_get_fw_version*) request->data; struct cerberus_protocol_get_fw_version_response *rsp = (struct cerberus_protocol_get_fw_version_response*) request->data; uint8_t area; if (request->length != sizeof (struct cerberus_protocol_get_fw_version)) { return CMD_HANDLER_BAD_LENGTH; } if (rq->area >= fw_version->count) { return CMD_HANDLER_UNSUPPORTED_INDEX; } area = rq->area; memset (&rsp->version, 0, sizeof (rsp->version)); if (fw_version->id[area] != NULL) { strncpy (rsp->version, fw_version->id[area], sizeof (rsp->version)); rsp->version[sizeof (rsp->version) - 1] = '\0'; } request->length = sizeof (struct cerberus_protocol_get_fw_version_response); return 0; } /** * Process get certificate digest request * * @param attestation Attestation responder instance to utilize * @param session Session manager instance to utilize * @param request Get certificate digest request to process * * @return 0 if input processed successfully or an error code. */ int cerberus_protocol_get_certificate_digest (struct attestation_responder *attestation, struct session_manager *session, struct cmd_interface_msg *request) { struct cerberus_protocol_get_certificate_digest *rq = (struct cerberus_protocol_get_certificate_digest*) request->data; struct cerberus_protocol_get_certificate_digest_response *rsp = (struct cerberus_protocol_get_certificate_digest_response*) request->data; uint8_t num_cert = 0; int status = 0; request->crypto_timeout = true; if (request->length != sizeof (struct cerberus_protocol_get_certificate_digest)) { return CMD_HANDLER_BAD_LENGTH; } if (rq->slot_num > ATTESTATION_MAX_SLOT_NUM) { return CMD_HANDLER_OUT_OF_RANGE; } if (rq->key_alg >= NUM_ATTESTATION_KEY_EXCHANGE_ALGORITHMS) { return CMD_HANDLER_UNSUPPORTED_INDEX; } if (rq->key_alg != ATTESTATION_KEY_EXCHANGE_NONE) { if (session == NULL) { return CMD_HANDLER_UNSUPPORTED_OPERATION; } if (rq->header.crypt == 0) { session->reset_session (session, request->source_eid, NULL, 0); } } attestation->key_exchange_algorithm = rq->key_alg; status = attestation->get_digests (attestation, rq->slot_num, cerberus_protocol_certificate_digests (rsp), CERBERUS_PROTOCOL_MAX_CERT_DIGESTS (request), &num_cert); if (!ROT_IS_ERROR (status)) { rsp->capabilities = 1; rsp->num_digests = num_cert; request->length = cerberus_protocol_get_certificate_digest_response_length (rsp); status = 0; } else if ((status == ATTESTATION_INVALID_SLOT_NUM) || (status == ATTESTATION_CERT_NOT_AVAILABLE)) { rsp->capabilities = 1; rsp->num_digests = 0; request->length = cerberus_protocol_get_certificate_digest_response_length (rsp); status = 0; } return status; } /** * Process get certificate request * * @param attestation Attestation responder instance to utilize * @param request Get certificate request to process * * @return 0 if request processed successfully or an error code. */ int cerberus_protocol_get_certificate (struct attestation_responder *attestation, struct cmd_interface_msg *request) { struct cerberus_protocol_get_certificate *rq = (struct cerberus_protocol_get_certificate*) request->data; struct cerberus_protocol_get_certificate_response *rsp = (struct cerberus_protocol_get_certificate_response*) request->data; struct der_cert cert; uint8_t slot_num; uint8_t cert_num; uint16_t offset; uint16_t length; int status; if (request->length != sizeof (struct cerberus_protocol_get_certificate)) { return CMD_HANDLER_BAD_LENGTH; } slot_num = rq->slot_num; cert_num = rq->cert_num; length = rq->length; offset = rq->offset; if (slot_num > ATTESTATION_MAX_SLOT_NUM) { return CMD_HANDLER_OUT_OF_RANGE; } status = attestation->get_certificate (attestation, slot_num, cert_num, &cert); if ((status != 0) && (status != ATTESTATION_INVALID_SLOT_NUM) && (status != ATTESTATION_INVALID_CERT_NUM) && (status != ATTESTATION_CERT_NOT_AVAILABLE)) { return status; } if (status == 0) { if (offset < cert.length) { if ((length == 0) || (length > CERBERUS_PROTOCOL_MAX_CERT_DATA (request))) { length = CERBERUS_PROTOCOL_MAX_CERT_DATA (request); } length = min (length, cert.length - offset); memcpy (cerberus_protocol_certificate (rsp), &cert.cert[offset], length); } else { length = 0; } } else { debug_log_create_entry (DEBUG_LOG_SEVERITY_INFO, DEBUG_LOG_COMPONENT_CMD_INTERFACE, CMD_LOGGING_NO_CERT, (slot_num << 8) | cert_num, status); length = 0; } rsp->slot_num = slot_num; rsp->cert_num = cert_num; request->length = cerberus_protocol_get_certificate_response_length (length); return 0; } /** * Process challenge request * * @param attestation Attestation manager instance to utilize * @param session Session manager instance to utilize if initialized * @param request Challenge request to process * * @return 0 if request completed successfully or an error code. */ int cerberus_protocol_get_challenge_response (struct attestation_responder *attestation, struct session_manager *session, struct cmd_interface_msg *request) { struct cerberus_protocol_challenge *rq = (struct cerberus_protocol_challenge*) request->data; struct cerberus_protocol_challenge_response *rsp = (struct cerberus_protocol_challenge_response*) request->data; uint8_t device_nonce[ATTESTATION_NONCE_LEN]; int status; request->crypto_timeout = true; if (request->length != sizeof (struct cerberus_protocol_challenge)) { return CMD_HANDLER_BAD_LENGTH; } memcpy (device_nonce, rq->challenge.nonce, sizeof (device_nonce)); status = attestation->challenge_response (attestation, (uint8_t*) &rq->challenge, request->max_response - CERBERUS_PROTOCOL_MIN_MSG_LEN); if (!ROT_IS_ERROR (status)) { request->length = CERBERUS_PROTOCOL_MIN_MSG_LEN + status; status = 0; if ((session != NULL) && (attestation->key_exchange_algorithm == ATTESTATION_ECDHE_KEY_EXCHANGE)) { session->add_session (session, request->source_eid, device_nonce, rsp->challenge.nonce); } } return status; } /** * Process a CSR request * * @param riot RIoT key manager to utilize * @param request Export CSR request to process * * @return 0 if processing completed successfully or an error code. */ int cerberus_protocol_export_csr (const struct riot_key_manager *riot, struct cmd_interface_msg *request) { struct cerberus_protocol_export_csr *rq = (struct cerberus_protocol_export_csr*) request->data; struct cerberus_protocol_export_csr_response *rsp = (struct cerberus_protocol_export_csr_response*) request->data; struct der_cert csr; int status; if (request->length != sizeof (struct cerberus_protocol_export_csr)) { return CMD_HANDLER_BAD_LENGTH; } status = riot_key_manager_get_csr (riot, rq->index, &csr); if (status == RIOT_KEY_MANAGER_UNKNOWN_CSR) { return CMD_HANDLER_UNSUPPORTED_INDEX; } else if (status != 0) { return status; } if (csr.length > CERBERUS_PROTOCOL_LOCAL_MAX_CSR_DATA) { return CMD_HANDLER_BUF_TOO_SMALL; } else if (csr.length > CERBERUS_PROTOCOL_MAX_CSR_DATA (request)) { return CMD_HANDLER_RESPONSE_TOO_SMALL; } memcpy (&rsp->csr, csr.cert, csr.length); request->length = cerberus_protocol_export_csr_response_length (csr.length); return status; } /** * Import a signed certificate * * @param riot RIoT key manager to utilize * @param background Background handler context for certificate authentication * @param request Import certificate request to process * * @return 0 if processing completed successfully or an error code. */ int cerberus_protocol_import_ca_signed_cert (const struct riot_key_manager *riot, const struct cmd_background *background, struct cmd_interface_msg *request) { struct cerberus_protocol_import_certificate *rq = (struct cerberus_protocol_import_certificate*) request->data; int min_length = sizeof (struct cerberus_protocol_import_certificate) - sizeof (rq->certificate); int status; request->crypto_timeout = true; if (request->length < sizeof (struct cerberus_protocol_import_certificate)) { return CMD_HANDLER_BAD_LENGTH; } if ((rq->cert_length == 0) || ((int) request->length != (min_length + rq->cert_length))) { return CMD_HANDLER_BAD_LENGTH; } switch (rq->index) { case 0: status = riot_key_manager_store_signed_device_id (riot, &rq->certificate, rq->cert_length); break; case 1: status = riot_key_manager_store_root_ca (riot, &rq->certificate, rq->cert_length); break; case 2: status = riot_key_manager_store_intermediate_ca (riot, &rq->certificate, rq->cert_length); break; default: return CMD_HANDLER_UNSUPPORTED_INDEX; } if (status != 0) { return status; } status = background->authenticate_riot_certs (background); if (status != 0) { return status; } request->length = 0; return 0; } /** * Process a request to get the current state of signed RIoT certificates. * * @param background Background context that contains the necessary state information. * @param request State request to process. * * @return 0 if processing completed successfully or an error code. */ int cerberus_protocol_get_signed_cert_state (const struct cmd_background *background, struct cmd_interface_msg *request) { struct cerberus_protocol_get_certificate_state_response *rsp = (struct cerberus_protocol_get_certificate_state_response*) request->data; if (request->length != sizeof (struct cerberus_protocol_get_certificate_state)) { return CMD_HANDLER_BAD_LENGTH; } rsp->cert_state = background->get_riot_cert_chain_state (background); request->length = sizeof (struct cerberus_protocol_get_certificate_state_response); return 0; } /** * Process get device capabilities request * * @param device_mgr Device manager instance to utilize * @param request Capabilities request to process * * @return 0 if request processing completed successfully or an error code. */ int cerberus_protocol_get_device_capabilities (struct device_manager *device_mgr, struct cmd_interface_msg *request) { struct cerberus_protocol_device_capabilities *rq = (struct cerberus_protocol_device_capabilities*) request->data; struct cerberus_protocol_device_capabilities_response *rsp = (struct cerberus_protocol_device_capabilities_response*) request->data; int device_num; int status; if (request->length != sizeof (struct cerberus_protocol_device_capabilities)) { return CMD_HANDLER_BAD_LENGTH; } device_num = device_manager_get_device_num (device_mgr, request->source_eid); if (ROT_IS_ERROR (device_num)) { return device_num; } status = device_manager_update_device_capabilities_request (device_mgr, device_num, &rq->capabilities); if (status != 0) { return status; } status = device_manager_get_device_capabilities (device_mgr, DEVICE_MANAGER_SELF_DEVICE_NUM, &rsp->capabilities); if (status != 0) { return status; } request->length = sizeof (struct cerberus_protocol_device_capabilities_response); return 0; } /** * Process device info request * * @param device The device command handler to query the device information * @param request Device info request to process * * @return 0 if request processed successfully or an error code. */ int cerberus_protocol_get_device_info (const struct cmd_device *device, struct cmd_interface_msg *request) { struct cerberus_protocol_get_device_info *rq = (struct cerberus_protocol_get_device_info*) request->data; struct cerberus_protocol_get_device_info_response *rsp = (struct cerberus_protocol_get_device_info_response*) request->data; int status; if (request->length != sizeof (struct cerberus_protocol_get_device_info)) { return CMD_HANDLER_BAD_LENGTH; } if (rq->info_index != 0) { return CMD_HANDLER_UNSUPPORTED_INDEX; } status = device->get_uuid (device, &rsp->info, CERBERUS_PROTOCOL_MAX_DEV_INFO_DATA (request)); if (!ROT_IS_ERROR (status)) { request->length = cerberus_protocol_get_device_info_response_length (status); status = 0; } return status; } /** * Process device ID request * * @param id Device ID data * @param request Device ID request to process * * @return 0 if request completed successfully or an error code. */ int cerberus_protocol_get_device_id (const struct cmd_interface_device_id *id, struct cmd_interface_msg *request) { struct cerberus_protocol_get_device_id_response *rsp = (struct cerberus_protocol_get_device_id_response*) request->data; if (request->length != sizeof (struct cerberus_protocol_get_device_id)) { return CMD_HANDLER_BAD_LENGTH; } rsp->vendor_id = id->vendor_id; rsp->device_id = id->device_id; rsp->subsystem_vid = id->subsystem_vid; rsp->subsystem_id = id->subsystem_id; request->length = sizeof (struct cerberus_protocol_get_device_id_response); return 0; } /** * Process reset counter request * * @param device The device command handler to query the counter data * @param request Reset counter request to process * * @return 0 if request completed successfully or an error code. */ int cerberus_protocol_reset_counter (const struct cmd_device *device, struct cmd_interface_msg *request) { struct cerberus_protocol_reset_counter *rq = (struct cerberus_protocol_reset_counter*) request->data; struct cerberus_protocol_reset_counter_response *rsp = (struct cerberus_protocol_reset_counter_response*) request->data; int status; if (request->length != sizeof (struct cerberus_protocol_reset_counter)) { return CMD_HANDLER_BAD_LENGTH; } status = device->get_reset_counter (device, rq->type, rq->port, &rsp->counter); if (status != 0) { return status; } request->length = sizeof (struct cerberus_protocol_reset_counter_response); return 0; } /** * Process an error response * * @param response Error response to process * * @return 0 if response processed successfully or an error code. */ int cerberus_protocol_process_error_response (struct cmd_interface_msg *response) { struct cerberus_protocol_error *error_msg = (struct cerberus_protocol_error*) response->data; if (response->length >= sizeof (struct cerberus_protocol_error)) { debug_log_create_entry (DEBUG_LOG_SEVERITY_INFO, DEBUG_LOG_COMPONENT_CMD_INTERFACE, CMD_LOGGING_CHANNEL, response->channel_id, 0); debug_log_create_entry (DEBUG_LOG_SEVERITY_ERROR, DEBUG_LOG_COMPONENT_CMD_INTERFACE, CMD_LOGGING_ERROR_MESSAGE, (error_msg->error_code << 24 | response->source_eid << 16 | response->target_eid << 8), error_msg->error_data); } else { return CMD_HANDLER_INVALID_ERROR_MSG; } return CMD_HANDLER_ERROR_MESSAGE; }