in #U57fa#U7840#U6559#U7a0b/A2-#U795e#U7ecf#U7f51#U7edc#U57fa#U672c#U539f#U7406/src/Data/ch18_color.py [0:0]
def color():
images_triangle = np.empty((1200,3,28,28))
label_triangle = np.empty((1200,1))
rand = np.random.randint(0,6,1200)
for i in range(1200):
img = Image.new("RGB", [28,28], "black")
drawObj = ImageDraw.Draw(img)
triangle(drawObj, rand[i])
images_triangle[i] = np.array(img).transpose(2,0,1)
#plt.imshow(images_triangle[i].transpose(1,2,0))
#plt.show()
label_triangle[i] = rand[i]
images_circle = np.empty((1200,3,28,28))
label_circle = np.empty((1200,1))
rand = np.random.randint(0,6,1200)
for i in range(1200):
img = Image.new("RGB", [28,28], "black")
drawObj = ImageDraw.Draw(img)
circle(drawObj, rand[i])
images_circle[i] = np.array(img).transpose(2,0,1)
label_circle[i] = rand[i]
images_rectangle = np.empty((1200,3,28,28))
label_rectangle = np.empty((1200,1))
rand = np.random.randint(0,6,1200)
for i in range(1200):
img = Image.new("RGB", [28,28], "black")
drawObj = ImageDraw.Draw(img)
rectangle(drawObj, rand[i])
images_rectangle[i] = np.array(img).transpose(2,0,1)
label_rectangle[i] = rand[i]
images_line = np.empty((1200,3,28,28))
label_line = np.empty((1200,1))
rand = np.random.randint(0,6,1200)
for i in range(1200):
img = Image.new("RGB", [28,28], "black")
drawObj = ImageDraw.Draw(img)
line(drawObj, rand[i])
images_line[i] = np.array(img).transpose(2,0,1)
label_line[i] = rand[i]
images_diamond = np.empty((1200,3,28,28))
label_diamond = np.empty((1200,1))
rand = np.random.randint(0,6,1200)
for i in range(1200):
img = Image.new("RGB", [28,28], "black")
drawObj = ImageDraw.Draw(img)
diamond(drawObj, rand[i])
images_diamond[i] = np.array(img).transpose(2,0,1)
label_diamond[i] = rand[i]
images_train = np.empty((5000,3,28,28))
label_train = np.empty((5000,1))
images_train[0:1000] = images_circle[0:1000]
label_train[0:1000] = label_circle[0:1000]
images_train[1000:2000] = images_rectangle[0:1000]
label_train[1000:2000] = label_rectangle[0:1000]
images_train[2000:3000] = images_triangle[0:1000]
label_train[2000:3000] = label_triangle[0:1000]
images_train[3000:4000] = images_diamond[0:1000]
label_train[3000:4000] = label_diamond[0:1000]
images_train[4000:5000] = images_line[0:1000]
label_train[4000:5000] = label_line[0:1000]
images_test = np.empty((1000,3,28,28))
label_test = np.empty((1000,1))
images_test[0:200] = images_circle[1000:1200]
label_test[0:200] = label_circle[1000:1200]
images_test[200:400] = images_rectangle[1000:1200]
label_test[200:400] = label_rectangle[1000:1200]
images_test[400:600] = images_triangle[1000:1200]
label_test[400:600] = label_triangle[1000:1200]
images_test[600:800] = images_diamond[1000:1200]
label_test[600:800] = label_diamond[1000:1200]
images_test[800:1000] = images_line[1000:1200]
label_test[800:1000] = label_line[1000:1200]
seed = np.random.randint(0,100)
np.random.seed(seed)
np.random.shuffle(images_train)
np.random.seed(seed)
np.random.shuffle(label_train)
np.random.seed(seed)
np.random.shuffle(images_test)
np.random.seed(seed)
np.random.shuffle(label_test)
np.savez(train_data_name, data=images_train, label=label_train)
np.savez(test_data_name, data=images_test, label=label_test)