bool TargetLowering::SimplifyDemandedBits()

in lib/CodeGen/SelectionDAG/TargetLowering.cpp [376:1121]


bool TargetLowering::SimplifyDemandedBits(SDValue Op,
                                          const APInt &DemandedMask,
                                          APInt &KnownZero,
                                          APInt &KnownOne,
                                          TargetLoweringOpt &TLO,
                                          unsigned Depth) const {
  unsigned BitWidth = DemandedMask.getBitWidth();
  assert(Op.getValueType().getScalarType().getSizeInBits() == BitWidth &&
         "Mask size mismatches value type size!");
  APInt NewMask = DemandedMask;
  SDLoc dl(Op);
  auto &DL = TLO.DAG.getDataLayout();

  // Don't know anything.
  KnownZero = KnownOne = APInt(BitWidth, 0);

  // Other users may use these bits.
  if (!Op.getNode()->hasOneUse()) {
    if (Depth != 0) {
      // If not at the root, Just compute the KnownZero/KnownOne bits to
      // simplify things downstream.
      TLO.DAG.computeKnownBits(Op, KnownZero, KnownOne, Depth);
      return false;
    }
    // If this is the root being simplified, allow it to have multiple uses,
    // just set the NewMask to all bits.
    NewMask = APInt::getAllOnesValue(BitWidth);
  } else if (DemandedMask == 0) {
    // Not demanding any bits from Op.
    if (Op.getOpcode() != ISD::UNDEF)
      return TLO.CombineTo(Op, TLO.DAG.getUNDEF(Op.getValueType()));
    return false;
  } else if (Depth == 6) {        // Limit search depth.
    return false;
  }

  APInt KnownZero2, KnownOne2, KnownZeroOut, KnownOneOut;
  switch (Op.getOpcode()) {
  case ISD::Constant:
    // We know all of the bits for a constant!
    KnownOne = cast<ConstantSDNode>(Op)->getAPIntValue();
    KnownZero = ~KnownOne;
    return false;   // Don't fall through, will infinitely loop.
  case ISD::AND:
    // If the RHS is a constant, check to see if the LHS would be zero without
    // using the bits from the RHS.  Below, we use knowledge about the RHS to
    // simplify the LHS, here we're using information from the LHS to simplify
    // the RHS.
    if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      APInt LHSZero, LHSOne;
      // Do not increment Depth here; that can cause an infinite loop.
      TLO.DAG.computeKnownBits(Op.getOperand(0), LHSZero, LHSOne, Depth);
      // If the LHS already has zeros where RHSC does, this and is dead.
      if ((LHSZero & NewMask) == (~RHSC->getAPIntValue() & NewMask))
        return TLO.CombineTo(Op, Op.getOperand(0));
      // If any of the set bits in the RHS are known zero on the LHS, shrink
      // the constant.
      if (TLO.ShrinkDemandedConstant(Op, ~LHSZero & NewMask))
        return true;
    }

    if (SimplifyDemandedBits(Op.getOperand(1), NewMask, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    if (SimplifyDemandedBits(Op.getOperand(0), ~KnownZero & NewMask,
                             KnownZero2, KnownOne2, TLO, Depth+1))
      return true;
    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");

    // If all of the demanded bits are known one on one side, return the other.
    // These bits cannot contribute to the result of the 'and'.
    if ((NewMask & ~KnownZero2 & KnownOne) == (~KnownZero2 & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(0));
    if ((NewMask & ~KnownZero & KnownOne2) == (~KnownZero & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(1));
    // If all of the demanded bits in the inputs are known zeros, return zero.
    if ((NewMask & (KnownZero|KnownZero2)) == NewMask)
      return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, Op.getValueType()));
    // If the RHS is a constant, see if we can simplify it.
    if (TLO.ShrinkDemandedConstant(Op, ~KnownZero2 & NewMask))
      return true;
    // If the operation can be done in a smaller type, do so.
    if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
      return true;

    // Output known-1 bits are only known if set in both the LHS & RHS.
    KnownOne &= KnownOne2;
    // Output known-0 are known to be clear if zero in either the LHS | RHS.
    KnownZero |= KnownZero2;
    break;
  case ISD::OR:
    if (SimplifyDemandedBits(Op.getOperand(1), NewMask, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    if (SimplifyDemandedBits(Op.getOperand(0), ~KnownOne & NewMask,
                             KnownZero2, KnownOne2, TLO, Depth+1))
      return true;
    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");

    // If all of the demanded bits are known zero on one side, return the other.
    // These bits cannot contribute to the result of the 'or'.
    if ((NewMask & ~KnownOne2 & KnownZero) == (~KnownOne2 & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(0));
    if ((NewMask & ~KnownOne & KnownZero2) == (~KnownOne & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(1));
    // If all of the potentially set bits on one side are known to be set on
    // the other side, just use the 'other' side.
    if ((NewMask & ~KnownZero & KnownOne2) == (~KnownZero & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(0));
    if ((NewMask & ~KnownZero2 & KnownOne) == (~KnownZero2 & NewMask))
      return TLO.CombineTo(Op, Op.getOperand(1));
    // If the RHS is a constant, see if we can simplify it.
    if (TLO.ShrinkDemandedConstant(Op, NewMask))
      return true;
    // If the operation can be done in a smaller type, do so.
    if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
      return true;

    // Output known-0 bits are only known if clear in both the LHS & RHS.
    KnownZero &= KnownZero2;
    // Output known-1 are known to be set if set in either the LHS | RHS.
    KnownOne |= KnownOne2;
    break;
  case ISD::XOR:
    if (SimplifyDemandedBits(Op.getOperand(1), NewMask, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    if (SimplifyDemandedBits(Op.getOperand(0), NewMask, KnownZero2,
                             KnownOne2, TLO, Depth+1))
      return true;
    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");

    // If all of the demanded bits are known zero on one side, return the other.
    // These bits cannot contribute to the result of the 'xor'.
    if ((KnownZero & NewMask) == NewMask)
      return TLO.CombineTo(Op, Op.getOperand(0));
    if ((KnownZero2 & NewMask) == NewMask)
      return TLO.CombineTo(Op, Op.getOperand(1));
    // If the operation can be done in a smaller type, do so.
    if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
      return true;

    // If all of the unknown bits are known to be zero on one side or the other
    // (but not both) turn this into an *inclusive* or.
    //    e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
    if ((NewMask & ~KnownZero & ~KnownZero2) == 0)
      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, Op.getValueType(),
                                               Op.getOperand(0),
                                               Op.getOperand(1)));

    // Output known-0 bits are known if clear or set in both the LHS & RHS.
    KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2);
    // Output known-1 are known to be set if set in only one of the LHS, RHS.
    KnownOneOut = (KnownZero & KnownOne2) | (KnownOne & KnownZero2);

    // If all of the demanded bits on one side are known, and all of the set
    // bits on that side are also known to be set on the other side, turn this
    // into an AND, as we know the bits will be cleared.
    //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
    // NB: it is okay if more bits are known than are requested
    if ((NewMask & (KnownZero|KnownOne)) == NewMask) { // all known on one side
      if (KnownOne == KnownOne2) { // set bits are the same on both sides
        EVT VT = Op.getValueType();
        SDValue ANDC = TLO.DAG.getConstant(~KnownOne & NewMask, dl, VT);
        return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT,
                                                 Op.getOperand(0), ANDC));
      }
    }

    // If the RHS is a constant, see if we can simplify it.
    // for XOR, we prefer to force bits to 1 if they will make a -1.
    // if we can't force bits, try to shrink constant
    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      APInt Expanded = C->getAPIntValue() | (~NewMask);
      // if we can expand it to have all bits set, do it
      if (Expanded.isAllOnesValue()) {
        if (Expanded != C->getAPIntValue()) {
          EVT VT = Op.getValueType();
          SDValue New = TLO.DAG.getNode(Op.getOpcode(), dl,VT, Op.getOperand(0),
                                        TLO.DAG.getConstant(Expanded, dl, VT));
          return TLO.CombineTo(Op, New);
        }
        // if it already has all the bits set, nothing to change
        // but don't shrink either!
      } else if (TLO.ShrinkDemandedConstant(Op, NewMask)) {
        return true;
      }
    }

    KnownZero = KnownZeroOut;
    KnownOne  = KnownOneOut;
    break;
  case ISD::SELECT:
    if (SimplifyDemandedBits(Op.getOperand(2), NewMask, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    if (SimplifyDemandedBits(Op.getOperand(1), NewMask, KnownZero2,
                             KnownOne2, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");

    // If the operands are constants, see if we can simplify them.
    if (TLO.ShrinkDemandedConstant(Op, NewMask))
      return true;

    // Only known if known in both the LHS and RHS.
    KnownOne &= KnownOne2;
    KnownZero &= KnownZero2;
    break;
  case ISD::SELECT_CC:
    if (SimplifyDemandedBits(Op.getOperand(3), NewMask, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    if (SimplifyDemandedBits(Op.getOperand(2), NewMask, KnownZero2,
                             KnownOne2, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");

    // If the operands are constants, see if we can simplify them.
    if (TLO.ShrinkDemandedConstant(Op, NewMask))
      return true;

    // Only known if known in both the LHS and RHS.
    KnownOne &= KnownOne2;
    KnownZero &= KnownZero2;
    break;
  case ISD::SHL:
    if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      unsigned ShAmt = SA->getZExtValue();
      SDValue InOp = Op.getOperand(0);

      // If the shift count is an invalid immediate, don't do anything.
      if (ShAmt >= BitWidth)
        break;

      // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
      // single shift.  We can do this if the bottom bits (which are shifted
      // out) are never demanded.
      if (InOp.getOpcode() == ISD::SRL &&
          isa<ConstantSDNode>(InOp.getOperand(1))) {
        if (ShAmt && (NewMask & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) {
          unsigned C1= cast<ConstantSDNode>(InOp.getOperand(1))->getZExtValue();
          unsigned Opc = ISD::SHL;
          int Diff = ShAmt-C1;
          if (Diff < 0) {
            Diff = -Diff;
            Opc = ISD::SRL;
          }

          SDValue NewSA =
            TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType());
          EVT VT = Op.getValueType();
          return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT,
                                                   InOp.getOperand(0), NewSA));
        }
      }

      if (SimplifyDemandedBits(InOp, NewMask.lshr(ShAmt),
                               KnownZero, KnownOne, TLO, Depth+1))
        return true;

      // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
      // are not demanded. This will likely allow the anyext to be folded away.
      if (InOp.getNode()->getOpcode() == ISD::ANY_EXTEND) {
        SDValue InnerOp = InOp.getNode()->getOperand(0);
        EVT InnerVT = InnerOp.getValueType();
        unsigned InnerBits = InnerVT.getSizeInBits();
        if (ShAmt < InnerBits && NewMask.lshr(InnerBits) == 0 &&
            isTypeDesirableForOp(ISD::SHL, InnerVT)) {
          EVT ShTy = getShiftAmountTy(InnerVT, DL);
          if (!APInt(BitWidth, ShAmt).isIntN(ShTy.getSizeInBits()))
            ShTy = InnerVT;
          SDValue NarrowShl =
            TLO.DAG.getNode(ISD::SHL, dl, InnerVT, InnerOp,
                            TLO.DAG.getConstant(ShAmt, dl, ShTy));
          return
            TLO.CombineTo(Op,
                          TLO.DAG.getNode(ISD::ANY_EXTEND, dl, Op.getValueType(),
                                          NarrowShl));
        }
        // Repeat the SHL optimization above in cases where an extension
        // intervenes: (shl (anyext (shr x, c1)), c2) to
        // (shl (anyext x), c2-c1).  This requires that the bottom c1 bits
        // aren't demanded (as above) and that the shifted upper c1 bits of
        // x aren't demanded.
        if (InOp.hasOneUse() &&
            InnerOp.getOpcode() == ISD::SRL &&
            InnerOp.hasOneUse() &&
            isa<ConstantSDNode>(InnerOp.getOperand(1))) {
          uint64_t InnerShAmt = cast<ConstantSDNode>(InnerOp.getOperand(1))
            ->getZExtValue();
          if (InnerShAmt < ShAmt &&
              InnerShAmt < InnerBits &&
              NewMask.lshr(InnerBits - InnerShAmt + ShAmt) == 0 &&
              NewMask.trunc(ShAmt) == 0) {
            SDValue NewSA =
              TLO.DAG.getConstant(ShAmt - InnerShAmt, dl,
                                  Op.getOperand(1).getValueType());
            EVT VT = Op.getValueType();
            SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
                                             InnerOp.getOperand(0));
            return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT,
                                                     NewExt, NewSA));
          }
        }
      }

      KnownZero <<= SA->getZExtValue();
      KnownOne  <<= SA->getZExtValue();
      // low bits known zero.
      KnownZero |= APInt::getLowBitsSet(BitWidth, SA->getZExtValue());
    }
    break;
  case ISD::SRL:
    if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      EVT VT = Op.getValueType();
      unsigned ShAmt = SA->getZExtValue();
      unsigned VTSize = VT.getSizeInBits();
      SDValue InOp = Op.getOperand(0);

      // If the shift count is an invalid immediate, don't do anything.
      if (ShAmt >= BitWidth)
        break;

      APInt InDemandedMask = (NewMask << ShAmt);

      // If the shift is exact, then it does demand the low bits (and knows that
      // they are zero).
      if (cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact())
        InDemandedMask |= APInt::getLowBitsSet(BitWidth, ShAmt);

      // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
      // single shift.  We can do this if the top bits (which are shifted out)
      // are never demanded.
      if (InOp.getOpcode() == ISD::SHL &&
          isa<ConstantSDNode>(InOp.getOperand(1))) {
        if (ShAmt && (NewMask & APInt::getHighBitsSet(VTSize, ShAmt)) == 0) {
          unsigned C1= cast<ConstantSDNode>(InOp.getOperand(1))->getZExtValue();
          unsigned Opc = ISD::SRL;
          int Diff = ShAmt-C1;
          if (Diff < 0) {
            Diff = -Diff;
            Opc = ISD::SHL;
          }

          SDValue NewSA =
            TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType());
          return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT,
                                                   InOp.getOperand(0), NewSA));
        }
      }

      // Compute the new bits that are at the top now.
      if (SimplifyDemandedBits(InOp, InDemandedMask,
                               KnownZero, KnownOne, TLO, Depth+1))
        return true;
      assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
      KnownZero = KnownZero.lshr(ShAmt);
      KnownOne  = KnownOne.lshr(ShAmt);

      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
      KnownZero |= HighBits;  // High bits known zero.
    }
    break;
  case ISD::SRA:
    // If this is an arithmetic shift right and only the low-bit is set, we can
    // always convert this into a logical shr, even if the shift amount is
    // variable.  The low bit of the shift cannot be an input sign bit unless
    // the shift amount is >= the size of the datatype, which is undefined.
    if (NewMask == 1)
      return TLO.CombineTo(Op,
                           TLO.DAG.getNode(ISD::SRL, dl, Op.getValueType(),
                                           Op.getOperand(0), Op.getOperand(1)));

    if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
      EVT VT = Op.getValueType();
      unsigned ShAmt = SA->getZExtValue();

      // If the shift count is an invalid immediate, don't do anything.
      if (ShAmt >= BitWidth)
        break;

      APInt InDemandedMask = (NewMask << ShAmt);

      // If the shift is exact, then it does demand the low bits (and knows that
      // they are zero).
      if (cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact())
        InDemandedMask |= APInt::getLowBitsSet(BitWidth, ShAmt);

      // If any of the demanded bits are produced by the sign extension, we also
      // demand the input sign bit.
      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
      if (HighBits.intersects(NewMask))
        InDemandedMask |= APInt::getSignBit(VT.getScalarType().getSizeInBits());

      if (SimplifyDemandedBits(Op.getOperand(0), InDemandedMask,
                               KnownZero, KnownOne, TLO, Depth+1))
        return true;
      assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
      KnownZero = KnownZero.lshr(ShAmt);
      KnownOne  = KnownOne.lshr(ShAmt);

      // Handle the sign bit, adjusted to where it is now in the mask.
      APInt SignBit = APInt::getSignBit(BitWidth).lshr(ShAmt);

      // If the input sign bit is known to be zero, or if none of the top bits
      // are demanded, turn this into an unsigned shift right.
      if (KnownZero.intersects(SignBit) || (HighBits & ~NewMask) == HighBits) {
        SDNodeFlags Flags;
        Flags.setExact(cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact());
        return TLO.CombineTo(Op,
                             TLO.DAG.getNode(ISD::SRL, dl, VT, Op.getOperand(0),
                                             Op.getOperand(1), &Flags));
      }

      int Log2 = NewMask.exactLogBase2();
      if (Log2 >= 0) {
        // The bit must come from the sign.
        SDValue NewSA =
          TLO.DAG.getConstant(BitWidth - 1 - Log2, dl,
                              Op.getOperand(1).getValueType());
        return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT,
                                                 Op.getOperand(0), NewSA));
      }

      if (KnownOne.intersects(SignBit))
        // New bits are known one.
        KnownOne |= HighBits;
    }
    break;
  case ISD::SIGN_EXTEND_INREG: {
    EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();

    APInt MsbMask = APInt::getHighBitsSet(BitWidth, 1);
    // If we only care about the highest bit, don't bother shifting right.
    if (MsbMask == NewMask) {
      unsigned ShAmt = ExVT.getScalarType().getSizeInBits();
      SDValue InOp = Op.getOperand(0);
      unsigned VTBits = Op->getValueType(0).getScalarType().getSizeInBits();
      bool AlreadySignExtended =
        TLO.DAG.ComputeNumSignBits(InOp) >= VTBits-ShAmt+1;
      // However if the input is already sign extended we expect the sign
      // extension to be dropped altogether later and do not simplify.
      if (!AlreadySignExtended) {
        // Compute the correct shift amount type, which must be getShiftAmountTy
        // for scalar types after legalization.
        EVT ShiftAmtTy = Op.getValueType();
        if (TLO.LegalTypes() && !ShiftAmtTy.isVector())
          ShiftAmtTy = getShiftAmountTy(ShiftAmtTy, DL);

        SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ShAmt, dl,
                                               ShiftAmtTy);
        return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl,
                                                 Op.getValueType(), InOp,
                                                 ShiftAmt));
      }
    }

    // Sign extension.  Compute the demanded bits in the result that are not
    // present in the input.
    APInt NewBits =
      APInt::getHighBitsSet(BitWidth,
                            BitWidth - ExVT.getScalarType().getSizeInBits());

    // If none of the extended bits are demanded, eliminate the sextinreg.
    if ((NewBits & NewMask) == 0)
      return TLO.CombineTo(Op, Op.getOperand(0));

    APInt InSignBit =
      APInt::getSignBit(ExVT.getScalarType().getSizeInBits()).zext(BitWidth);
    APInt InputDemandedBits =
      APInt::getLowBitsSet(BitWidth,
                           ExVT.getScalarType().getSizeInBits()) &
      NewMask;

    // Since the sign extended bits are demanded, we know that the sign
    // bit is demanded.
    InputDemandedBits |= InSignBit;

    if (SimplifyDemandedBits(Op.getOperand(0), InputDemandedBits,
                             KnownZero, KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");

    // If the sign bit of the input is known set or clear, then we know the
    // top bits of the result.

    // If the input sign bit is known zero, convert this into a zero extension.
    if (KnownZero.intersects(InSignBit))
      return TLO.CombineTo(Op,
                          TLO.DAG.getZeroExtendInReg(Op.getOperand(0),dl,ExVT));

    if (KnownOne.intersects(InSignBit)) {    // Input sign bit known set
      KnownOne |= NewBits;
      KnownZero &= ~NewBits;
    } else {                       // Input sign bit unknown
      KnownZero &= ~NewBits;
      KnownOne &= ~NewBits;
    }
    break;
  }
  case ISD::BUILD_PAIR: {
    EVT HalfVT = Op.getOperand(0).getValueType();
    unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();

    APInt MaskLo = NewMask.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
    APInt MaskHi = NewMask.getHiBits(HalfBitWidth).trunc(HalfBitWidth);

    APInt KnownZeroLo, KnownOneLo;
    APInt KnownZeroHi, KnownOneHi;

    if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownZeroLo,
                             KnownOneLo, TLO, Depth + 1))
      return true;

    if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownZeroHi,
                             KnownOneHi, TLO, Depth + 1))
      return true;

    KnownZero = KnownZeroLo.zext(BitWidth) |
                KnownZeroHi.zext(BitWidth).shl(HalfBitWidth);

    KnownOne = KnownOneLo.zext(BitWidth) |
               KnownOneHi.zext(BitWidth).shl(HalfBitWidth);
    break;
  }
  case ISD::ZERO_EXTEND: {
    unsigned OperandBitWidth =
      Op.getOperand(0).getValueType().getScalarType().getSizeInBits();
    APInt InMask = NewMask.trunc(OperandBitWidth);

    // If none of the top bits are demanded, convert this into an any_extend.
    APInt NewBits =
      APInt::getHighBitsSet(BitWidth, BitWidth - OperandBitWidth) & NewMask;
    if (!NewBits.intersects(NewMask))
      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl,
                                               Op.getValueType(),
                                               Op.getOperand(0)));

    if (SimplifyDemandedBits(Op.getOperand(0), InMask,
                             KnownZero, KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    KnownZero = KnownZero.zext(BitWidth);
    KnownOne = KnownOne.zext(BitWidth);
    KnownZero |= NewBits;
    break;
  }
  case ISD::SIGN_EXTEND: {
    EVT InVT = Op.getOperand(0).getValueType();
    unsigned InBits = InVT.getScalarType().getSizeInBits();
    APInt InMask    = APInt::getLowBitsSet(BitWidth, InBits);
    APInt InSignBit = APInt::getBitsSet(BitWidth, InBits - 1, InBits);
    APInt NewBits   = ~InMask & NewMask;

    // If none of the top bits are demanded, convert this into an any_extend.
    if (NewBits == 0)
      return TLO.CombineTo(Op,TLO.DAG.getNode(ISD::ANY_EXTEND, dl,
                                              Op.getValueType(),
                                              Op.getOperand(0)));

    // Since some of the sign extended bits are demanded, we know that the sign
    // bit is demanded.
    APInt InDemandedBits = InMask & NewMask;
    InDemandedBits |= InSignBit;
    InDemandedBits = InDemandedBits.trunc(InBits);

    if (SimplifyDemandedBits(Op.getOperand(0), InDemandedBits, KnownZero,
                             KnownOne, TLO, Depth+1))
      return true;
    KnownZero = KnownZero.zext(BitWidth);
    KnownOne = KnownOne.zext(BitWidth);

    // If the sign bit is known zero, convert this to a zero extend.
    if (KnownZero.intersects(InSignBit))
      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl,
                                               Op.getValueType(),
                                               Op.getOperand(0)));

    // If the sign bit is known one, the top bits match.
    if (KnownOne.intersects(InSignBit)) {
      KnownOne |= NewBits;
      assert((KnownZero & NewBits) == 0);
    } else {   // Otherwise, top bits aren't known.
      assert((KnownOne & NewBits) == 0);
      assert((KnownZero & NewBits) == 0);
    }
    break;
  }
  case ISD::ANY_EXTEND: {
    unsigned OperandBitWidth =
      Op.getOperand(0).getValueType().getScalarType().getSizeInBits();
    APInt InMask = NewMask.trunc(OperandBitWidth);
    if (SimplifyDemandedBits(Op.getOperand(0), InMask,
                             KnownZero, KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    KnownZero = KnownZero.zext(BitWidth);
    KnownOne = KnownOne.zext(BitWidth);
    break;
  }
  case ISD::TRUNCATE: {
    // Simplify the input, using demanded bit information, and compute the known
    // zero/one bits live out.
    unsigned OperandBitWidth =
      Op.getOperand(0).getValueType().getScalarType().getSizeInBits();
    APInt TruncMask = NewMask.zext(OperandBitWidth);
    if (SimplifyDemandedBits(Op.getOperand(0), TruncMask,
                             KnownZero, KnownOne, TLO, Depth+1))
      return true;
    KnownZero = KnownZero.trunc(BitWidth);
    KnownOne = KnownOne.trunc(BitWidth);

    // If the input is only used by this truncate, see if we can shrink it based
    // on the known demanded bits.
    if (Op.getOperand(0).getNode()->hasOneUse()) {
      SDValue In = Op.getOperand(0);
      switch (In.getOpcode()) {
      default: break;
      case ISD::SRL:
        // Shrink SRL by a constant if none of the high bits shifted in are
        // demanded.
        if (TLO.LegalTypes() &&
            !isTypeDesirableForOp(ISD::SRL, Op.getValueType()))
          // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
          // undesirable.
          break;
        ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(In.getOperand(1));
        if (!ShAmt)
          break;
        SDValue Shift = In.getOperand(1);
        if (TLO.LegalTypes()) {
          uint64_t ShVal = ShAmt->getZExtValue();
          Shift = TLO.DAG.getConstant(ShVal, dl,
                                      getShiftAmountTy(Op.getValueType(), DL));
        }

        APInt HighBits = APInt::getHighBitsSet(OperandBitWidth,
                                               OperandBitWidth - BitWidth);
        HighBits = HighBits.lshr(ShAmt->getZExtValue()).trunc(BitWidth);

        if (ShAmt->getZExtValue() < BitWidth && !(HighBits & NewMask)) {
          // None of the shifted in bits are needed.  Add a truncate of the
          // shift input, then shift it.
          SDValue NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, dl,
                                             Op.getValueType(),
                                             In.getOperand(0));
          return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl,
                                                   Op.getValueType(),
                                                   NewTrunc,
                                                   Shift));
        }
        break;
      }
    }

    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
    break;
  }
  case ISD::AssertZext: {
    // AssertZext demands all of the high bits, plus any of the low bits
    // demanded by its users.
    EVT VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
    APInt InMask = APInt::getLowBitsSet(BitWidth,
                                        VT.getSizeInBits());
    if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | NewMask,
                             KnownZero, KnownOne, TLO, Depth+1))
      return true;
    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");

    KnownZero |= ~InMask & NewMask;
    break;
  }
  case ISD::BITCAST:
    // If this is an FP->Int bitcast and if the sign bit is the only
    // thing demanded, turn this into a FGETSIGN.
    if (!TLO.LegalOperations() &&
        !Op.getValueType().isVector() &&
        !Op.getOperand(0).getValueType().isVector() &&
        NewMask == APInt::getSignBit(Op.getValueType().getSizeInBits()) &&
        Op.getOperand(0).getValueType().isFloatingPoint()) {
      bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, Op.getValueType());
      bool i32Legal  = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
      if ((OpVTLegal || i32Legal) && Op.getValueType().isSimple()) {
        EVT Ty = OpVTLegal ? Op.getValueType() : MVT::i32;
        // Make a FGETSIGN + SHL to move the sign bit into the appropriate
        // place.  We expect the SHL to be eliminated by other optimizations.
        SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Op.getOperand(0));
        unsigned OpVTSizeInBits = Op.getValueType().getSizeInBits();
        if (!OpVTLegal && OpVTSizeInBits > 32)
          Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, Op.getValueType(), Sign);
        unsigned ShVal = Op.getValueType().getSizeInBits()-1;
        SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, Op.getValueType());
        return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl,
                                                 Op.getValueType(),
                                                 Sign, ShAmt));
      }
    }
    break;
  case ISD::ADD:
  case ISD::MUL:
  case ISD::SUB: {
    // Add, Sub, and Mul don't demand any bits in positions beyond that
    // of the highest bit demanded of them.
    APInt LoMask = APInt::getLowBitsSet(BitWidth,
                                        BitWidth - NewMask.countLeadingZeros());
    if (SimplifyDemandedBits(Op.getOperand(0), LoMask, KnownZero2,
                             KnownOne2, TLO, Depth+1))
      return true;
    if (SimplifyDemandedBits(Op.getOperand(1), LoMask, KnownZero2,
                             KnownOne2, TLO, Depth+1))
      return true;
    // See if the operation should be performed at a smaller bit width.
    if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
      return true;
  }
  // FALL THROUGH
  default:
    // Just use computeKnownBits to compute output bits.
    TLO.DAG.computeKnownBits(Op, KnownZero, KnownOne, Depth);
    break;
  }

  // If we know the value of all of the demanded bits, return this as a
  // constant.
  if ((NewMask & (KnownZero|KnownOne)) == NewMask) {
    // Avoid folding to a constant if any OpaqueConstant is involved.
    const SDNode *N = Op.getNode();
    for (SDNodeIterator I = SDNodeIterator::begin(N),
         E = SDNodeIterator::end(N); I != E; ++I) {
      SDNode *Op = *I;
      if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op))
        if (C->isOpaque())
          return false;
    }
    return TLO.CombineTo(Op,
                         TLO.DAG.getConstant(KnownOne, dl, Op.getValueType()));
  }

  return false;
}