in tcav/tcav_examples/image_models/imagenet/imagenet_and_broden_fetcher.py [0:0]
def fetch_imagenet_class(path, class_name, number_of_images, imagenet_dataframe):
if imagenet_dataframe is None:
raise tf.errors.NotFoundError(
None, None,
"Please provide a dataframe containing the imagenet classes. Easiest way to do this is by calling make_imagenet_dataframe()"
)
# To speed up imagenet download, we timeout image downloads at 5 seconds.
socket.setdefaulttimeout(5)
tf.compat.v1.logging.info("Fetching imagenet data for " + class_name)
concept_path = os.path.join(path, class_name)
tf.io.gfile.makedirs(concept_path)
tf.compat.v1.logging.info("Saving images at " + concept_path)
# Check to see if this class name exists. Fetch all urls if so.
all_images = fetch_all_urls_for_concept(imagenet_dataframe, class_name)
# Fetch number_of_images images or as many as you can.
num_downloaded = 0
for image_url in all_images:
if "flickr" not in image_url:
try:
download_image(concept_path, image_url)
num_downloaded += 1
except Exception as e:
tf.compat.v1.logging.info("Problem downloading imagenet image. Exception was " +
str(e) + " for URL " + image_url)
if num_downloaded >= number_of_images:
break
# If we reached the end, notify the user through the console.
if num_downloaded < number_of_images:
print("You requested " + str(number_of_images) +
" but we were only able to find " +
str(num_downloaded) +
" good images from imageNet for concept " + class_name)
else:
print("Downloaded " + str(number_of_images) + " for " + class_name)