NTSTATUS PerformAdvancedPacketInjectionAtInboundTransport()

in network/trans/WFPSampler/sys/ClassifyFunctions_AdvancedPacketInjectionCallouts.cpp [2284:2835]


NTSTATUS PerformAdvancedPacketInjectionAtInboundTransport(_In_ CLASSIFY_DATA** ppClassifyData,
                                                          _In_ INJECTION_DATA** ppInjectionData,
                                                          _In_ BOOLEAN isInline = FALSE)
{
#if DBG

   DbgPrintEx(DPFLTR_IHVNETWORK_ID,
              DPFLTR_INFO_LEVEL,
              " ---> PerformAdvancedPacketInjectionAtInboundTransport()\n");

#endif /// DBG

   NT_ASSERT(ppClassifyData);
   NT_ASSERT(ppInjectionData);
   NT_ASSERT(*ppClassifyData);
   NT_ASSERT(*ppInjectionData);

   NTSTATUS                                   status              = STATUS_SUCCESS;
   FWPS_INCOMING_VALUES*                      pClassifyValues     = (FWPS_INCOMING_VALUES*)(*ppClassifyData)->pClassifyValues;
   FWPS_INCOMING_METADATA_VALUES*             pMetadata           = (FWPS_INCOMING_METADATA_VALUES*)(*ppClassifyData)->pMetadataValues;
   PC_ADVANCED_PACKET_INJECTION_DATA*         pData               = (PC_ADVANCED_PACKET_INJECTION_DATA*)(*ppClassifyData)->pFilter->providerContext->dataBuffer->data;
   COMPARTMENT_ID                             compartmentID       = DEFAULT_COMPARTMENT_ID;
   IF_INDEX                                   interfaceIndex      = 0;
   IF_INDEX                                   subInterfaceIndex   = 0;
   UINT32                                     flags               = 0;
   NET_BUFFER_LIST*                           pNetBufferList      = 0;
   UINT32                                     size                = 0;
   ADVANCED_PACKET_INJECTION_COMPLETION_DATA* pCompletionData     = 0;
   UINT32                                     ipHeaderSize        = 0;
   UINT32                                     transportHeaderSize = 0;
   UINT32                                     bytesRetreated      = 0;
   IPPROTO                                    protocol            = IPPROTO_MAX;
   FWP_VALUE*                                 pProtocol           = 0;
   FWP_VALUE*                                 pInterfaceIndex     = 0;
   FWP_VALUE*                                 pSubInterfaceIndex  = 0;
   FWP_VALUE*                                 pFlags              = 0;
   FWPS_PACKET_LIST_INFORMATION*              pPacketInformation  = 0;
   BOOLEAN                                    bypassInjection     = FALSE;
   BYTE*                                      pSourceAddress      = 0;
   BYTE*                                      pDestinationAddress = 0;
   NDIS_TCP_IP_CHECKSUM_PACKET_INFO           checksumInfo        = {0};

#if DBG

   KIRQL                                      irql                = KeGetCurrentIrql();
   HANDLE                                     injectionHandle     = (*ppInjectionData)->injectionHandle;

#endif /// DBG

#pragma warning(push)
#pragma warning(disable: 6014) /// pCompletionData will be freed in completionFn using AdvancedPacketInjectionCompletionDataDestroy

   HLPR_NEW(pCompletionData,
            ADVANCED_PACKET_INJECTION_COMPLETION_DATA,
            WFPSAMPLER_CALLOUT_DRIVER_TAG);
   HLPR_BAIL_ON_ALLOC_FAILURE(pCompletionData,
                              status);

#pragma warning(pop)

   KeInitializeSpinLock(&(pCompletionData->spinLock));

   pCompletionData->performedInline = isInline;
   pCompletionData->pClassifyData   = *ppClassifyData;
   pCompletionData->pInjectionData  = *ppInjectionData;

   /// Responsibility for freeing this memory has been transferred to the pCompletionData
   *ppClassifyData = 0;

   *ppInjectionData = 0;

   HLPR_NEW(pPacketInformation,
            FWPS_PACKET_LIST_INFORMATION,
            WFPSAMPLER_CALLOUT_DRIVER_TAG);
   HLPR_BAIL_ON_ALLOC_FAILURE(pPacketInformation,
                              status);
   pInterfaceIndex = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                               &FWPM_CONDITION_INTERFACE_INDEX);
   if(pInterfaceIndex &&
      pInterfaceIndex->type == FWP_UINT32)
      interfaceIndex = (IF_INDEX)pInterfaceIndex->uint32;

   pSubInterfaceIndex = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                  &FWPM_CONDITION_SUB_INTERFACE_INDEX);
   if(pSubInterfaceIndex &&
      pSubInterfaceIndex->type == FWP_UINT32)
      subInterfaceIndex = (IF_INDEX)pSubInterfaceIndex->uint32;

   pFlags = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                      &FWPM_CONDITION_FLAGS);
   if(pFlags &&
      pFlags->type == FWP_UINT32)
      flags = pFlags->uint32;

   if(pClassifyValues->layerId == FWPS_LAYER_INBOUND_ICMP_ERROR_V4 ||
      pClassifyValues->layerId == FWPS_LAYER_OUTBOUND_ICMP_ERROR_V4)
      protocol = IPPROTO_ICMP;
   else if(pClassifyValues->layerId == FWPS_LAYER_INBOUND_ICMP_ERROR_V6 ||
           pClassifyValues->layerId == FWPS_LAYER_OUTBOUND_ICMP_ERROR_V6)
      protocol = IPPROTO_ICMPV6;

