def fetch_imagenet_class()

in tcav/tcav_examples/image_models/imagenet/imagenet_and_broden_fetcher.py [0:0]


def fetch_imagenet_class(path, class_name, number_of_images, imagenet_dataframe):
  if imagenet_dataframe is None:
    raise tf.errors.NotFoundError(
        None, None,
        "Please provide a dataframe containing the imagenet classes. Easiest way to do this is by calling make_imagenet_dataframe()"
    )
  # To speed up imagenet download, we timeout image downloads at 5 seconds.
  socket.setdefaulttimeout(5)

  tf.compat.v1.logging.info("Fetching imagenet data for " + class_name)
  concept_path = os.path.join(path, class_name)
  tf.io.gfile.makedirs(concept_path)
  tf.compat.v1.logging.info("Saving images at " + concept_path)

  # Check to see if this class name exists. Fetch all urls if so.
  all_images = fetch_all_urls_for_concept(imagenet_dataframe, class_name)

  # Fetch number_of_images images or as many as you can.
  num_downloaded = 0
  for image_url in all_images:
    if "flickr" not in image_url:
      try:
        download_image(concept_path, image_url)
        num_downloaded += 1

      except Exception as e:
        tf.compat.v1.logging.info("Problem downloading imagenet image. Exception was " +
                        str(e) + " for URL " + image_url)
    if num_downloaded >= number_of_images:
      break

  # If we reached the end, notify the user through the console.
  if num_downloaded < number_of_images:
    print("You requested " + str(number_of_images) +
          " but we were only able to find " +
          str(num_downloaded) +
          " good images from imageNet for concept " + class_name)
  else:
    print("Downloaded " + str(number_of_images) + " for " + class_name)