FST pack()

in oak-lucene/src/main/java/org/apache/lucene/util/fst/FST.java [1490:1836]


  FST<T> pack(int minInCountDeref, int maxDerefNodes, float acceptableOverheadRatio) throws IOException {

    // NOTE: maxDerefNodes is intentionally int: we cannot
    // support > 2.1B deref nodes

    // TODO: other things to try
    //   - renumber the nodes to get more next / better locality?
    //   - allow multiple input labels on an arc, so
    //     singular chain of inputs can take one arc (on
    //     wikipedia terms this could save another ~6%)
    //   - in the ord case, the output '1' is presumably
    //     very common (after NO_OUTPUT)... maybe use a bit
    //     for it..?
    //   - use spare bits in flags.... for top few labels /
    //     outputs / targets

    if (nodeAddress == null) {
      throw new IllegalArgumentException("this FST was not built with willPackFST=true");
    }

    Arc<T> arc = new Arc<T>();

    final BytesReader r = getBytesReader();

    final int topN = Math.min(maxDerefNodes, inCounts.size());

    // Find top nodes with highest number of incoming arcs:
    NodeQueue q = new NodeQueue(topN);

    // TODO: we could use more RAM efficient selection algo here...
    NodeAndInCount bottom = null;
    for(int node=0; node<inCounts.size(); node++) {
      if (inCounts.get(node) >= minInCountDeref) {
        if (bottom == null) {
          q.add(new NodeAndInCount(node, (int) inCounts.get(node)));
          if (q.size() == topN) {
            bottom = q.top();
          }
        } else if (inCounts.get(node) > bottom.count) {
          q.insertWithOverflow(new NodeAndInCount(node, (int) inCounts.get(node)));
        }
      }
    }

    // Free up RAM:
    inCounts = null;

    final Map<Integer,Integer> topNodeMap = new HashMap<Integer,Integer>();
    for(int downTo=q.size()-1;downTo>=0;downTo--) {
      NodeAndInCount n = q.pop();
      topNodeMap.put(n.node, downTo);
      //System.out.println("map node=" + n.node + " inCount=" + n.count + " to newID=" + downTo);
    }

    // +1 because node ords start at 1 (0 is reserved as stop node):
    final GrowableWriter newNodeAddress = new GrowableWriter(
                       PackedInts.bitsRequired(this.bytes.getPosition()), (int) (1 + nodeCount), acceptableOverheadRatio);

    // Fill initial coarse guess:
    for(int node=1;node<=nodeCount;node++) {
      newNodeAddress.set(node, 1 + this.bytes.getPosition() - nodeAddress.get(node));
    }

    int absCount;
    int deltaCount;
    int topCount;
    int nextCount;

    FST<T> fst;

    // Iterate until we converge:
    while(true) {

      //System.out.println("\nITER");
      boolean changed = false;

      // for assert:
      boolean negDelta = false;

      fst = new FST<T>(inputType, outputs, bytes.getBlockBits());
      
      final BytesStore writer = fst.bytes;

      // Skip 0 byte since 0 is reserved target:
      writer.writeByte((byte) 0);

      fst.arcWithOutputCount = 0;
      fst.nodeCount = 0;
      fst.arcCount = 0;

      absCount = deltaCount = topCount = nextCount = 0;

      int changedCount = 0;

      long addressError = 0;

      //int totWasted = 0;

      // Since we re-reverse the bytes, we now write the
      // nodes backwards, so that BIT_TARGET_NEXT is
      // unchanged:
      for(int node=(int)nodeCount;node>=1;node--) {
        fst.nodeCount++;
        final long address = writer.getPosition();

        //System.out.println("  node: " + node + " address=" + address);
        if (address != newNodeAddress.get(node)) {
          addressError = address - newNodeAddress.get(node);
          //System.out.println("    change: " + (address - newNodeAddress[node]));
          changed = true;
          newNodeAddress.set(node, address);
          changedCount++;
        }

        int nodeArcCount = 0;
        int bytesPerArc = 0;

        boolean retry = false;

        // for assert:
        boolean anyNegDelta = false;

        // Retry loop: possibly iterate more than once, if
        // this is an array'd node and bytesPerArc changes:
        writeNode:
        while(true) { // retry writing this node

          //System.out.println("  cycle: retry");
          readFirstRealTargetArc(node, arc, r);

          final boolean useArcArray = arc.bytesPerArc != 0;
          if (useArcArray) {
            // Write false first arc:
            if (bytesPerArc == 0) {
              bytesPerArc = arc.bytesPerArc;
            }
            writer.writeByte(ARCS_AS_FIXED_ARRAY);
            writer.writeVInt(arc.numArcs);
            writer.writeVInt(bytesPerArc);
            //System.out.println("node " + node + ": " + arc.numArcs + " arcs");
          }

          int maxBytesPerArc = 0;
          //int wasted = 0;
          while(true) {  // iterate over all arcs for this node
            //System.out.println("    cycle next arc");

            final long arcStartPos = writer.getPosition();
            nodeArcCount++;

            byte flags = 0;

            if (arc.isLast()) {
              flags += BIT_LAST_ARC;
            }
            /*
            if (!useArcArray && nodeUpto < nodes.length-1 && arc.target == nodes[nodeUpto+1]) {
              flags += BIT_TARGET_NEXT;
            }
            */
            if (!useArcArray && node != 1 && arc.target == node-1) {
              flags += BIT_TARGET_NEXT;
              if (!retry) {
                nextCount++;
              }
            }
            if (arc.isFinal()) {
              flags += BIT_FINAL_ARC;
              if (arc.nextFinalOutput != NO_OUTPUT) {
                flags += BIT_ARC_HAS_FINAL_OUTPUT;
              }
            } else {
              assert arc.nextFinalOutput == NO_OUTPUT;
            }
            if (!targetHasArcs(arc)) {
              flags += BIT_STOP_NODE;
            }

            if (arc.output != NO_OUTPUT) {
              flags += BIT_ARC_HAS_OUTPUT;
            }

            final long absPtr;
            final boolean doWriteTarget = targetHasArcs(arc) && (flags & BIT_TARGET_NEXT) == 0;
            if (doWriteTarget) {

              final Integer ptr = topNodeMap.get(arc.target);
              if (ptr != null) {
                absPtr = ptr;
              } else {
                absPtr = topNodeMap.size() + newNodeAddress.get((int) arc.target) + addressError;
              }

              long delta = newNodeAddress.get((int) arc.target) + addressError - writer.getPosition() - 2;
              if (delta < 0) {
                //System.out.println("neg: " + delta);
                anyNegDelta = true;
                delta = 0;
              }

              if (delta < absPtr) {
                flags |= BIT_TARGET_DELTA;
              }
            } else {
              absPtr = 0;
            }

            assert flags != ARCS_AS_FIXED_ARRAY;
            writer.writeByte(flags);

            fst.writeLabel(writer, arc.label);

            if (arc.output != NO_OUTPUT) {
              outputs.write(arc.output, writer);
              if (!retry) {
                fst.arcWithOutputCount++;
              }
            }
            if (arc.nextFinalOutput != NO_OUTPUT) {
              outputs.writeFinalOutput(arc.nextFinalOutput, writer);
            }

            if (doWriteTarget) {

              long delta = newNodeAddress.get((int) arc.target) + addressError - writer.getPosition();
              if (delta < 0) {
                anyNegDelta = true;
                //System.out.println("neg: " + delta);
                delta = 0;
              }

              if (flag(flags, BIT_TARGET_DELTA)) {
                //System.out.println("        delta");
                writer.writeVLong(delta);
                if (!retry) {
                  deltaCount++;
                }
              } else {
                /*
                if (ptr != null) {
                  System.out.println("        deref");
                } else {
                  System.out.println("        abs");
                }
                */
                writer.writeVLong(absPtr);
                if (!retry) {
                  if (absPtr >= topNodeMap.size()) {
                    absCount++;
                  } else {
                    topCount++;
                  }
                }
              }
            }

            if (useArcArray) {
              final int arcBytes = (int) (writer.getPosition() - arcStartPos);
              //System.out.println("  " + arcBytes + " bytes");
              maxBytesPerArc = Math.max(maxBytesPerArc, arcBytes);
              // NOTE: this may in fact go "backwards", if
              // somehow (rarely, possibly never) we use
              // more bytesPerArc in this rewrite than the
              // incoming FST did... but in this case we
              // will retry (below) so it's OK to ovewrite
              // bytes:
              //wasted += bytesPerArc - arcBytes;
              writer.skipBytes((int) (arcStartPos + bytesPerArc - writer.getPosition()));
            }

            if (arc.isLast()) {
              break;
            }

            readNextRealArc(arc, r);
          }

          if (useArcArray) {
            if (maxBytesPerArc == bytesPerArc || (retry && maxBytesPerArc <= bytesPerArc)) {
              // converged
              //System.out.println("  bba=" + bytesPerArc + " wasted=" + wasted);
              //totWasted += wasted;
              break;
            }
          } else {
            break;
          }

          //System.out.println("  retry this node maxBytesPerArc=" + maxBytesPerArc + " vs " + bytesPerArc);

          // Retry:
          bytesPerArc = maxBytesPerArc;
          writer.truncate(address);
          nodeArcCount = 0;
          retry = true;
          anyNegDelta = false;
        }

        negDelta |= anyNegDelta;

        fst.arcCount += nodeArcCount;
      }

      if (!changed) {
        // We don't renumber the nodes (just reverse their
        // order) so nodes should only point forward to
        // other nodes because we only produce acyclic FSTs
        // w/ nodes only pointing "forwards":
        assert !negDelta;
        //System.out.println("TOT wasted=" + totWasted);
        // Converged!
        break;
      }
      //System.out.println("  " + changedCount + " of " + fst.nodeCount + " changed; retry");
    }

    long maxAddress = 0;
    for (long key : topNodeMap.keySet()) {
      maxAddress = Math.max(maxAddress, newNodeAddress.get((int) key));
    }

    PackedInts.Mutable nodeRefToAddressIn = PackedInts.getMutable(topNodeMap.size(),
        PackedInts.bitsRequired(maxAddress), acceptableOverheadRatio);
    for(Map.Entry<Integer,Integer> ent : topNodeMap.entrySet()) {
      nodeRefToAddressIn.set(ent.getValue(), newNodeAddress.get(ent.getKey()));
    }
    fst.nodeRefToAddress = nodeRefToAddressIn;
    
    fst.startNode = newNodeAddress.get((int) startNode);
    //System.out.println("new startNode=" + fst.startNode + " old startNode=" + startNode);

    if (emptyOutput != null) {
      fst.setEmptyOutput(emptyOutput);
    }

    assert fst.nodeCount == nodeCount: "fst.nodeCount=" + fst.nodeCount + " nodeCount=" + nodeCount;
    assert fst.arcCount == arcCount;
    assert fst.arcWithOutputCount == arcWithOutputCount: "fst.arcWithOutputCount=" + fst.arcWithOutputCount + " arcWithOutputCount=" + arcWithOutputCount;

    fst.bytes.finish();
    fst.cacheRootArcs();

    //final int size = fst.sizeInBytes();
    //System.out.println("nextCount=" + nextCount + " topCount=" + topCount + " deltaCount=" + deltaCount + " absCount=" + absCount);

    return fst;
  }