#if(NTDDI_VERSION >= NTDDI_WIN7)

   else if(pClassifyValues->layerId == FWPS_LAYER_STREAM_PACKET_V4 ||
           pClassifyValues->layerId == FWPS_LAYER_STREAM_PACKET_V6)
      protocol = IPPROTO_TCP;

#endif /// (NTDDI_VERSION >= NTDDI_WIN7)

   else
   {
      pProtocol = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                            &FWPM_CONDITION_IP_PROTOCOL);
      HLPR_BAIL_ON_NULL_POINTER(pProtocol);

      protocol = (IPPROTO)pProtocol->uint8;
   }

   if(pClassifyValues->layerId == FWPS_LAYER_ALE_FLOW_ESTABLISHED_V4)
   {
      ipHeaderSize = IPV4_HEADER_MIN_SIZE;
   
      if(protocol == IPPROTO_ICMP)
         transportHeaderSize = ICMP_HEADER_MIN_SIZE;
      else if(protocol == IPPROTO_TCP)
         transportHeaderSize = TCP_HEADER_MIN_SIZE;
      else if(protocol == IPPROTO_UDP)
         transportHeaderSize = UDP_HEADER_MIN_SIZE;
   }
   else if(pClassifyValues->layerId == FWPS_LAYER_ALE_FLOW_ESTABLISHED_V6)
   {
      ipHeaderSize = IPV6_HEADER_MIN_SIZE;

      if(protocol == IPPROTO_ICMPV6)
         transportHeaderSize = ICMP_HEADER_MIN_SIZE;
      else if(protocol == IPPROTO_TCP)
         transportHeaderSize = TCP_HEADER_MIN_SIZE;
      else if(protocol == IPPROTO_UDP)
         transportHeaderSize = UDP_HEADER_MIN_SIZE;
   }

   if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                     FWPS_METADATA_FIELD_COMPARTMENT_ID))
      compartmentID = (COMPARTMENT_ID)pMetadata->compartmentId;

   if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                     FWPS_METADATA_FIELD_IP_HEADER_SIZE) &&
      pMetadata->ipHeaderSize)
      ipHeaderSize = pMetadata->ipHeaderSize;

   if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                     FWPS_METADATA_FIELD_TRANSPORT_HEADER_SIZE) &&
      pMetadata->transportHeaderSize)
      transportHeaderSize = pMetadata->transportHeaderSize;

   bytesRetreated = ipHeaderSize;

   if(protocol != IPPROTO_ICMP &&
      protocol != IPPROTO_ICMPV6)
   {
      if(!isInline &&
         protocol != IPPROTO_TCP &&
         !(protocol == IPPROTO_UDP &&
         flags & FWP_CONDITION_FLAG_IS_RAW_ENDPOINT) &&
         (pClassifyValues->layerId == FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V4 ||
         pClassifyValues->layerId == FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V6))
      {
         /// For asynchronous execution, the drop will cause the stack to continue processing on the 
         /// NBL for auditing purposes.  This processing retreats the NBL Offset to the Transport header.
         /// We need to take this into account because we only took a reference on the NBL.
      }
      else
         bytesRetreated += transportHeaderSize;
   }
   else
   {
      if(pClassifyValues->layerId == FWPS_LAYER_DATAGRAM_DATA_V4 ||
         pClassifyValues->layerId == FWPS_LAYER_DATAGRAM_DATA_V6 ||
         pClassifyValues->layerId == FWPS_LAYER_INBOUND_ICMP_ERROR_V4 ||
         pClassifyValues->layerId == FWPS_LAYER_INBOUND_ICMP_ERROR_V6)
      {
         if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                           FWPS_METADATA_FIELD_TRANSPORT_HEADER_SIZE))
            bytesRetreated += pMetadata->transportHeaderSize;
      }
   }

   /// Query to see if IPsec has applied tunnel mode SA's to this NET_BUFFER_LIST ...
   status = FwpsGetPacketListSecurityInformation((NET_BUFFER_LIST*)pCompletionData->pClassifyData->pPacket,
                                                 FWPS_PACKET_LIST_INFORMATION_QUERY_ALL_INBOUND,
                                                 pPacketInformation);
   if(status != STATUS_SUCCESS)
   {
      DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                 DPFLTR_ERROR_LEVEL,
                 " !!!! PerformAdvancedPacketInjectionAtInboundTransport: FwpsGetPacketListSecurityInformation() [status: %#x]\n",
                 status);

      HLPR_BAIL;
   }

   /// ... if it has, then bypass the injection until the NET_BUFFER_LIST has come out of the tunnel
   if((pPacketInformation->ipsecInformation.inbound.isTunnelMode &&
      !(pPacketInformation->ipsecInformation.inbound.isDeTunneled)) ||
      pPacketInformation->ipsecInformation.inbound.isSecure)
   {
      bypassInjection = TRUE;

      HLPR_BAIL;
   }

   /// Initial offset is at the data, so retreat the size of the IP Header and Transport Header ...
   /// For ICMP, offset is at the ICMP Header, so retreat the size of the IP Header ...
   status = NdisRetreatNetBufferDataStart(NET_BUFFER_LIST_FIRST_NB((NET_BUFFER_LIST*)pCompletionData->pClassifyData->pPacket),
                                          bytesRetreated,
                                          0,
                                          0);
   if(status != STATUS_SUCCESS)
   {
      DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                 DPFLTR_ERROR_LEVEL,
                 " !!!! PerformAdvancedPacketInjectionAtInboundTransport: NdisRetreatNetBufferDataStart() [status: %#x]\n",
                 status);

      HLPR_BAIL;
   }

   /// ... create a new NET_BUFFER_LIST based on the original NET_BUFFER_LIST ...
   pNetBufferList = KrnlHlprNBLCreateNew(g_pNDISPoolData->nblPoolHandle,
                                         (NET_BUFFER_LIST*)pCompletionData->pClassifyData->pPacket,
                                         &(pCompletionData->pAllocatedBuffer),
                                         &size,
                                         &(pCompletionData->pAllocatedMDL),
                                         pData->additionalBytes,
                                         FALSE);

   /// ... and advance the offset back to the original position.
   NdisAdvanceNetBufferDataStart(NET_BUFFER_LIST_FIRST_NB((NET_BUFFER_LIST*)pCompletionData->pClassifyData->pPacket),
                                 bytesRetreated,
                                 FALSE,
                                 0);

   if(!pNetBufferList)
   {
      DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                 DPFLTR_ERROR_LEVEL,
                 " !!!! PerformAdvancedPacketInjectionAtInboundTransport: KrnlHlprNBLCreateNew() [pNetBufferList: %#p]\n",
                 pNetBufferList);

      HLPR_BAIL;
   }

