def torchvision()

in dags/pytorch_xla/r2_7.py [0:0]


def torchvision():
  mnist_v2_8 = task.run_queued_resource_test(
      test_config.JSonnetTpuVmTest.from_pytorch(
          "pt-2-7-mnist-pjrt-func-v2-8-1vm"
      ),
      US_CENTRAL1_C,
  )
  resnet_v2_8 = task.run_queued_resource_test(
      test_config.JSonnetTpuVmTest.from_pytorch(
          "pt-2-7-resnet50-pjrt-fake-v2-8-1vm",
          reserved=True,
      ),
      US_CENTRAL1_C,
  )
  resnet_v3_8_tests = [
      task.run_queued_resource_test(
          test_config.JSonnetTpuVmTest.from_pytorch(test, reserved=True),
          US_EAST1_D,
      )
      for test in (
          "pt-2-7-resnet50-pjrt-fake-v3-8-1vm",
          "pt-2-7-resnet50-pjrt-ddp-fake-v3-8-1vm",
      )
  ]
  resnet_v4_8_tests = [
      task.run_queued_resource_test(
          test_config.JSonnetTpuVmTest.from_pytorch(test),
          US_CENTRAL2_B,
      )
      for test in (
          "pt-2-7-resnet50-pjrt-fake-v4-8-1vm",
          "pt-2-7-resnet50-pjrt-ddp-fake-v4-8-1vm",
          "pt-2-7-resnet50-spmd-batch-fake-v4-8-1vm",
          "pt-2-7-resnet50-spmd-spatial-fake-v4-8-1vm",
      )
  ]
  resnet_v4_32 = task.run_queued_resource_test(
      test_config.JSonnetTpuVmTest.from_pytorch(
          "pt-2-7-resnet50-pjrt-fake-v4-32-1vm"
      ),
      US_CENTRAL2_B,
  )
  resnet_v5lp_4 = task.run_queued_resource_test(
      test_config.JSonnetTpuVmTest.from_pytorch(
          "pt-2-7-resnet50-pjrt-fake-v5litepod-4-1vm",
          network=V5_NETWORKS,
          subnetwork=V5E_SUBNETWORKS,
          reserved=True,
      ),
      US_EAST1_C,
  )

  mnist_v2_8 >> (resnet_v2_8, *resnet_v4_8_tests, resnet_v4_32, resnet_v5lp_4)
  resnet_v2_8 >> resnet_v3_8_tests

  resnet_v100_2x2 = task.GpuGkeTask(
      test_config.GpuGkeTest.from_pytorch("pt-2-7-resnet50-mp-fake-v100-x2x2"),
      US_CENTRAL1,
      "gpu-uc1",
  ).run()
  resnet_v100_2x2_spmd = task.GpuGkeTask(
      test_config.GpuGkeTest.from_pytorch(
          "pt-2-7-resnet50-spmd-batch-fake-v100-x2x2"
      ),
      US_CENTRAL1,
      "gpu-uc1",
  ).run()
  resnet_v100_2x2 >> resnet_v100_2x2_spmd