static int vfio_pci_zap_and_vma_lock()

in pci/vfio_pci_core.c [1240:1326]


static int vfio_pci_zap_and_vma_lock(struct vfio_pci_core_device *vdev, bool try)
{
	struct vfio_pci_mmap_vma *mmap_vma, *tmp;

	/*
	 * Lock ordering:
	 * vma_lock is nested under mmap_lock for vm_ops callback paths.
	 * The memory_lock semaphore is used by both code paths calling
	 * into this function to zap vmas and the vm_ops.fault callback
	 * to protect the memory enable state of the device.
	 *
	 * When zapping vmas we need to maintain the mmap_lock => vma_lock
	 * ordering, which requires using vma_lock to walk vma_list to
	 * acquire an mm, then dropping vma_lock to get the mmap_lock and
	 * reacquiring vma_lock.  This logic is derived from similar
	 * requirements in uverbs_user_mmap_disassociate().
	 *
	 * mmap_lock must always be the top-level lock when it is taken.
	 * Therefore we can only hold the memory_lock write lock when
	 * vma_list is empty, as we'd need to take mmap_lock to clear
	 * entries.  vma_list can only be guaranteed empty when holding
	 * vma_lock, thus memory_lock is nested under vma_lock.
	 *
	 * This enables the vm_ops.fault callback to acquire vma_lock,
	 * followed by memory_lock read lock, while already holding
	 * mmap_lock without risk of deadlock.
	 */
	while (1) {
		struct mm_struct *mm = NULL;

		if (try) {
			if (!mutex_trylock(&vdev->vma_lock))
				return 0;
		} else {
			mutex_lock(&vdev->vma_lock);
		}
		while (!list_empty(&vdev->vma_list)) {
			mmap_vma = list_first_entry(&vdev->vma_list,
						    struct vfio_pci_mmap_vma,
						    vma_next);
			mm = mmap_vma->vma->vm_mm;
			if (mmget_not_zero(mm))
				break;

			list_del(&mmap_vma->vma_next);
			kfree(mmap_vma);
			mm = NULL;
		}
		if (!mm)
			return 1;
		mutex_unlock(&vdev->vma_lock);

		if (try) {
			if (!mmap_read_trylock(mm)) {
				mmput(mm);
				return 0;
			}
		} else {
			mmap_read_lock(mm);
		}
		if (try) {
			if (!mutex_trylock(&vdev->vma_lock)) {
				mmap_read_unlock(mm);
				mmput(mm);
				return 0;
			}
		} else {
			mutex_lock(&vdev->vma_lock);
		}
		list_for_each_entry_safe(mmap_vma, tmp,
					 &vdev->vma_list, vma_next) {
			struct vm_area_struct *vma = mmap_vma->vma;

			if (vma->vm_mm != mm)
				continue;

			list_del(&mmap_vma->vma_next);
			kfree(mmap_vma);

			zap_vma_ptes(vma, vma->vm_start,
				     vma->vm_end - vma->vm_start);
		}
		mutex_unlock(&vdev->vma_lock);
		mmap_read_unlock(mm);
		mmput(mm);
	}
}