#if DBG
   
      AdvancedPacketInjectionCountersIncrement(injectionHandle,
                                               &g_apiOutstandingNewNBLs);
   
#endif /// DBG

   checksumInfo.Value = (ULONG)(ULONG_PTR)NET_BUFFER_LIST_INFO((NET_BUFFER_LIST*)pCompletionData->pClassifyData->pPacket,
                                                               TcpIpChecksumNetBufferListInfo);

   /// Handle if the packet was IPsec secured
   if(pCompletionData->pInjectionData->isIPsecSecured)
   {
      /// For performance reasons, IPsec leaves the original ESP / AH information in the IP Header ...
      UINT32     headerIncludeSize   = 0;
      UINT64     endpointHandle      = 0;
      UINT32     ipv4Address         = 0;
      UINT32     addressSize         = 0;
      FWP_VALUE* pRemoteAddressValue = 0;
      FWP_VALUE* pLocalAddressValue  = 0;
      FWP_VALUE* pProtocolValue      = 0;

      pRemoteAddressValue = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                      &FWPM_CONDITION_IP_REMOTE_ADDRESS);
      if(pRemoteAddressValue)
      {
         if(pRemoteAddressValue->type == FWP_BYTE_ARRAY16_TYPE)
            addressSize = IPV6_ADDRESS_SIZE;
         else
            addressSize = IPV4_ADDRESS_SIZE;               

         HLPR_NEW_ARRAY(pSourceAddress,
                        BYTE,
                        addressSize,
                        WFPSAMPLER_CALLOUT_DRIVER_TAG);
         HLPR_BAIL_ON_ALLOC_FAILURE(pSourceAddress,
                                    status);

         if(pRemoteAddressValue->type == FWP_BYTE_ARRAY16_TYPE)
            RtlCopyMemory(pSourceAddress,
                          pRemoteAddressValue->byteArray16->byteArray16,
                          addressSize);
         else
         {
            ipv4Address = htonl(pRemoteAddressValue->uint32);

            RtlCopyMemory(pSourceAddress,
                          &ipv4Address,
                          addressSize);
         }
      }

      pLocalAddressValue = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                     &FWPM_CONDITION_IP_LOCAL_ADDRESS);
      if(pLocalAddressValue)
      {
         if(pLocalAddressValue->type == FWP_BYTE_ARRAY16_TYPE)
            addressSize = IPV6_ADDRESS_SIZE;
         else
            addressSize = IPV4_ADDRESS_SIZE;

         HLPR_NEW_ARRAY(pDestinationAddress,
                        BYTE,
                        addressSize,
                        WFPSAMPLER_CALLOUT_DRIVER_TAG);
         HLPR_BAIL_ON_ALLOC_FAILURE(pDestinationAddress,
                                    status);

         if(pLocalAddressValue->type == FWP_BYTE_ARRAY16_TYPE)
            RtlCopyMemory(pDestinationAddress,
                          pLocalAddressValue->byteArray16->byteArray16,
                          addressSize);
         else
         {
            ipv4Address = htonl(pLocalAddressValue->uint32);

            RtlCopyMemory(pDestinationAddress,
                          &ipv4Address,
                          addressSize);
         }            
      }

      pProtocolValue = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                 &FWPM_CONDITION_IP_PROTOCOL);
      if(pProtocolValue &&
         pProtocolValue->type == FWP_UINT8)
         protocol = (IPPROTO)pProtocolValue->uint8;
      else
         protocol = IPPROTO_MAX;

      NT_ASSERT(protocol != IPPROTO_MAX);

