def spreadRefundAcrossItems()

in src/main/scala/com/gu/invoicing/refund/Impl.scala [103:192]


  def spreadRefundAcrossItems(
      invoiceItems: List[InvoiceItem],
      taxationItems: List[TaxationItem],
      adjustments: List[InvoiceItemAdjustment],
      totalRefundAmount: BigDecimal,
      refundGuid: String,
  ): List[InvoiceItemAdjustmentWrite] = {

    def buildInvoiceItemAdjustments(
        invoiceItem: InvoiceItem,
        amountToRefund: BigDecimal,
    ): List[InvoiceItemAdjustmentWrite] = {
      // If the invoice item being adjusted has tax paid on it, it will need to be adjusted in two separate adjustments:
      // - one for the charge amount where the SourceType is "InvoiceDetail" and SourceId is the invoice item id
      // - one for the tax amount where the SourceType is "Tax" and SourceId is the taxation item id
      // https://www.zuora.com/developer/api-references/older-api/operation/Object_POSTInvoiceItemAdjustment/#!path=SourceType&t=request

      val chargeAmountToRefund = invoiceItem.ChargeAmount.min(amountToRefund)
      val chargeAdjustment = List(
        InvoiceItemAdjustmentWrite(
          LocalDate.now(ZoneId.of("Europe/London")),
          chargeAmountToRefund,
          refundGuid,
          invoiceItem.InvoiceId,
          "Credit",
          "InvoiceDetail",
          invoiceItem.Id,
        ),
      )

      val taxAmountToRefund = amountToRefund - chargeAmountToRefund
      if (taxAmountToRefund > invoiceItem.TaxAmount) {
        println(
          s"Unexpected state when trying to create InvoiceItem adjustment for $invoiceItem. " +
            s"Amount to refund was $amountToRefund, chargeAmountToRefund was $chargeAmountToRefund " +
            s"so taxAmountToRefund was $taxAmountToRefund but the tax on the invoice item was only ${invoiceItem.TaxAmount}",
        )
        throw new RuntimeException(s"Unexpected state when trying to create InvoiceItem adjustment for $invoiceItem")
      }

      val taxAdjustment =
        if (taxAmountToRefund > 0) {
          val Some(taxationItemId) = taxationItems.find(_.InvoiceItemId == invoiceItem.Id).map(_.Id) tap { item =>
            s"Missing taxation id for invoiceItem $invoiceItem" assert item.isDefined
          }

          List(
            InvoiceItemAdjustmentWrite(
              LocalDate.now(ZoneId.of("Europe/London")),
              taxAmountToRefund,
              refundGuid,
              invoiceItem.InvoiceId,
              "Credit",
              "Tax",
              taxationItemId,
            ),
          )
        } else Nil

      chargeAdjustment ++ taxAdjustment
    }

    @tailrec def loop(
        remainingAmountToRefund: BigDecimal,
        remainingItems: List[InvoiceItem],
        accumulatedAdjustments: List[InvoiceItemAdjustmentWrite],
    ): List[InvoiceItemAdjustmentWrite] = {
      remainingItems match {
        case Nil =>
          accumulatedAdjustments
        case nextItem :: tail =>
          availableAmount(nextItem, adjustments) match {
            case Some(availableRefundableAmount) =>
              if (availableRefundableAmount >= remainingAmountToRefund)
                buildInvoiceItemAdjustments(nextItem, remainingAmountToRefund) ++ accumulatedAdjustments
              else {
                loop(
                  remainingAmountToRefund - availableRefundableAmount,
                  tail,
                  buildInvoiceItemAdjustments(nextItem, availableRefundableAmount) ++ accumulatedAdjustments,
                )
              }
            case None =>
              loop(remainingAmountToRefund, tail, accumulatedAdjustments)
          }
      }
    }

    loop(totalRefundAmount, invoiceItems, List.empty[InvoiceItemAdjustmentWrite])
  }