Path: blob/master/site/ko/r1/tutorials/eager/custom_training.ipynb
25118 views
Copyright 2018 The TensorFlow Authors.
์ฌ์ฉ์ ์ ์ ํ์ต: ๊ธฐ์ด
Note: ์ด ๋ฌธ์๋ ํ ์ํ๋ก ์ปค๋ฎค๋ํฐ์์ ๋ฒ์ญํ์ต๋๋ค. ์ปค๋ฎค๋ํฐ ๋ฒ์ญ ํ๋์ ํน์ฑ์ ์ ํํ ๋ฒ์ญ๊ณผ ์ต์ ๋ด์ฉ์ ๋ฐ์ํ๊ธฐ ์ํด ๋ ธ๋ ฅํจ์๋ ๋ถ๊ตฌํ๊ณ ๊ณต์ ์๋ฌธ ๋ฌธ์์ ๋ด์ฉ๊ณผ ์ผ์นํ์ง ์์ ์ ์์ต๋๋ค. ์ด ๋ฒ์ญ์ ๊ฐ์ ํ ๋ถ๋ถ์ด ์๋ค๋ฉด tensorflow/docs ๊นํ ์ ์ฅ์๋ก ํ ๋ฆฌํ์คํธ๋ฅผ ๋ณด๋ด์ฃผ์๊ธฐ ๋ฐ๋๋๋ค. ๋ฌธ์ ๋ฒ์ญ์ด๋ ๋ฆฌ๋ทฐ์ ์ฐธ์ฌํ๋ ค๋ฉด [email protected]๋ก ๋ฉ์ผ์ ๋ณด๋ด์ฃผ์๊ธฐ ๋ฐ๋๋๋ค.
์ด์ ํํ ๋ฆฌ์ผ์์๋ ๋จธ์ ๋ฌ๋์ ์ํ ๊ธฐ๋ณธ ๊ตฌ์ฑ ์์์ธ ์๋ ๋ฏธ๋ถ(automatic differentiation)์ ์ํ ํ ์ํ๋ก API๋ฅผ ์์๋ณด์์ต๋๋ค. ์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋ ์ด์ ํํ ๋ฆฌ์ผ์์ ์๊ฐ๋์๋ ํ ์ํ๋ก์ ๊ธฐ๋ณธ ์์๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ๋จํ ๋จธ์ ๋ฌ๋์ ์ํํด๋ณด๊ฒ ์ต๋๋ค.
ํ
์ํ๋ก๋ ๋ฐ๋ณต๋๋ ์ฝ๋๋ฅผ ์ค์ด๊ธฐ ์ํด ์ ์ฉํ ์ถ์ํ๋ฅผ ์ ๊ณตํ๋ ๊ณ ์์ค ์ ๊ฒฝ๋ง(neural network) API์ธ tf.keras
๋ฅผ ํฌํจํ๊ณ ์์ต๋๋ค. ์ ๊ฒฝ๋ง์ ๋ค๋ฃฐ ๋ ์ด๋ฌํ ๊ณ ์์ค์ API์ ๊ฐํ๊ฒ ์ถ์ฒํฉ๋๋ค. ์ด๋ฒ ์งง์ ํํ ๋ฆฌ์ผ์์๋ ํํํ ๊ธฐ์ด๋ฅผ ๊ธฐ๋ฅด๊ธฐ ์ํด ๊ธฐ๋ณธ์ ์ธ ์์๋ง์ผ๋ก ์ ๊ฒฝ๋ง ํ๋ จ์์ผ ๋ณด๊ฒ ์ต๋๋ค.
์ค์
๋ณ์
ํ ์ํ๋ก์ ํ ์(Tensor)๋ ์ํ๊ฐ ์๊ณ , ๋ณ๊ฒฝ์ด ๋ถ๊ฐ๋ฅํ(immutable stateless) ๊ฐ์ฒด์ ๋๋ค. ๊ทธ๋ฌ๋ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ํ๊ฐ ๋ณ๊ฒฝ๋ (stateful) ํ์๊ฐ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ชจ๋ธ ํ์ต์์ ์์ธก์ ๊ณ์ฐํ๊ธฐ ์ํ ๋์ผํ ์ฝ๋๋ ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ(ํฌ๋งํ๊ฑด๋ ๋ ๋ฎ์ ์์ค๋ก ๊ฐ๋ ๋ฐฉํฅ์ผ๋ก)๋์ํด์ผ ํฉ๋๋ค. ์ด ์ฐ์ฐ ๊ณผ์ ์ ํตํด ๋ณํ๋์ด์ผ ํ๋ ์ํ๋ฅผ ํํํ๊ธฐ ์ํด ๋ช ๋ นํ ํ๋ก๊ทธ๋๋ฐ ์ธ์ด์ธ ํ์ด์ฌ์ ์ฌ์ฉ ํ ์ ์์ต๋๋ค.
ํ ์ํ๋ก๋ ์ํ๋ฅผ ๋ณ๊ฒฝํ ์ ์๋ ์ฐ์ฐ์๊ฐ ๋ด์ฅ๋์ด ์์ผ๋ฉฐ, ์ด๋ฌํ ์ฐ์ฐ์๋ ์ํ๋ฅผ ํํํ๊ธฐ ์ํ ์ ์์ค ํ์ด์ฌ ํํ๋ณด๋ค ์ฌ์ฉํ๊ธฐ๊ฐ ๋ ์ข์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ชจ๋ธ์์ ๊ฐ์ค์น๋ฅผ ๋ํ๋ด๊ธฐ ์ํด์ ํ ์ํ๋ก ๋ณ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ํธํ๊ณ ํจ์จ์ ์ ๋๋ค.
ํ
์ํ๋ก ๋ณ์๋ ๊ฐ์ ์ ์ฅํ๋ ๊ฐ์ฒด๋ก ํ
์ํ๋ก ์ฐ์ฐ์ ์ฌ์ฉ๋ ๋ ์ ์ฅ๋ ์ด ๊ฐ์ ์ฝ์ด์ฌ ๊ฒ์
๋๋ค. tf.assign_sub
, tf.scatter_update
๋ฑ์ ํ
์ํ๋ก ๋ณ์์ ์ ์ฅ๋์๋ ๊ฐ์ ์กฐ์ํ๋ ์ฐ์ฐ์์
๋๋ค.
๋ณ์๋ฅผ ์ฌ์ฉํ ์ฐ์ฐ์ ๊ทธ๋๋์ธํธ๊ฐ ๊ณ์ฐ๋ ๋ ์๋์ ์ผ๋ก ์ถ์ ๋ฉ๋๋ค. ์๋ฒ ๋ฉ(embedding)์ ๋ํ๋ด๋ ๋ณ์์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ํฌ์ ํ ์(sparse tensor)๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋ฐ์ดํธ๋ฉ๋๋ค. ์ด๋ ์ฐ์ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ์ ๋์ฑ ํจ์จ์ ์ ๋๋ค.
๋ํ ๋ณ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ์ฝ๋๋ฅผ ์ฝ๋ ๋ ์์๊ฒ ์ํ๊ฐ ๋ณ๊ฒฝ๋ ์ ์๋ค๋ ๊ฒ์ ์๋ ค์ฃผ๋ ์์ฌ์ด ๋ฐฉ๋ฒ์ ๋๋ค.
์: ์ ํ ๋ชจ๋ธ ํ๋ จ
์ง๊ธ๊น์ง ๋ช ๊ฐ์ง ๊ฐ๋
์ ์ค๋ช
ํ์ต๋๋ค. ๊ฐ๋จํ ๋ชจ๋ธ์ ๊ตฌ์ถํ๊ณ ํ์ต์ํค๊ธฐ ์ํด ---Tensor
, GradientTape
, Variable
--- ๋ฑ์ ์ฌ์ฉํ์๊ณ , ์ด๋ ์ผ๋ฐ์ ์ผ๋ก ๋ค์์ ๊ณผ์ ์ ํฌํจํฉ๋๋ค.
๋ชจ๋ธ ์ ์
์์ค ํจ์ ์ ์
ํ๋ จ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
ํ๋ จ ๋ฐ์ดํฐ์์ ์คํ, ๋ฐ์ดํฐ์ ์ต์ ํํ๊ธฐ ์ํด "์ตํฐ๋ง์ด์ (optimizer)"๋ฅผ ์ฌ์ฉํ ๋ณ์ ์กฐ์
์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋ ์ ํ ๋ชจ๋ธ์ ๊ฐ๋จํ ์์ ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. f(x) = x * W + b
, ๋ชจ๋ธ์ W
์ b
๋ ๋ณ์๋ฅผ ๊ฐ์ง๊ณ ์๋ ์ ํ๋ชจ๋ธ์ด๋ฉฐ, ์ ํ์ต๋ ๋ชจ๋ธ์ด W = 3.0
and b = 2.0
์ ๊ฐ์ ๊ฐ๋๋ก ํฉ์ฑ ๋ฐ์ดํฐ๋ฅผ ๋ง๋ค๊ฒ ์ต๋๋ค.
๋ชจ๋ธ ์ ์
๋ณ์์ ์ฐ์ฐ์ ์บก์ํํ๊ธฐ ์ํ ๊ฐ๋จํ ํด๋์ค๋ฅผ ์ ์ํด๋ด ์๋ค.
์์ค ํจ์ ์ ์
์์ค ํจ์๋ ์ฃผ์ด์ง ์ ๋ ฅ์ ๋ํ ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ์ํ๋ ์ถ๋ ฅ๊ณผ ์ผ๋ง๋ ์ ์ผ์นํ๋์ง๋ฅผ ์ธก์ ํฉ๋๋ค. ํ๊ท ์ ๊ณฑ ์ค์ฐจ(mean square error)๋ฅผ ์ ์ฉํ ์์ค ํจ์๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
ํ๋ จ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
์ฝ๊ฐ์ ์ก์๊ณผ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ํฉ์นฉ๋๋ค.
๋ชจ๋ธ์ ํ๋ จ์ํค๊ธฐ ์ ์, ๋ชจ๋ธ์ ํ์ฌ ์ํ๋ฅผ ์๊ฐํํฉ์๋ค. ๋ชจ๋ธ์ ์์ธก์ ๋นจ๊ฐ์์ผ๋ก, ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ํ๋์์ผ๋ก ๊ตฌ์ฑํฉ๋๋ค.
ํ๋ จ ๋ฃจํ ์ ์
์ด์ ๋คํธ์ํฌ์ ํ๋ จ ๋ฐ์ดํฐ๊ฐ ์ค๋น๋์์ต๋๋ค. ๋ชจ๋ธ์ ๋ณ์(W
์ b
)๋ฅผ ์
๋ฐ์ดํธํ๊ธฐ ์ํด ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ์์ผ ๋ณด์ฃ . ๊ทธ๋ฆฌ๊ณ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ(gradient descent)์ ์ฌ์ฉํ์ฌ ์์ค์ ๊ฐ์์ํต๋๋ค. ๊ฒฝ์ฌ ํ๊ฐ๋ฒ์๋ ์ฌ๋ฌ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ผ๋ฉฐ, tf.train.Optimizer
์ ๊ตฌํ๋์ด์์ต๋๋ค. ์ด๋ฌํ ๊ตฌํ์ ์ฌ์ฉํ๋๊ฒ์ ๊ฐ๋ ฅํ ์ถ์ฒ๋๋ฆฝ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋ ๊ธฐ๋ณธ์ ์ธ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก, ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์คํํ๊ณ , W
์ b
์ ๋ณํ ๊ณผ์ ์ ํ์ธํฉ๋๋ค.
๋ค์ ๋จ๊ณ
์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋ ๋ณ์๋ฅผ ๋ค๋ฃจ์์ผ๋ฉฐ, ์ง๊ธ๊น์ง ๋ ผ์๋ ํ ์ํ๋ก์ ๊ธฐ๋ณธ ์์๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ๋จํ ์ ํ ๋ชจ๋ธ์ ๊ตฌ์ถํ๊ณ ํ๋ จ์์ผฐ์ต๋๋ค.
์ด๋ก ์ ์ผ๋ก, ํ
์ํ๋ก๋ฅผ ๋จธ์ ๋ฌ๋ ์ฐ๊ตฌ์ ์ฌ์ฉํ๊ธฐ ์ํด ์์์ผ ํ ๊ฒ์ด ๋งค์ฐ ๋ง์ต๋๋ค. ์ค์ ๋ก ์ ๊ฒฝ๋ง์ ์์ด tf.keras
์ ๊ฐ์ ๊ณ ์์ค API๋ ๊ณ ์์ค ๊ตฌ์ฑ ์์("์ธต"์ผ๋ก ๋ถ๋ฆฌ๋)๋ฅผ ์ ๊ณตํ๊ณ , ์ ์ฅ ๋ฐ ๋ณต์์ ์ํ ์ ํธ๋ฆฌํฐ, ์์ค ํจ์ ๋ชจ์, ์ต์ ํ ์ ๋ต ๋ชจ์ ๋ฑ์ ์ ๊ณตํ๊ธฐ ๋๋ฌธ์ ๋์ฑ ํธ๋ฆฌํฉ๋๋ค.