#if (NTDDI_VERSION >= NTDDI_WIN6SP1)

      if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                        FWPS_METADATA_FIELD_TRANSPORT_HEADER_INCLUDE_HEADER))
         headerIncludeSize = pMetadata->headerIncludeHeaderLength;

#endif // (NTDDI_VERSION >= NTDDI_WIN6SP1)

      if(FWPS_IS_METADATA_FIELD_PRESENT(pMetadata,
                                        FWPS_METADATA_FIELD_TRANSPORT_ENDPOINT_HANDLE))
         endpointHandle = pMetadata->transportEndpointHandle;

      if(pSourceAddress == 0 ||
         pDestinationAddress == 0)
      {
         status = STATUS_INVALID_MEMBER;

         DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                    DPFLTR_ERROR_LEVEL,
                    " !!!! PerformAdvancedPacketModificationAtInboundTransport() [status: %#x][pSourceAddress: %#p][pDestinationAddress: %#p]\n",
                    status,
                    pSourceAddress,
                    pDestinationAddress);

         HLPR_BAIL;
      }

      /// ... so we must re-construct the IPHeader with the appropriate information
      status = FwpsConstructIpHeaderForTransportPacket(pNetBufferList,
                                                       headerIncludeSize,
                                                       pCompletionData->pInjectionData->addressFamily,
                                                       pSourceAddress,
                                                       pDestinationAddress,
                                                       protocol,
                                                       endpointHandle,
                                                       (const WSACMSGHDR*)pCompletionData->pInjectionData->pControlData,
                                                       pCompletionData->pInjectionData->controlDataLength,
                                                       0,
                                                       0,
                                                       interfaceIndex,
                                                       subInterfaceIndex);
      if(status != STATUS_SUCCESS)
      {
         DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                    DPFLTR_ERROR_LEVEL,
                    " !!!! PerformAdvancedPacketInjectionAtInboundTransport: FwpsConstructIpHeaderForTransportPacket() [status: %#x]\n",
                    status);

         HLPR_BAIL;
      }
   }
   /// Handle if this packet had the IP or Transport checksums offloaded or if it's loopback
   else if(checksumInfo.Receive.NdisPacketIpChecksumSucceeded ||
           checksumInfo.Receive.NdisPacketTcpChecksumSucceeded ||
           checksumInfo.Receive.NdisPacketUdpChecksumSucceeded ||
           flags & FWP_CONDITION_FLAG_IS_LOOPBACK)
   {
      /// Prevent TCP/IP Zone crossing and recalculate the checksums
      if(flags & FWP_CONDITION_FLAG_IS_LOOPBACK)
      {
         FWP_VALUE* pLocalAddress    = 0;
         FWP_VALUE* pRemoteAddress   = 0;
         FWP_VALUE* pLoopbackAddress = 0;

         pLocalAddress = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                    &FWPM_CONDITION_IP_REMOTE_ADDRESS);
         if(pLocalAddress &&
            ((pLocalAddress->type == FWP_UINT32 &&
            RtlCompareMemory(&(pLocalAddress->uint32),
                             IPV4_LOOPBACK_ADDRESS,
                             IPV4_ADDRESS_SIZE)) ||
            (pLocalAddress->type == FWP_BYTE_ARRAY16_TYPE &&
            RtlCompareMemory(pLocalAddress->byteArray16->byteArray16,
                             IPV6_LOOPBACK_ADDRESS,
                             IPV6_ADDRESS_SIZE))))
            pLoopbackAddress = pLocalAddress;

         if(!pLoopbackAddress)
         {
            pRemoteAddress = KrnlHlprFwpValueGetFromFwpsIncomingValues(pClassifyValues,
                                                                       &FWPM_CONDITION_IP_REMOTE_ADDRESS);
            if(pRemoteAddress &&
               ((pRemoteAddress->type == FWP_UINT32 &&
               RtlCompareMemory(&(pRemoteAddress->uint32),
                                IPV4_LOOPBACK_ADDRESS,
                                IPV4_ADDRESS_SIZE)) ||
               (pRemoteAddress->type == FWP_BYTE_ARRAY16_TYPE &&
               RtlCompareMemory(pRemoteAddress->byteArray16->byteArray16,
                                IPV6_LOOPBACK_ADDRESS,
                                IPV6_ADDRESS_SIZE))))
               pLoopbackAddress = pRemoteAddress;
         }

         if(pLoopbackAddress)
         {
            status = KrnlHlprIPHeaderModifyLoopbackToLocal(pMetadata,
                                                           pLoopbackAddress,
                                                           ipHeaderSize,
                                                           pNetBufferList,
                                                           (const WSACMSGHDR*)pCompletionData->pInjectionData->pControlData,
                                                           pCompletionData->pInjectionData->controlDataLength);
            if(status != STATUS_SUCCESS)
            {
               DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                          DPFLTR_ERROR_LEVEL,
                          " !!!! PerformAdvancedPacketInjectionAtInboundTransport: KrnlHlprIPHeaderModifyLoopbackToLocal() [status: %#x]\n",
                          status);

               HLPR_BAIL;
            }
         }
      }
      else
      {
         /// Recalculate the checksum
         if(pCompletionData->pInjectionData->addressFamily == AF_INET)
            KrnlHlprIPHeaderCalculateV4Checksum(pNetBufferList,
                                                ipHeaderSize);
      }
   }

   pCompletionData->refCount = KrnlHlprNBLGetRequiredRefCount(pNetBufferList);

   status = FwpsInjectTransportReceiveAsync(pCompletionData->pInjectionData->injectionHandle,
                                            pCompletionData->pInjectionData->injectionContext,
                                            0,
                                            0,
                                            pCompletionData->pInjectionData->addressFamily,
                                            compartmentID,
                                            interfaceIndex,
                                            subInterfaceIndex,
                                            pNetBufferList,
                                            CompleteAdvancedPacketInjection,
                                            pCompletionData);

   NT_ASSERT(irql == KeGetCurrentIrql());

   if(status != STATUS_SUCCESS)
   {
      DbgPrintEx(DPFLTR_IHVNETWORK_ID,
                 DPFLTR_ERROR_LEVEL,
                 " !!!! PerformAdvancedPacketInjectionAtInboundTransport: FwpsInjectTransportReceiveAsync() [status: %#x]\n",
                 status);

#if DBG

      AdvancedPacketInjectionCountersIncrement(injectionHandle,
                                               &g_apiTotalFailedInjectionCalls);

#endif /// DBG

   }

