static int ne_set_user_memory_region_ioctl()

in drivers/virt/nitro_enclaves/ne_misc_dev.c [999:1143]


static int ne_set_user_memory_region_ioctl(struct ne_enclave *ne_enclave,
					   struct ne_user_memory_region mem_region)
{
	long gup_rc = 0;
	unsigned long i = 0;
	unsigned long max_nr_pages = 0;
	unsigned long memory_size = 0;
	struct ne_mem_region *ne_mem_region = NULL;
	struct pci_dev *pdev = ne_devs.ne_pci_dev->pdev;
	struct ne_phys_contig_mem_regions phys_contig_mem_regions = {};
	int rc = -EINVAL;

	rc = ne_sanity_check_user_mem_region(ne_enclave, mem_region);
	if (rc < 0)
		return rc;

	ne_mem_region = kzalloc(sizeof(*ne_mem_region), GFP_KERNEL);
	if (!ne_mem_region)
		return -ENOMEM;

	max_nr_pages = mem_region.memory_size / NE_MIN_MEM_REGION_SIZE;

	ne_mem_region->pages = kcalloc(max_nr_pages, sizeof(*ne_mem_region->pages),
				       GFP_KERNEL);
	if (!ne_mem_region->pages) {
		rc = -ENOMEM;

		goto free_mem_region;
	}

	phys_contig_mem_regions.regions = kcalloc(max_nr_pages,
						  sizeof(*phys_contig_mem_regions.regions),
						  GFP_KERNEL);
	if (!phys_contig_mem_regions.regions) {
		rc = -ENOMEM;

		goto free_mem_region;
	}

	do {
		i = ne_mem_region->nr_pages;

		if (i == max_nr_pages) {
			dev_err_ratelimited(ne_misc_dev.this_device,
					    "Reached max nr of pages in the pages data struct\n");

			rc = -ENOMEM;

			goto put_pages;
		}

		gup_rc = get_user_pages_unlocked(mem_region.userspace_addr + memory_size, 1,
						 ne_mem_region->pages + i, FOLL_GET);

		if (gup_rc < 0) {
			rc = gup_rc;

			dev_err_ratelimited(ne_misc_dev.this_device,
					    "Error in get user pages [rc=%d]\n", rc);

			goto put_pages;
		}

		rc = ne_sanity_check_user_mem_region_page(ne_enclave, ne_mem_region->pages[i]);
		if (rc < 0)
			goto put_pages;

		rc = ne_merge_phys_contig_memory_regions(&phys_contig_mem_regions,
							 page_to_phys(ne_mem_region->pages[i]),
							 page_size(ne_mem_region->pages[i]));
		if (rc < 0)
			goto put_pages;

		memory_size += page_size(ne_mem_region->pages[i]);

		ne_mem_region->nr_pages++;
	} while (memory_size < mem_region.memory_size);

	if ((ne_enclave->nr_mem_regions + phys_contig_mem_regions.num) >
	    ne_enclave->max_mem_regions) {
		dev_err_ratelimited(ne_misc_dev.this_device,
				    "Reached max memory regions %lld\n",
				    ne_enclave->max_mem_regions);

		rc = -NE_ERR_MEM_MAX_REGIONS;

		goto put_pages;
	}

	for (i = 0; i < phys_contig_mem_regions.num; i++) {
		u64 phys_region_addr = phys_contig_mem_regions.regions[i].start;
		u64 phys_region_size = range_len(&phys_contig_mem_regions.regions[i]);

		rc = ne_sanity_check_phys_mem_region(phys_region_addr, phys_region_size);
		if (rc < 0)
			goto put_pages;
	}

	ne_mem_region->memory_size = mem_region.memory_size;
	ne_mem_region->userspace_addr = mem_region.userspace_addr;

	list_add(&ne_mem_region->mem_region_list_entry, &ne_enclave->mem_regions_list);

	for (i = 0; i < phys_contig_mem_regions.num; i++) {
		struct ne_pci_dev_cmd_reply cmd_reply = {};
		struct slot_add_mem_req slot_add_mem_req = {};

		slot_add_mem_req.slot_uid = ne_enclave->slot_uid;
		slot_add_mem_req.paddr = phys_contig_mem_regions.regions[i].start;
		slot_add_mem_req.size = range_len(&phys_contig_mem_regions.regions[i]);

		rc = ne_do_request(pdev, SLOT_ADD_MEM,
				   &slot_add_mem_req, sizeof(slot_add_mem_req),
				   &cmd_reply, sizeof(cmd_reply));
		if (rc < 0) {
			dev_err_ratelimited(ne_misc_dev.this_device,
					    "Error in slot add mem [rc=%d]\n", rc);

			kfree(phys_contig_mem_regions.regions);

			/*
			 * Exit here without put pages as memory regions may
			 * already been added.
			 */
			return rc;
		}

		ne_enclave->mem_size += slot_add_mem_req.size;
		ne_enclave->nr_mem_regions++;
	}

	kfree(phys_contig_mem_regions.regions);

	return 0;

put_pages:
	for (i = 0; i < ne_mem_region->nr_pages; i++)
		put_page(ne_mem_region->pages[i]);
free_mem_region:
	kfree(phys_contig_mem_regions.regions);
	kfree(ne_mem_region->pages);
	kfree(ne_mem_region);

	return rc;
}