Path: blob/master/site/ko/js/guide/train_models.md
25118 views
ํ๋ จ ๋ชจ๋ธ
์ด ๊ฐ์ด๋๋ ์ด๋ฏธ ๋ชจ๋ธ ๋ฐ ๋ ์ด์ด ๊ฐ์ด๋๋ฅผ ์ฝ์๋ค๋ ๊ฐ์ ํ์ ์ฐ์์ต๋๋ค.
TensorFlow.js์๋ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
<a href="
https://js.tensorflow.org/api/latest/#tf.Model.fit
" data-md-type="link">LayersModel.fit()</a>
๋๋<a href="
https://js.tensorflow.org/api/latest/#tf.Model.fitDataset
" data-md-type="link">LayersModel.fitDataset()</a>
์ ํจ๊ป Layers API ์ฌ์ฉํ๊ธฐ<a href="
https://js.tensorflow.org/api/latest/#tf.train.Optimizer.minimize
" data-md-type="link">Optimizer.minimize()</a>
์ ํจ๊ป Core API ์ฌ์ฉํ๊ธฐ
๋จผ์ ๋ชจ๋ธ ๋น๋ ๋ฐ ํ๋ จ์ ์ํ ์์ ์์ค API ์ธ Layers API๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ Core API๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์ ๋ณด๊ฒ ์ต๋๋ค.
์๊ฐ
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์ ์ํ๋ ์ถ๋ ฅ์ ๋งคํํ๋ ํ์ต ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์๊ฐ ์๋ ํจ์์ ๋๋ค. ์ต์ ์ ๋งค๊ฐ๋ณ์๋ ๋ฐ์ดํฐ์์ ๋ชจ๋ธ์ ํ๋ จํ์ฌ ์ป์ต๋๋ค.
ํ๋ จ์๋ ์ฌ๋ฌ ๋จ๊ณ๊ฐ ํฌํจ๋ฉ๋๋ค.
๋ชจ๋ธ์ ๋ฐ์ดํฐ ๋ฐฐ์น ๊ฐ์ ธ์ค๊ธฐ
๋ชจ๋ธ์ ์์ธก๊ฐ ์์ฒญํ๊ธฐ
ํด๋น ์์ธก๊ฐ์ '์ฐธ'๊ฐ๊ณผ ๋น๊ตํ๊ธฐ
๋ชจ๋ธ์ด ํฅํ ํด๋น ๋ฐฐ์น์ ๋ํด ๋ ๋์ ์์ธก๊ฐ์ ๋ด๋๋ก ๊ฐ ๋งค๊ฐ๋ณ์์ ๋ณ๊ฒฝ ๋ฒ์ ๊ฒฐ์ ํ๊ธฐ
์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ ๋ ฅ์์ ์ํ๋ ์ถ๋ ฅ์ผ๋ก ์ ํํ ๋งคํ์ ์ ๊ณตํฉ๋๋ค.
๋ชจ๋ธ ๋งค๊ฐ๋ณ์
Layers API๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ๋จํ 2๋ ์ด์ด ๋ชจ๋ธ์ ์ ์ํด๋ณด๊ฒ ์ต๋๋ค.
๋ด๋ถ์ ์ผ๋ก ๋ชจ๋ธ์๋ ๋ฐ์ดํฐ ํ์ต์ ํตํด ํ๋ จํ ์ ์๋ ๋งค๊ฐ๋ณ์(์ข ์ข ๊ฐ์ค์น ๋ผ๊ณ ํจ)๊ฐ ์์ต๋๋ค. ์ด ๋ชจ๋ธ ๋ฐ ํ์๊ณผ ๊ด๋ จ๋ ๊ฐ์ค์น์ ์ด๋ฆ์ ์ถ๋ ฅํด๋ด ๋๋ค.
๋ค์๊ณผ ๊ฐ์ ์ถ๋ ฅ์ด ํ์๋ฉ๋๋ค.
์ด 4๊ฐ์ ๊ฐ์ค์น๊ฐ ์์ผ๋ฉฐ ๋ฐ์ง ๋ ์ด์ด๋น 2๊ฐ์
๋๋ค. ์ด๋ฌํ ๊ฒฐ๊ณผ๊ฐ ์์๋๋ ์ด์ ๋ ๋ฐ์ง ๋ ์ด์ด๋ ์์ y = Ax + b
๋ฅผ ํตํด ์
๋ ฅ ํ
์ x
๋ฅผ ์ถ๋ ฅ ํ
์ y
๋ก ๋งคํํ๋ ํจ์๋ฅผ ๋ํ๋ด๊ธฐ ๋๋ฌธ์
๋๋ค. ์ฌ๊ธฐ์ A
(์ปค๋) ๋ฐ b
(๋ฐ์ด์ด์ค)๋ ๋ฐ์ง ๋ ์ด์ด์ ๋งค๊ฐ๋ณ์์
๋๋ค.
์ฐธ๊ณ : ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฐ์ง ๋ ์ด์ด์๋ ๋ฐ์ด์ด์ค๊ฐ ํฌํจ๋์ง๋ง, ๋ฐ์ง ๋ ์ด์ด๋ฅผ ๋ง๋ค ๋ ์ต์ ์์
{useBias: false}
๋ฅผ ์ง์ ํ์ฌ ๋ฐ์ด์ด์ค๋ฅผ ์ ์ธํ ์ ์์ต๋๋ค.
model.summary()
๋ ๋ชจ๋ธ์ ๊ฐ์์ ์ด ๋งค๊ฐ๋ณ์ ์๋ฅผ ํ์ธํ๋ ค๋ ๊ฒฝ์ฐ ์ ์ฉํ ๋ฉ์๋์
๋๋ค.
๋ ์ด์ด(์ ํ) | ์ถ๋ ฅ ํ์ | ๋งค๊ฐ๋ณ์ ๋ฒํธ |
density_Dense1(๋ฐ๋) | [null,32] | 25120 |
density_Dense2(๋ฐ๋) | [null,10] | 330 |
์ด ๋งค๊ฐ๋ณ์: 25450 ํ๋ จ ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์: 25450 ํ๋ จํ ์ ์๋ ๋งค๊ฐ๋ณ์: 0 |
๋ชจ๋ธ์ ๊ฐ ๊ฐ์ค์น๋ <a href="
https://js.tensorflow.org/api/0.14.2/#class:Variable
" data-md-type="link">Variable</a>
๊ฐ์ฒด์ ๋ฐฑ์๋์
๋๋ค. TensorFlow.js์์ Variable
์ ๊ฐ์ ์
๋ฐ์ดํธํ๋ ๋ฐ ์ฌ์ฉ๋๋ ํ๋์ ์ถ๊ฐ ๋ฉ์๋ assign()
์ด ์๋ ๋ถ๋ ์์์ Tensor
์
๋๋ค. Layers API๋ ๋ชจ๋ฒ ์ฌ๋ก๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ฅผ ์๋์ผ๋ก ์ด๊ธฐํํฉ๋๋ค. ๋ฐ๋ชจ๋ฅผ ์ํด ๊ธฐ๋ณธ ๋ณ์์ ๋ํด assign()
์ ํธ์ถํ์ฌ ๊ฐ์ค์น๋ฅผ ๋ฎ์ด์ธ ์ ์์ต๋๋ค.
์ตํฐ๋ง์ด์ , ์์ค ๋ฐ ๋ฉํธ๋ฆญ
ํ๋ จ์ ์์ํ๊ธฐ ์ ์ ๋ค์ ์ธ ๊ฐ์ง๋ฅผ ๊ฒฐ์ ํด์ผ ํฉ๋๋ค.
์ตํฐ๋ง์ด์ : ์ตํฐ๋ง์ด์ ๋ ํ์ฌ ๋ชจ๋ธ์ ์์ธก๊ฐ์ด ์ฃผ์ด์ก์ ๋ ๋ชจ๋ธ์ ๊ฐ ๋งค๊ฐ๋ณ์๋ฅผ ์ผ๋ง๋ ๋ณ๊ฒฝํ ๊ฒ์ธ์ง ๊ฒฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค. Layers API๋ฅผ ์ฌ์ฉํ ๋ ๊ธฐ์กด ์ตํฐ๋ง์ด์ ์ ๋ฌธ์์ด ์๋ณ์์ธ(์:
'sgd'
๋๋'adam'
) ๋๋<a href="
https://js.tensorflow.org/api/latest/#Training-Optimizers
" data-md-type="link">Optimizer</a>
ํด๋์ค์ ์ธ์คํด์ค๋ฅผ ์ ๊ณตํ ์ ์์ต๋๋ค.์์ค ํจ์: ๋ชจ๋ธ์ ์ต์ํ๋ฅผ ๋ชฉํ๋ก ํฉ๋๋ค. ๋ชจ๋ธ์ ์์ธก๊ฐ์ด '์ผ๋ง๋ ์๋ชป๋์๋์ง'์ ๋ํ ๋จ์ผ ์ซ์๋ฅผ ์ ๊ณตํ๋ ๊ฒ์ ๋๋ค. ์์ค์ ๋ชจ๋ธ์ด ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํ ์ ์๋๋ก ๋ชจ๋ ๋ฐ์ดํฐ ๋ฐฐ์น์์ ๊ณ์ฐ๋ฉ๋๋ค. Layers API๋ฅผ ์ฌ์ฉํ ๋ ๊ธฐ์กด ์์ค ํจ์์ ๋ฌธ์์ด ์๋ณ์(์:
'categoricalCrossentropy'
) ๋๋ ์์ธก๊ฐ๊ณผ ์ฐธ๊ฐ์ ๊ฐ์ ธ์ ์์ค์ ๋ฐํํ๋ ๋ชจ๋ ํจ์๋ฅผ ์ ๊ณตํ ์ ์์ต๋๋ค. API ์ค๋ช ์์์ ์ฌ์ฉ ๊ฐ๋ฅํ ์์ค ๋ชฉ๋ก์ ์ฐธ์กฐํ์ธ์.๋ฉํธ๋ฆญ ๋ชฉ๋ก: ์์ค๊ณผ ์ ์ฌํ๊ฒ ๋ฉํธ๋ฆญ์ ๋จ์ผ ์ซ์๋ฅผ ๊ณ์ฐํ์ฌ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์์ฝํฉ๋๋ค. ๋ฉํธ๋ฆญ์ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ epoch๊ฐ ๋๋ ๋ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํด ๊ณ์ฐ๋ฉ๋๋ค. ์ต์ํ ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ์์ค์ด ๊ฐ์ํ๊ณ ์๋์ง ๋ชจ๋ํฐ๋งํด์ผ ํฉ๋๋ค๋ง, ์ ํ์ฑ๊ณผ ๊ฐ์ ๋ณด๋ค ์ธ๊ฐ ์นํ์ ์ธ ๋ฉํธ๋ฆญ์ ์ํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. Layers API๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ ๊ธฐ์กด ๋ฉํธ๋ฆญ์ ๋ฌธ์์ด ์๋ณ์(์:
'accuracy'
) ๋๋ ์์ธก๊ฐ ๋ฐ ์ฐธ๊ฐ์ ๊ฐ์ ธ์ ์ ์๋ฅผ ๋ฐํํ๋ ๋ชจ๋ ํจ์๋ฅผ ์ ๊ณตํ ์ ์์ต๋๋ค. API ์ค๋ช ์์์ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ฉํธ๋ฆญ ๋ชฉ๋ก์ ์ฐธ์กฐํ์ธ์.
๊ฒฐ์ ํ์ผ๋ฉด ์ ๊ณต๋ ์ต์
์ผ๋ก model.compile()
์ ํธ์ถํ์ฌ LayersModel
์ ์ปดํ์ผํฉ๋๋ค.
์ปดํ์ผํ๋ ๋์ ๋ชจ๋ธ์ ์ ํํ ์ต์ ์ด ์๋ก ํธํ๋๋์ง๋ฅผ ํ์ธํ๋ ๋ช ๊ฐ์ง ๊ฒ์ฆ์ ์ํํฉ๋๋ค.
ํ๋ จ
LayersModel
์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์๋ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค.
model.fit()
๋ฅผ ์ฌ์ฉํ๊ณ ๋ฐ์ดํฐ๋ฅผ ํ๋์ ํฐ ํ ์๋ก ์ ๊ณตํ๊ธฐmodel.fitDataset()
๋ฐDataset
๊ฐ์ฒด๋ฅผ ํตํด ๋ฐ์ดํฐ ์ ๊ณตํ๊ธฐ
model.fit()
๋ฐ์ดํฐ์ธํธ๊ฐ ์ฃผ ๋ฉ๋ชจ๋ฆฌ์ ์ ํฉํ๊ฒ ๋ง๊ณ ๋จ์ผ ํ
์๋ก ์ฌ์ฉํ ์์๋ ๊ฒฝ์ฐ fit()
๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ ์ ์์ต๋๋ค.
๋ด๋ถ์ ์ผ๋ก model.fit()
๋ ๋ง์ ์ผ์ ํ ์ ์์ต๋๋ค.
๋ฐ์ดํฐ๋ฅผ ํ๋ จ ๋ฐ ๊ฒ์ฆ ์ธํธ๋ก ๋ถํ ํ๊ณ ๊ฒ์ฆ ์ธํธ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ ์ค ์งํ ์ํฉ์ ์ธก์ ํฉ๋๋ค.
๋ถํ ํ์ ๋ฐ์ดํฐ๋ฅผ ์ ํํฉ๋๋ค. ์์ ์ ์ํด ๋ฐ์ดํฐ๋ฅผ
fit()
๋ก ์ ๋ฌํ๊ธฐ ์ ์ ๋ฏธ๋ฆฌ ์ ํํด์ผ ํฉ๋๋ค.ํฐ ๋ฐ์ดํฐ ํ ์๋ฅผ
batchSize.
ํฌ๊ธฐ์ ๋ ์์ ํ ์๋ก ๋ถํ ํฉ๋๋ค.๋ฐ์ดํฐ ๋ฐฐ์น์ ๊ด๋ จํ์ฌ ๋ชจ๋ธ ์์ค์ ๊ณ์ฐํ๋ ๋์
optimizer.minimize()
๋ฅผ ํธ์ถํฉ๋๋ค.๊ฐ epoch ๋๋ ๋ฐฐ์น์ ์์๊ณผ ๋์ ์๋ ค์ค ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ์๋ ๋ชจ๋ ๋ฐฐ์น๊ฐ ๋๋ ๋
callbacks.onBatchEnd
์ต์ ์ ์ฌ์ฉํ์ฌ ์๋ฆผ์ ๋ฐ์ต๋๋ค. ๋ค๋ฅธ ์ต์ ์ผ๋ก๋onTrainBegin
,onTrainEnd
,onEpochBegin
,onEpochEnd
๋ฐonBatchBegin
์ด ์์ต๋๋ค.JS ์ด๋ฒคํธ ๋ฃจํ์ ๋๊ธฐ ์ค์ธ ์์ ์ ์ ์์ ์ฒ๋ฆฌํ ์ ์๋๋ก ์ฃผ ์ค๋ ๋์ ์๋ณดํฉ๋๋ค.
์์ธํ ๋ด์ฉ์ fit()
์ค๋ช
์๋ฅผ ์ฐธ์กฐํ์ธ์. Core API๋ฅผ ์ฌ์ฉํ๊ธฐ๋ก ์ ํํ ๊ฒฝ์ฐ ์ด ๋ก์ง์ ์ง์ ๊ตฌํํด์ผ ํฉ๋๋ค.
model.fitDataset()
๋ฐ์ดํฐ๊ฐ ๋ฉ๋ชจ๋ฆฌ์ ์์ ํ ๋ง์ง ์๊ฑฐ๋ ์คํธ๋ฆฌ๋ฐ๋๋ ๊ฒฝ์ฐ Dataset
๊ฐ์ฒด๋ฅผ ์ฌ์ฉํ๋ fitDataset()
๋ฅผ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ ์ ์์ต๋๋ค. ๋ค์์ ๊ฐ์ ํ๋ จ ์ฝ๋์ด์ง๋ง ์์ฑ๊ธฐ ํจ์๋ฅผ ๋ํํ๋ ๋ฐ์ดํฐ์ธํธ๊ฐ ์์ต๋๋ค.
๋ฐ์ดํฐ์ธํธ์ ๋ํ ์์ธํ ๋ด์ฉ์ model.fitDataset()
์ค๋ช
์๋ฅผ ์ฐธ์กฐํ์ธ์.
์๋ก์ด ๋ฐ์ดํฐ ์์ธกํ๊ธฐ
ํ๋ฒ ํ๋ จ์ ๊ฑฐ์น๋ฉด ๋ชจ๋ธ์ด model.predict()
๋ฅผ ํธ์ถํ์ฌ ๋ณด์ด์ง ์๋ ๋ฐ์ดํฐ์ ์์ธก๊ฐ์ ๋ผ ์ ์์ต๋๋ค.
์ฐธ๊ณ : ๋ชจ๋ธ ๋ฐ ๋ ์ด์ด ๊ฐ์ด๋์์ ์ธ๊ธ๋์๋ฏ์ดLayersModel
์ ์
๋ ฅ์ ๊ฐ์ฅ ๋ฐ๊นฅ์ชฝ ์ฐจ์์ด ๋ฐฐ์น ํฌ๊ธฐ์ผ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์์ ์์์ ๋ฐฐ์น ํฌ๊ธฐ๋ 3์
๋๋ค.
Core API
์์ TensorFlow.js์์ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์๋ ๋ ๊ฐ์ง๊ฐ ์๋ค๊ณ ์ธ๊ธํ์ต๋๋ค.
์ผ๋ฐ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก๋ ๋จผ์ ์ ์ฑํ๋ Keras API๋ฅผ ๋ชจ๋ธ๋ก ํ๋ Layers API๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. Layers API๋ ๊ฐ์ค์น ์ด๊ธฐํ, ๋ชจ๋ธ ์ง๋ ฌํ, ๋ชจ๋ํฐ๋ง ํ๋ จ, ์ด์์ฑ ๋ฐ ์์ ๊ฒ์ฌ์ ๊ฐ์ ๋ค์ํ ๊ธฐ์ฑ ์๋ฃจ์ ๋ ์ ๊ณตํฉ๋๋ค.
๋ค์๊ณผ ๊ฐ์ ๊ฒฝ์ฐ Core API๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ต๋ํ์ ์ ์ฐ์ฑ ๋๋ ์ ์ด๊ฐ ํ์ํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ง๋ ฌํ๊ฐ ํ์ํ์ง ์๊ฑฐ๋ ์์ฒด์ ์ผ๋ก ์ง๋ ฌํ ๋ ผ๋ฆฌ๋ฅผ ๊ตฌํํ ์ ์์ต๋๋ค.
์์ธํ ๋ด์ฉ์ ๋ชจ๋ธ ๋ฐ ๋ ์ด์ด ๊ฐ์ด๋์ 'Core API'์น์ ์ ์ฐธ์กฐํ์ธ์.
Core API๋ฅผ ์ฌ์ฉํ์ฌ ์์ฑ๋ ์์ ๋์ผํ ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Layers API ์ธ์๋ Data API๋ Core API์ ์ํํ๊ฒ ๋์ํฉ๋๋ค. ์ ํ ๋ฐ ์ผ๊ด ์ฒ๋ฆฌ๋ฅผ ์ํํ๋ model.fitDataset () ์น์ ์์ ์ด์ ์ ์ ์ํ ๋ฐ์ดํฐ์ธํธ๋ฅผ ๋ค์ ์ฌ์ฉํด๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ์ ํ๋ จํด๋ณด๊ฒ ์ต๋๋ค.
์์ ์ฝ๋๋ Core API๋ก ๋ชจ๋ธ์ ํ๋ จํ ๋ ์ฐ์ด๋ ํ์ค ๋ ์ํผ์ ๋๋ค.
epoch ์๋ฅผ ๋ฐ๋ณตํฉ๋๋ค.
๊ฐ epoch ๋ด์์ ๋ฐ์ดํฐ ๋ฐฐ์น๋ฅผ ๋ฐ๋ณตํฉ๋๋ค.
Dataset
๋ฅผ ์ฌ์ฉํ ๋<a href="
https://js.tensorflow.org/api/0.15.1/#tf.data.Dataset.forEachAsync
" data-md-type="link">dataset.forEachAsync()</a>
๋ ๋ฐฐ์น๋ฅผ ๋ฐ๋ณตํ๋ ํธ๋ฆฌํ ๋ฐฉ๋ฒ์ ๋๋ค.๊ฐ ๋ฐฐ์น์ ๋ํด
<a href="
https://js.tensorflow.org/api/latest/#tf.train.Optimizer.minimize
" data-md-type="link">optimizer.minimize(f)</a>
๋ฅผ ํธ์ถํ๋ฉดf
๋ฅผ ์คํํ๊ณ ์์ ์ ์ํ 4๊ฐ์ ๋ณ์์ ๋ํ ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ์ฌ ์ถ๋ ฅ์ ์ต์ํํฉ๋๋ค.f
๋ ์์ค์ ๊ณ์ฐํฉ๋๋ค. ๋ชจ๋ธ์ ์์ธก๊ฐ๊ณผ ์ค์ ๊ฐ์ ์ฌ์ฉํ์ฌ ๋ฏธ๋ฆฌ ์ ์๋ ์์ค ํจ์ ์ค ํ๋๋ฅผ ํธ์ถํฉ๋๋ค.