#if DBG

   else
      AdvancedPacketInjectionCountersIncrement(injectionHandle,
                                               &g_apiTotalSuccessfulInjectionCalls);

#endif /// DBG

   HLPR_BAIL_LABEL:

   NT_ASSERT(status == STATUS_SUCCESS);

   if(status != STATUS_SUCCESS ||
      bypassInjection)
   {
      if(pNetBufferList)
      {
         FwpsFreeNetBufferList(pNetBufferList);

         pNetBufferList = 0;

#if DBG

         AdvancedPacketInjectionCountersDecrement(injectionHandle,
                                                  &g_apiOutstandingNewNBLs);

#endif

      }

      if(pCompletionData)
         AdvancedPacketInjectionCompletionDataDestroy(&pCompletionData,
                                                      TRUE);
   }

   HLPR_DELETE_ARRAY(pSourceAddress,
                     WFPSAMPLER_CALLOUT_DRIVER_TAG);

   HLPR_DELETE_ARRAY(pDestinationAddress,
                     WFPSAMPLER_CALLOUT_DRIVER_TAG);

   HLPR_DELETE(pPacketInformation,
               WFPSAMPLER_CALLOUT_DRIVER_TAG);

#if DBG

   DbgPrintEx(DPFLTR_IHVNETWORK_ID,
              DPFLTR_INFO_LEVEL,
              " <--- PerformAdvancedPacketInjectionAtInboundTransport() [status: %#x]\n",
              status);

#endif /// DBG

   return status;
}