Xây Dựng NN Model Phân Loại Với Tập Dữ Liệu MNIST - SuNT's Blog
Có thể bạn quan tâm
Bài này, ta sẽ nâng độ khó hơn 1 chút so với bài trước. Yêu cầu đề bài như sau:
- Xây dựng một NN model phân loại các hình ảnh trong tập dữ liệu MNIST
- Trong quá trình train, khi độ chính xác của model đạt đến 99% thì dừng train.
Mục đích của bài này là giúp ta làm quen với dữ liệu “thực tế” hơn 1 chút so với bài trước và cách sử dụng hàm callback để điều khiển quá trình train model.
Môi trường thực hành của bài này giống hệt bài dự đoán giá nhà trước đó.
Ok, hãy cùng bắt đầu!
Đầu tiên, import tensorflow:
1 import tensorflow as tfTiếp theo, ta định nghĩa hàm callback. Hàm này sẽ được gọi mỗi khi model train xong một epoch.
2 class CustomCallback(keras.callbacks.Callback): 3 def on_epoch_end(self, epoch, logs={}): 4 if logs.get('acc') > 0.99: 5 print('Reached to 99%, stop training!') 6 self.model.stop_training = TrueỞ đây, ta sẽ cho model dừng train khi độ chính xác đạt đến 99% như yêu cầu đề bài.
Dataset được load như code của hàm sau:
7 def load_data(): 8 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 9 x_train = x_train/255 10 x_test = x_test/255 11 return x_train, y_train, x_test, y_testMNIST có thể coi là bộ dataset kinh điển mà hầu như bất kỳ ai cũng sử dụng khi mới học AI. Có lẽ vì thế mà nó được tích hợp sẵn trong thư viện tensorflow.
Có một chú ý ở hàm load_data() là ta cũng scale down giá trị của x_train, x_test bằng cách chia cho 255. Mục đích của việc làm này cũng giống như mình đã trình bày trong bài trước.
Phần chính của chúng ta là định nghĩa model:
12 def create_model(): 13 model = tf.keras.models.Sequential([ 14 tf.keras.layers.Flatten(), 15 tf.keras.layers.Dense(128, activation='relu', input_shape=(28,28)), 16 tf.keras.layers.Dense(10, activation='softmax') 17 ]) 18 model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['acc']) 19 return modelModel của chúng ta hôm nay gồm 2 lớp: 1 lớp input (128 nodes) và 1 lớp output (10 nodes), không có lớp ẩn (hidden layer).
Kích thước của dữ liệu đầu vào là (28,28), bằng với kích thước của mỗi bức ảnh trong tập MNIST.
Số node của lớp output là 10, bằng với số lớp của tập MNIST mà ta cần phân loại. Hàm kích hoạt Softmax sử dụng ở lớp này sẽ cho ta biết chính xác xác suất của hình ảnh thuộc về mỗi lớp. Lớp nào có xác suất lớn nhất sẽ được lấy làm kết quả cuối cùng.
Model được compile với thuật toán tối ưu SGD, hàm loss là sparse_categorical_crossentropy, và metric là accuracy trên tập train.
Bây giờ ta sẽ tiến hành train model:
20 x_train, y_train, x_test, y_test = load_data() 21 model = create_model() 22 23 history = model.fit(x_train, y_train, epochs=100, verbose=1, callbacks=[CustomCallback()])Model được train tối đa 100 epochs, hàm callback mà ta định nghĩa bên trên được truyền vào như 1 tham số.
Output:
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 1s 0us/step Epoch 1/100 1875/1875 - 1s - loss: 0.2573 - acc: 0.9266 Epoch 2/100 1875/1875 - 1s - loss: 0.1128 - acc: 0.9666 Epoch 3/100 1875/1875 - 1s - loss: 0.0771 - acc: 0.9765 Epoch 4/100 1875/1875 - 1s - loss: 0.0573 - acc: 0.9825 Epoch 5/100 1875/1875 - 1s - loss: 0.0445 - acc: 0.9856 Epoch 6/100 1875/1875 - 1s - loss: 0.0345 - acc: 0.9895 Epoch 7/100 Reached to 99%, stop training! 1875/1875 - 1s - loss: 0.0283 - acc: 0.9913Đầu tiên, tập MNIST sẽ được download về local, sau đó model sẽ được train. Quá trình train dừng lại sau 7 epochs vì độ chính xác đã đạt đến 99% như định nghĩa ở hàm callback.
Như vậy là chúng ta đã giải quyết xong yêu cầu đặt ra lúc đầu. Qua bài này ta đã biết:
- Cách load dataset được tích hợp trong tensorflow.
- Cách xây dựng và sử dụng hàm callback khi train model.
- Các tạo và train model với 2 lớp NN sử dụng tensorflow.
Source code của bài này, các bạn có thể tham khảo trên github cá nhân của mình tại đây.
Bài tiếp theo, chúng ta sẽ xây dựng model sử dụng lớp CONV trong tensorflow để nâng cao độ chính xác cũng như hiệu năng của model. Mời các bạn đón đọc.
Tham khảo
- Coursera
Từ khóa » Bộ Dữ Liệu Mnist
-
3.5. Bộ Dữ Liệu Phân Loại Ảnh (Fashion-MNIST)
-
MNIST Handwritten Digit Database, Yann LeCun, Corinna Cortes ...
-
Cơ Sở Dữ Liệu MNIST Là Gì? Chi Tiết Về Cơ Sở Dữ ... - LADIGI Academy
-
Bộ Dữ Liệu MNIST - Tìm Hiểu Và Nâng Cao Hiệu Quả Nhận Dạng Chữ ...
-
Dữ Liệu Trong Deep Learning - TEK4
-
Cơ Sở Dữ Liệu MNIST - Wikimedia Tiếng Việt
-
Nhận Diện Chữ Viết Với PyTorch | Deep Learning Viet Nam
-
Mnist Dataset Là Gì
-
Cơ Sở Dữ Liệu MNIST
-
Hiểu MNIST Và Xây Dựng Mô Hình Phân Loại Với Bộ Dữ ... - Chickgolden
-
Mnist Dataset Là Gì - LIVESHAREWIKI
-
Mnist.ipynb - Gists · GitHub
-
Load Thư Viện Mnist Là Gì
-
Giới Thiệu Keras Và Bài Toán Phân Loại ảnh. - Deep Learning Cơ Bản
-
Làm Cách Nào để Tạo Tập Dữ Liệu Hình ảnh Giống Như Tập Dữ Liệu ...