sources/Google.Solutions.Platform/Dispatch/Win32Job.cs (374 lines of code) (raw):
//
// Copyright 2023 Google LLC
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
using Google.Solutions.Common.Interop;
using Google.Solutions.Common.Runtime;
using Google.Solutions.Common.Util;
using Google.Solutions.Platform.Interop;
using Microsoft.Win32.SafeHandles;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace Google.Solutions.Platform.Dispatch
{
/// <summary>
/// A Win32 job.
/// </summary>
public interface IWin32Job : IWin32ProcessSet, IDisposable
{
/// <summary>
/// Job handle.
/// </summary>
SafeHandle Handle { get; }
/// <summary>
/// Add a process to the job.
/// </summary>
void Add(IWin32Process process);
/// <summary>
/// Return the IDs of processes in this job.
/// </summary>
IEnumerable<uint> ProcessIds { get; }
/// <summary>
/// Wait for all processes to terminate.
/// </summary>
Task WaitForProcessesAsync(
TimeSpan timeout,
CancellationToken cancellationToken);
}
public class Win32Job : DisposableBase, IWin32Job
{
private readonly SafeJobHandle handle;
public Win32Job(bool killOnJobClose)
{
var securityAttributes = new NativeMethods.SECURITY_ATTRIBUTES()
{
nLength = Marshal.SizeOf<NativeMethods.SECURITY_ATTRIBUTES>(),
//
// Don't inherit hande to child process, otherwise they might
// be able to modify the job.
//
bInheritHandle = false
};
var job = NativeMethods.CreateJobObject(ref securityAttributes, null);
if (job.IsInvalid)
{
throw DispatchException.FromLastWin32Error("Creating job object failed");
}
if (killOnJobClose)
{
//
// Configure the job so that it kills all member processes
// when it's closed.
//
var jobLimits = new NativeMethods.JOBOBJECT_EXTENDED_LIMIT_INFORMATION()
{
BasicLimitInformation = new NativeMethods.JOBOBJECT_BASIC_LIMIT_INFORMATION()
{
LimitFlags = NativeMethods.JOB_OBJECT_LIMIT.KILL_ON_JOB_CLOSE
}
};
if (!NativeMethods.SetInformationJobObject(
job,
NativeMethods.JOBOBJECTINFOCLASS.JobObjectExtendedLimitInformation,
ref jobLimits,
(uint)Marshal.SizeOf<NativeMethods.JOBOBJECT_EXTENDED_LIMIT_INFORMATION>()))
{
job.Close();
throw DispatchException.FromLastWin32Error("Configuring job limits failed");
}
}
this.handle = job;
}
//---------------------------------------------------------------------
// IWin32Job.
//---------------------------------------------------------------------
public SafeHandle Handle => this.handle;
public void Add(IWin32Process process)
{
process.ExpectNotNull(nameof(process));
if (!NativeMethods.AssignProcessToJobObject(
this.handle,
process.Handle))
{
throw DispatchException.FromLastWin32Error(
$"Assigining the process {process} to the job failed");
}
}
public bool Contains(IWin32Process process)
{
process.ExpectNotNull(nameof(process));
if (!NativeMethods.IsProcessInJob(
process.Handle,
this.handle,
out var inJob))
{
throw DispatchException.FromLastWin32Error(
$"Checking if the process {process} is in the job failed");
}
return inJob;
}
public bool Contains(uint processId)
{
using (var process = NativeMethods.OpenProcess(
NativeMethods.PROCESS_QUERY_LIMITED_INFORMATION,
false,
processId))
{
if (process.IsInvalid)
{
throw DispatchException.FromLastWin32Error(
$"Accessing the process with PID {processId} failed");
}
if (!NativeMethods.IsProcessInJob(
process,
this.handle,
out var inJob))
{
throw DispatchException.FromLastWin32Error(
$"Checking if the process with PID {processId} " +
"is in the job failed");
}
return inJob;
}
}
public IEnumerable<uint> ProcessIds
{
get
{
var size = (uint)Marshal.SizeOf<NativeMethods.JOBOBJECT_BASIC_PROCESS_ID_LIST>();
while (true)
{
using (var listHandle = GlobalAllocSafeHandle.GlobalAlloc(size))
{
var listPtr = listHandle.DangerousGetHandle();
if (NativeMethods.QueryInformationJobObject(
this.handle,
NativeMethods.JOBOBJECTINFOCLASS.JobObjectBasicProcessIdList,
listPtr,
size,
out var requiredLength))
{
var list = Marshal.PtrToStructure<NativeMethods.JOBOBJECT_BASIC_PROCESS_ID_LIST>(listPtr);
var arrayOffset =
Marshal.SizeOf<NativeMethods.JOBOBJECT_BASIC_PROCESS_ID_LIST>() - UIntPtr.Size;
var pids = new uint[list.NumberOfProcessIdsInList];
for (var i = 0; i < pids.Length; i++)
{
pids[i] = (uint)Marshal.ReadIntPtr(listPtr, arrayOffset + i * UIntPtr.Size).ToInt32();
}
return pids;
}
else if (requiredLength > 0)
{
//
// Try again with proper size.
//
size = requiredLength;
}
else if (size < ushort.MaxValue)
{
//
// QueryInformationJobObject sometimes fails without
// setting a required length.
//
size *= 2;
}
else
{
throw new Win32Exception(
Marshal.GetLastWin32Error(),
$"Querying process list failed");
}
}
}
}
}
public Task WaitForProcessesAsync(
TimeSpan timeout,
CancellationToken cancellationToken)
{
if (!this.ProcessIds.Any())
{
return Task.CompletedTask;
}
//
// NB. Job objects aren't signaled when all processes
// have exited. We have to listen for a completion port
// notification instead.
//
//
// Create and associate a completion port while we're still
// on the calling thread. That way, we ensure that by the time
// the method returns, we're ready to receive notifications.
//
var completionPort = NativeMethods.CreateIoCompletionPort(
NativeMethods.INVALID_HANDLE_VALUE,
IntPtr.Zero,
IntPtr.Zero,
1);
var associationInfo = new NativeMethods.JOBOBJECT_ASSOCIATE_COMPLETION_PORT()
{
CompletionKey = this.handle,
CompletionPort = completionPort
};
if (!NativeMethods.SetInformationJobObject(
this.handle,
NativeMethods.JOBOBJECTINFOCLASS.JobObjectAssociateCompletionPortInformation,
ref associationInfo,
(uint)Marshal.SizeOf<NativeMethods.JOBOBJECT_ASSOCIATE_COMPLETION_PORT>()))
{
completionPort.Dispose();
throw new Win32Exception(
Marshal.GetLastWin32Error(),
$"Quering job completion status failed");
}
//
// There's no great way to await a completion port notification,
// so we have to sacrifice a thread for it.
//
return Task.Factory.StartNew(() =>
{
using (completionPort)
{
while (NativeMethods.GetQueuedCompletionStatus(
completionPort,
out var completionCode,
out var completionKey,
out var overlapped,
(uint)timeout.TotalSeconds))
{
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
else if (completionCode == NativeMethods.JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO)
{
//
// Active process count has dropped to 0.
//
return;
}
else
{
//
// Keep going.
//
}
}
var lastError = Marshal.GetLastWin32Error();
if (lastError == NativeMethods.WAIT_TIMEOUT)
{
throw new TimeoutException();
}
else
{
throw new Win32Exception(
lastError,
$"Quering job completion status failed");
}
}
},
TaskCreationOptions.LongRunning);
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
this.Handle.Dispose();
}
//---------------------------------------------------------------------
// P/Invoke.
//---------------------------------------------------------------------
private static class NativeMethods
{
internal static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1);
internal const uint PROCESS_QUERY_LIMITED_INFORMATION = 0x1000;
internal const uint SYNCHRONIZE = 0x00100000;
internal const int ERROR_NO_TOKEN = 1008;
internal const int WAIT_TIMEOUT = 0x102;
internal const uint JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO = 4;
[StructLayout(LayoutKind.Sequential)]
internal struct SECURITY_ATTRIBUTES
{
public int nLength;
public IntPtr lpSecurityDescriptor;
public bool bInheritHandle;
}
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
internal static extern SafeJobHandle CreateJobObject(
[In] ref SECURITY_ATTRIBUTES lpJobAttributes,
string? lpName);
[DllImport("kernel32.dll")]
internal static extern bool SetInformationJobObject(
SafeJobHandle hJob,
JOBOBJECTINFOCLASS infoClass,
ref JOBOBJECT_EXTENDED_LIMIT_INFORMATION lpJobObjectInfo,
uint cbJobObjectInfoLength);
[DllImport("kernel32.dll")]
internal static extern bool SetInformationJobObject(
SafeJobHandle hJob,
JOBOBJECTINFOCLASS infoClass,
ref JOBOBJECT_ASSOCIATE_COMPLETION_PORT lpJobObjectInfo,
uint cbJobObjectInfoLength);
[DllImport("kernel32.dll", SetLastError = true)]
[return: MarshalAs(UnmanagedType.Bool)]
internal static extern bool AssignProcessToJobObject(
SafeJobHandle hJob,
SafeProcessHandle hProcess);
[DllImport("kernel32.dll")]
[return: MarshalAs(UnmanagedType.Bool)]
internal static extern bool IsProcessInJob(
SafeProcessHandle process,
SafeJobHandle job, out bool result);
[DllImport("kernel32.dll", SetLastError = true)]
internal static extern SafeProcessHandle OpenProcess(
uint processAccess,
bool bInheritHandle,
uint processId);
[DllImport("kernel32.dll", SetLastError = true)]
[return: MarshalAs(UnmanagedType.Bool)]
internal static extern bool QueryInformationJobObject(
SafeJobHandle hJob,
JOBOBJECTINFOCLASS infoClass,
IntPtr lpJobObjectInfo,
uint cbJobObjectInfoLength,
out uint lpReturnLength);
[DllImport("kernel32.dll", SetLastError = true)]
internal static extern SafeCompletionPortHandle CreateIoCompletionPort(
IntPtr fileHandle,
IntPtr existingCompletionPort,
IntPtr completionKey,
int numberOfConcurrentThreads);
[DllImport("Kernel32.dll", SetLastError = true)]
internal static extern bool GetQueuedCompletionStatus(
SafeCompletionPortHandle completionPort,
out int lpNumberOfBytesTransferred,
out IntPtr lpCompletionKey,
out IntPtr lpOverlapped,
uint dwMilliseconds);
internal enum JOB_OBJECT_LIMIT
{
KILL_ON_JOB_CLOSE = 0x00002000
}
internal enum JOBOBJECTINFOCLASS
{
JobObjectBasicProcessIdList = 3,
JobObjectAssociateCompletionPortInformation = 7,
JobObjectExtendedLimitInformation = 9
}
[StructLayout(LayoutKind.Sequential)]
internal struct JOBOBJECT_BASIC_LIMIT_INFORMATION
{
public long PerProcessUserTimeLimit;
public long PerJobUserTimeLimit;
public JOB_OBJECT_LIMIT LimitFlags;
public UIntPtr MinimumWorkingSetSize;
public UIntPtr MaximumWorkingSetSize;
public uint ActiveProcessLimit;
public UIntPtr Affinity;
public uint PriorityClass;
public uint SchedulingClass;
}
[StructLayout(LayoutKind.Sequential)]
internal struct IO_COUNTERS
{
public ulong ReadOperationCount;
public ulong WriteOperationCount;
public ulong OtherOperationCount;
public ulong ReadTransferCount;
public ulong WriteTransferCount;
public ulong OtherTransferCount;
}
[StructLayout(LayoutKind.Sequential)]
internal struct JOBOBJECT_EXTENDED_LIMIT_INFORMATION
{
public JOBOBJECT_BASIC_LIMIT_INFORMATION BasicLimitInformation;
public IO_COUNTERS IoInfo;
public UIntPtr ProcessMemoryLimit;
public UIntPtr JobMemoryLimit;
public UIntPtr PeakProcessMemoryUsed;
public UIntPtr PeakJobMemoryUsed;
}
[StructLayout(LayoutKind.Sequential)]
internal struct JOBOBJECT_BASIC_PROCESS_ID_LIST
{
public uint NumberOfAssignedProcesses;
public uint NumberOfProcessIdsInList;
public UIntPtr ProcessIdListStart;
}
[StructLayout(LayoutKind.Sequential)]
internal struct JOBOBJECT_ASSOCIATE_COMPLETION_PORT
{
public SafeJobHandle CompletionKey;
public SafeCompletionPortHandle CompletionPort;
}
}
private class SafeJobHandle : Win32SafeHandle
{
public SafeJobHandle()
: base(true)
{
}
public SafeJobHandle(IntPtr handle, bool ownsHandle)
: base(handle, ownsHandle)
{
}
}
private class SafeCompletionPortHandle : Win32SafeHandle
{
public SafeCompletionPortHandle()
: base(true)
{
}
public SafeCompletionPortHandle(IntPtr handle, bool ownsHandle)
: base(handle, ownsHandle)
{
}
}
}
}