export default function MNIST()

in react-native-pytorch-core/example/src/toolbox/models/MNIST.tsx [196:320]


export default function MNIST() {
  const [canvasSize, setCanvasSize] = useState<number>(0);

  // `ctx` is drawing context to draw shapes
  const [ctx, setCtx] = useState<CanvasRenderingContext2D>();

  const {classify, result} = useMNISTCanvasInference(canvasSize);

  const trailRef = useRef<TrailPoint[]>([]);
  const [drawingDone, setDrawingDone] = useState(false);
  const animationHandleRef = useRef<number | null>(null);

  const draw = useCallback(() => {
    if (animationHandleRef.current != null) return;
    if (ctx != null) {
      animationHandleRef.current = requestAnimationFrame(() => {
        const trail = trailRef.current;
        if (trail != null) {
          // fill background by drawing a rect
          ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
          ctx.fillRect(0, 0, canvasSize, canvasSize);

          // Draw the trail
          ctx.strokeStyle = COLOR_TRAIL_STROKE;
          ctx.lineWidth = 25;
          ctx.lineJoin = 'round';
          ctx.lineCap = 'round';
          ctx.miterLimit = 1;

          if (trail.length > 0) {
            ctx.beginPath();
            ctx.moveTo(trail[0].x, trail[0].y);
            for (let i = 1; i < trail.length; i++) {
              ctx.lineTo(trail[i].x, trail[i].y);
            }
          }
          ctx.stroke();
          // Need to include this at the end, for now.
          ctx.invalidate();
          animationHandleRef.current = null;
        }
      });
    }
  }, [animationHandleRef, canvasSize, ctx, trailRef]);

  // handlers for touch events
  const handleMove = useCallback(
    async event => {
      const position: TrailPoint = {
        x: event.nativeEvent.locationX,
        y: event.nativeEvent.locationY,
      };
      const trail = trailRef.current;
      if (trail.length > 0) {
        const lastPosition = trail[trail.length - 1];
        const dx = position.x - lastPosition.x;
        const dy = position.y - lastPosition.y;
        // add a point to trail if distance from last point > 5
        if (dx * dx + dy * dy > 25) {
          trail.push(position);
        }
      } else {
        trail.push(position);
      }
      draw();
    },
    [trailRef, draw],
  );

  const handleStart = useCallback(() => {
    setDrawingDone(false);
    trailRef.current = [];
  }, [trailRef, setDrawingDone]);

  const handleEnd = useCallback(() => {
    setDrawingDone(true);
    if (ctx != null) classify(ctx, true);
  }, [setDrawingDone, classify, ctx]);

  useEffect(() => {
    draw();
  }, [draw]);

  return (
    <View
      style={styles.container}
      onLayout={event => {
        const {layout} = event.nativeEvent;
        setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
      }}>
      <View style={styles.instruction}>
        <Text style={styles.label}>Write a number</Text>
        <Text style={styles.label}>
          Let's see if the AI model will get it right
        </Text>
      </View>
      <Canvas
        style={{
          height: canvasSize,
          width: canvasSize,
        }}
        onContext2D={setCtx}
        onTouchMove={handleMove}
        onTouchStart={handleStart}
        onTouchEnd={handleEnd}
      />
      {drawingDone && (
        <View style={[styles.resultView]} pointerEvents="none">
          <Text style={[styles.label, styles.secondary]}>
            {result &&
              `${numLabels[result[0].num].asciiSymbol} it looks like ${
                numLabels[result[0].num].english
              }`}
          </Text>
          <Text style={[styles.label, styles.secondary]}>
            {result &&
              `${numLabels[result[1].num].asciiSymbol} or it might be ${
                numLabels[result[1].num].english
              }`}
          </Text>
        </View>
      )}
    </View>
  );
}