op ๋ง๋ค๊ธฐ
์ฐธ๊ณ : C++ ์ฌ์ฉ์ ์ ์ ops์ ABI๊ฐ TensorFlow์ ๊ณต์ pip ํจํค์ง์ ํธํ๋๋๋ก ํ๋ ค๋ฉด ์ฌ์ฉ์ ์ ์ op ๋ฆฌํฌ์งํ ๋ฆฌ์ ๊ฐ์ด๋๋ฅผ ๋ฐ๋ฅด์ธ์. ๊ฐ์ด๋์๋ ์๋ ํฌ ์๋ ์ฝ๋ ์์ ์ ์ฌ์ฉ์ ์ง์ ops๋ฅผ ์์ฑ ๋ฐ ๋ฐฐํฌํ๊ธฐ ์ํ Docker ์ด๋ฏธ์ง๊ฐ ์์ต๋๋ค.
๊ธฐ์กด TensorFlow ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํฌํจ๋์ง ์๋ op๋ฅผ ๋ง๋ค๋ ค๋ฉด ๋จผ์ ๊ธฐ์กด Python ops ๋๋ ํจ์์ ๊ตฌ์ฑ์ผ๋ก op๋ฅผ Python์ผ๋ก ์์ฑํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๊ฐ๋ฅํ์ง ์๋ค๋ฉด, ์ฌ์ฉ์ ์ ์ C++ op๋ฅผ ์์ฑํ ์ ์์ต๋๋ค. ์ฌ์ฉ์ ์ ์ C++ op๋ฅผ ์์ฑํ๋ ๋ช ๊ฐ์ง ์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
๊ธฐ์กด ops์ ๊ตฌ์ฑ์ผ๋ก ์์ ์ ํํํ๋ ๊ฒ์ ์ฝ์ง ์๊ฑฐ๋ ๋ถ๊ฐ๋ฅํฉ๋๋ค.
๊ธฐ์กด ํ๋ฆฌ๋ฏธํฐ๋ธ์ ๊ตฌ์ฑ์ผ๋ก ์ฐ์ฐ์ ํํํ๋ ๊ฒ์ด ๋นํจ์จ์ ์ผ ๋
์ฌ์ฉ์๊ฐ ๋ฏธ๋์ ์ปดํ์ผ๋ฌ์์ ์ตํฉ์ด ์ด๋ ค์ด ํ๋ฆฌ๋ฏธํฐ๋ธ์ ๊ตฌ์ฑ์ ์๋์ผ๋ก ์ตํฉํ๋ ค ํ ๋
์๋ฅผ ๋ค์ด, "MaxPool" ์ฐ์ฐ์์ ๋น์ทํ "์ค์๊ฐ ํ๋ง"๊ณผ ๊ฐ์ ์ฐ์ฐ์ ๊ตฌํํ ๋ ์ต๋๊ฐ ๋์ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ์ ๋ํด ์ค์๊ฐ์ ๊ณ์ฐํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. ์ฐ์ฐ์ ๊ตฌ์ฑ์ ์ฌ์ฉํ์ฌ ์ด ์ฐ์ฐ์ ์ํํ ์ ์์ง๋ง(์: ExtractImagePatches ๋ฐ TopK ์ฌ์ฉ), ๋จ์ผ ์ตํฉ ์ฐ์ฐ์ผ๋ก ๋ ๋๋ํ ์ฐ์ฐ์ ์ํํ ์ ์๋ ๋ค์ดํฐ๋ธ ์ฐ์ฐ๋ณด๋ค๋ ์ฑ๋ฅ ๋๋ ๋ฉ๋ชจ๋ฆฌ์ ํจ์จ์ฑ์ด ๋จ์ด์ง ์ ์์ต๋๋ค. ํญ์ ๊ทธ๋ ๋ฏ์ด, ์ผ๋ฐ์ ์ผ๋ก ์ฐ์ฐ์ ๊ตฌ์ฑ์ ์ฌ์ฉํ์ฌ ์ํ๋ ๊ฒ์ ํํํด ๋ณผ ๋งํ๋ฐ, ๊ฐ์ฅ ์ด๋ ต๊ณ ๋นํจ์จ์ ์ผ ๊ฒฝ์ฐ์๋ง ์ ์ฐ์ฐ์ ์ถ๊ฐํ๋๋ก ์ ํํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์ฌ์ฉ์ ์ ์ op๋ฅผ ํตํฉํ๋ ค๋ฉด ๋ค์์ ์ํํด์ผ ํฉ๋๋ค.
C++ ํ์ผ์ ์ op๋ฅผ ๋ฑ๋กํฉ๋๋ค. Op ๋ฑ๋ก์์ op์ ๊ตฌํ๊ณผ๋ ๋ ๋ฆฝ์ ์ธ op ๊ธฐ๋ฅ์ ๋ํ ์ธํฐํ์ด์ค(์ฌ์)๋ฅผ ์ ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, op ๋ฑ๋ก์์ op์ ์ด๋ฆ๊ณผ op์ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ์ ์ ์ํฉ๋๋ค. ๋ํ, ํ ์ ํ์ ์ ์ถ์ ์ฌ์ฉ๋๋ ํ์ ํจ์๋ฅผ ์ ์ํฉ๋๋ค.
C++๋ก op๋ฅผ ๊ตฌํํฉ๋๋ค. op์ ๊ตฌํ์ ์ปค๋์ด๋ผ๊ณ ํ๋ฉฐ 1๋จ๊ณ์์ ๋ฑ๋กํ ์ฌ์์ ๊ตฌ์ฒด์ ์ธ ๊ตฌํ์ ๋๋ค. ๋ค์ํ ์ ๋ ฅ/์ถ๋ ฅ ์ ํ ๋๋ ์ํคํ ์ฒ(์: CPU, GPU)๋ฅผ ์ํ ์ปค๋์ด ์ฌ๋ฌ ๊ฐ ์์ ์ ์์ต๋๋ค.
Python ๋ํผ๋ฅผ ๋ง๋ญ๋๋ค(์ ํ ์ฌํญ). ์ด ๋ํผ๋ Python์์ op๋ฅผ ๋ง๋๋ ๋ฐ ์ฌ์ฉ๋๋ ๊ณต๊ฐ API์ ๋๋ค. ๊ธฐ๋ณธ ๋ํผ๋ op ๋ฑ๋ก์์ ์์ฑ๋๋ฉฐ ์ง์ ์ฌ์ฉํ๊ฑฐ๋ ์ถ๊ฐํ ์ ์์ต๋๋ค.
op์ ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ๋ ํจ์๋ฅผ ์์ฑํฉ๋๋ค(์ ํ ์ฌํญ).
op๋ฅผ ํ ์คํธํฉ๋๋ค. ๋ณดํต ํธ์๋ฅผ ์ํด ํ์ด์ฌ์์ ์ด ์ฐ์ฐ์ ํ ์คํธํ์ง๋ง, C++์์ op๋ฅผ ํ ์คํธํ ์๋ ์์ต๋๋ค. ๊ทธ๋๋์ธํธ๋ฅผ ์ ์ํ๋ฉด ํ์ด์ฌ
tf.test.compute_gradient_error์ ์ฌ์ฉํ์ฌ ํ์ธํ ์ ์์ต๋๋ค. Relu ๊ฐ์ ์ฐ์ฐ์์ ์ ๋ฌ ํจ์์ ๊ทธ๋๋์ธํธ๋ฅผ ํ ์คํธํ๋ ์์ ๋relu_op_test.py๋ฅผ ์ฐธ์กฐํ์ธ์.
์ ์ ์กฐ๊ฑด
C++์ ์ด๋ ์ ๋ ์ต์ํด์ผ ํฉ๋๋ค.
TensorFlow ๋ฐ์ด๋๋ฆฌ๋ฅผ ์ค์นํ๊ฑฐ๋ TensorFlow ์์ค๋ฅผ ๋ค์ด๋ก๋ํ์ฌ ๋น๋ํ ์ ์์ด์ผ ํฉ๋๋ค.
op ์ธํฐํ์ด์ค ์ ์ํ๊ธฐ
op๋ฅผ TensorFlow ์์คํ ์ ๋ฑ๋กํ์ฌ op์ ์ธํฐํ์ด์ค๋ฅผ ์ ์ํฉ๋๋ค. ๋ฑ๋ก ์ op์ ์ด๋ฆ, ํด๋น ์ ๋ ฅ(์ ํ ๋ฐ ์ด๋ฆ) ๋ฐ ์ถ๋ ฅ(์ ํ ๋ฐ ์ด๋ฆ), ๊ทธ๋ฆฌ๊ณ docstrings ๋ฐ op์ ํ์ํ attrs๋ฅผ ์ง์ ํฉ๋๋ค.
์๋ ์๋ฆฌ๋ฅผ ์์๋ณด๊ธฐ ์ํด int32์ ํ
์๋ฅผ ๊ฐ์ ธ์์ ์ฒซ ๋ฒ์งธ ์์๋ฅผ ์ ์ธํ ๋ชจ๋ ์์๋ฅผ โโ0์ผ๋ก ์ค์ ํ์ฌ ํ
์์ ๋ณต์ฌ๋ณธ์ ์ถ๋ ฅํ๋ op๋ฅผ ๋ง๋ ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. ๊ทธ๋ ๊ฒ ํ๋ ค๋ฉด, zero_out.cc์ด๋ผ๋ ํ์ผ์ ์์ฑํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์, op์ ์ธํฐํ์ด์ค๋ฅผ ์ ์ํ๋ REGISTER_OP ๋งคํฌ๋ก์ ๋ํ ํธ์ถ์ ์ถ๊ฐํฉ๋๋ค.
์ด ZeroOut op๋ 32-bit ์ ์์ ํ
์ to_zero ํ๋๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ณ 32-bit ์ ์์ ํ
์ zeroed๋ฅผ ์ถ๋ ฅํฉ๋๋ค. ๋ํ, op๋ ํ์ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ ฅ ํ
์๊ฐ ์
๋ ฅ ํ
์์ ๊ฐ์ ํ์์ด ๋๋๋ก ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์
๋ ฅ์ด ํ์[10, 20]์ ํ
์์ธ ๊ฒฝ์ฐ, ์ด ํ์ ํจ์๋ ์ถ๋ ฅ ํ์๋ [10, 20]๋ก ์ง์ ํฉ๋๋ค.
์ฐธ๊ณ : op ์ด๋ฆ์ CamelCase์ฌ์ผ ํ๋ฉฐ ๋ฐ์ด๋๋ฆฌ์ ๋ฑ๋ก๋ ๋ค๋ฅธ ๋ชจ๋ op ์ค์์ ๊ณ ์ ํด์ผ ํฉ๋๋ค.
op์ ์ปค๋ ๊ตฌํํ๊ธฐ
์ธํฐํ์ด์ค๋ฅผ ์ ์ํ ํ, ํ๋ ์ด์์ op ๊ตฌํ์ ์ ๊ณตํฉ๋๋ค. ์ด๋ค ์ปค๋ ์ค ํ๋๋ฅผ ์์ฑํ๋ ค๋ฉด, OpKernel์ ํ์ฅํ์ฌ Compute ๋ฉ์๋๋ฅผ ๋์ฒดํ๋ ํด๋์ค๋ฅผ ์์ฑํฉ๋๋ค. Compute ๋ฉ์๋๋ ์ ํ OpKernelContext*์ context ์ธ์๋ฅผ ํ๋ ์ ๊ณตํ๋ฉฐ, ์ด ์ธ์์์ ์
๋ ฅ ๋ฐ ์ถ๋ ฅ ํ
์์ ๊ฐ์ ์ ์ฉํ ํญ๋ชฉ์ ์ก์ธ์คํ ์ ์์ต๋๋ค.
์์์ ๋ง๋ ํ์ผ์ ์ปค๋์ ์ถ๊ฐํฉ๋๋ค. ์ปค๋์ ๋ค์๊ณผ ๊ฐ์ ์ ์์ต๋๋ค.
์ปค๋์ ๊ตฌํํ ํ์๋ TensorFlow ์์คํ ์ ์ปค๋์ ๋ฑ๋กํฉ๋๋ค. ๋ฑ๋ก ์ ์ด ์ปค๋์ด ์คํ๋ ๋ค๋ฅธ ์ ์ฝ ์กฐ๊ฑด์ ์ง์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, CPU์ฉ ์ปค๋ ํ๋์ GPU์ฉ ์ปค๋ ํ๋๊ฐ ์์ ์ ์์ต๋๋ค.
ZeroOut op์ฉ ์ปค๋์ ๊ตฌํํ๋ ค๋ฉด, zero_out.cc์ ๋ค์์ ์ถ๊ฐํฉ๋๋ค.
์ค์: OpKernel ์ธ์คํด์ค์ ๋์์ ์ก์ธ์คํ ์ ์์ต๋๋ค.
Compute๋ฉ์๋๋ ์ค๋ ๋๋ก๋ถํฐ ์์ ํด์ผ ํฉ๋๋ค. ๋ฎคํ ์ค๋ฅผ ์ฌ์ฉํ์ฌ ํด๋์ค ๋ฉค๋ฒ์ ๋ํ ์ก์ธ์ค๋ฅผ ๋ณดํธํ์ธ์. ๋๋ ๋ ๋์ ๋ฐฉ๋ฒ์ผ๋ก, ํด๋์ค ๋ฉค๋ฒ๋ฅผ ํตํด ์ํ๋ฅผ ๊ณต์ ํ์ง ๋ง์ธ์! op ์ํ๋ฅผ ์ถ์ ํ๊ธฐ ์ํดResourceMgr๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋ค์ค ์ค๋ ๋ CPU ์ปค๋
๋ค์ค ์ค๋ ๋ CPU ์ปค๋์ ์์ฑํ๊ธฐ ์ํด work_sharder.h์ Shard ํจ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด ํจ์๋ intra-op ์ค๋ ๋ฉ์ ์ฌ์ฉ๋๋๋ก ๊ตฌ์ฑ๋ ์ค๋ ๋ ๊ฐ์ ๊ณ์ฐ ํจ์๋ฅผ ๋ถํ ํฉ๋๋ค(config.proto์ intra_op_parallelism_threads ์ฐธ์กฐ).
GPU ์ปค๋
GPU ์ปค๋์ OpKernel ๋ฐ CUDA ์ปค๋๊ณผ ์์ ์ฝ๋์ ๋ ๋ถ๋ถ์ผ๋ก ๊ตฌํ๋ฉ๋๋ค.
์ ๋ ฅ ๊ฒ์ฌ ๋ฐ ์ถ๋ ฅ ํ ๋น๊ณผ ๊ฐ์ด CPU์ GPU ์ปค๋ ๊ฐ์ OpKernel ๊ตฌํ์ด ๊ณตํต์ ์ผ๋ก ์ฌ์ฉ๋๋ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ, ์ ์ ๊ตฌํ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Device ํ ํ๋ฆฟ ํ์์ OpKernel๊ณผ ํ ์์ ๊ธฐ๋ณธ ์ ํ์ ์ ์ํฉ๋๋ค.
์ถ๋ ฅ์ ์ค์ ๊ณ์ฐ์ ์ํํ๊ธฐ ์ํด Compute ํจ์์์ ํ ํ๋ฆฟ ํ์์ functor ๊ตฌ์กฐ์ฒด๋ฅผ ํธ์ถํฉ๋๋ค.
CPUDevice์ ๋ํ ํด๋น functor์ ์ ๋ฌธํ๋ ๊ฐ์ ํ์ผ์ ์ ์๋์ด ์์ง๋ง, GPUDevice์ ๋ํ ์ ๋ฌธํ๋ CUDA ์ปดํ์ผ๋ฌ๋ก ์ปดํ์ผ๋๋ฏ๋ก .cu.cc ํ์ผ์ ์ ์๋์ด ์์ต๋๋ค.
๋ค์์ ๊ตฌํ ์์ ๋๋ค.
op ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋น๋ํ๊ธฐ
์์คํ ์ปดํ์ผ๋ฌ๋ฅผ ์ฌ์ฉํ์ฌ op ์ปดํ์ผํ๊ธฐ(TensorFlow ๋ฐ์ด๋๋ฆฌ ์ค์น)
์์คํ
์์ ์ฌ์ฉ ๊ฐ๋ฅํ g++ ๋๋ clang๊ณผ ๊ฐ์ C++ ์ปดํ์ผ๋ฌ๋ก zero_out.cc๋ฅผ ์ปดํ์ผํ ์ ์์ต๋๋ค. ์ด์ง PIP ํจํค์ง๋ ์์คํ
์ ํน์ ์์น์ op๋ฅผ ์ปดํ์ผํ๋ ๋ฐ ํ์ํ ํค๋ ํ์ผ๊ณผ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํฉ๋๋ค. ํ์ง๋ง, TensorFlow Python ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ํค๋ ๋๋ ํ ๋ฆฌ๋ฅผ ๊ฐ์ ธ์ค๋ get_include ํจ์๋ฅผ ์ ๊ณตํ๋ฉฐ, get_lib ๋๋ ํ ๋ฆฌ์๋ ๋งํฌํ ๊ณต์ ๊ฐ์ฒด๊ฐ ์์ต๋๋ค. Ubuntu ๋จธ์ ์์ ์ด๋ค ํจ์์ ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
g++๋ฅผ ์ค์นํ๋ค๊ณ ๊ฐ์ ํ๋ฉด, ๋ค์์ op๋ฅผ ๋์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์ปดํ์ผํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๋ช
๋ น ์ํ์ค์
๋๋ค.
macOS์์๋ .so ํ์ผ์ ๋น๋ํ ๋ ์ถ๊ฐ ํ๋๊ทธ "-undefined dynamic_lookup"์ด ํ์ํฉ๋๋ค.
gcc๋ฒ์ >=5์ ๋ํ ์ฐธ๊ณ ์ฌํญ: gcc๋ ๋ฒ์ 5๋ถํฐ ์๋ก์ด C++ ABI๋ฅผ ์ฌ์ฉํฉ๋๋ค. TensorFlow 2.8๊ณผ ์ด์ ๋ฒ์ ์ ๊ธฐ์กด ABI๋ฅผ ์ฌ์ฉํ๋gcc4๋ก ๋น๋๋์ต๋๋ค. ์ด๋ฌํ ๋ฒ์ ์ TensorFlow๋ฅผ ์ฌ์ฉ ์ค์ด๊ณgcc>=5๋ก ์ฐ์ฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ปดํ์ผํ๋ ค๋ ๊ฒฝ์ฐ, ๋ช ๋ น์ค์-D_GLIBCXX_USE_CXX11_ABI=0์ ์ถ๊ฐํ์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๊ธฐ์กด ABI์ ํธํ๋๋๋ก ํฉ๋๋ค. TensorFlow 2.9+ ํจํค์ง๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ต์ ABI์ ํธํ๋ฉ๋๋ค.
bazel(TensorFlow ์์ค ์ค์น)์ ์ฌ์ฉํ์ฌ op ์ปดํ์ผํ๊ธฐ
TensorFlow ์์ค๊ฐ ์ค์น๋์ด ์์ผ๋ฉด, TensorFlow์ ๋น๋ ์์คํ
์ ์ฌ์ฉํ์ฌ op๋ฅผ ์ปดํ์ผํ ์ ์์ต๋๋ค. tensorflow/core/user_ops ๋๋ ํ ๋ฆฌ์ ๋ค์ Bazel ๋น๋ ๊ท์น์ ๊ฐ์ง BUILD ํ์ผ์ ์ ์ฅํฉ๋๋ค.
๋ค์ ๋ช
๋ น์ ์คํํ์ฌ zero_out.so๋ฅผ ๋น๋ํฉ๋๋ค.
CUDA ์ปค๋์ ์ฌ์ฉํ์ฌ Example ์ฐ์ฐ์ ์ปดํ์ผํ๋ ค๋ฉด tf_custom_op_library์ gpu_srcs ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ๋ค์ Bazel ๋น๋ ๊ท์น์ด ์๋ BUILD ํ์ผ์ tensorflow/core/user_ops ๋๋ ํฐ๋ฆฌ(์: "example_gpu") ๋ด์ ์ ํด๋์ ๋ฐฐ์นํฉ๋๋ค.
๋ค์ ๋ช
๋ น์ ์คํํ์ฌ kernel_example.so๋ฅผ ๋น๋ํฉ๋๋ค.
์ฐธ๊ณ : ์์์ ์ค๋ช
ํ ๋๋ก gcc>=5๋ก ์ปดํ์ผํ๋ ๊ฒฝ์ฐ, Bazel ๋ช
๋ น์ค ์ธ์์ --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"์ ์ถ๊ฐํฉ๋๋ค.
์ฐธ๊ณ : ํ์ค
cc_library๊ท์น์ ์ฌ์ฉํ์ฌ ๊ณต์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ(.soํ์ผ)๋ฅผ ๋ง๋ค ์ ์์ง๋ง,tf_custom_op_library๋งคํฌ๋ก๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ด ๋งคํฌ๋ก๋ ํ์ ์ข ์์ฑ์ ์ถ๊ฐํ๊ณ ๊ณต์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ TensorFlow์ ํ๋ฌ๊ทธ์ธ ๋ก๋ฉ ๋ฉ์ปค๋์ฆ๊ณผ ํธํ๋๋์ง ํ์ธํฉ๋๋ค.
Python์์ op ์ฌ์ฉํ๊ธฐ
TensorFlow Python API๋ tf.load_op_library ํจ์๋ฅผ ์ ๊ณตํ์ฌ ๋์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ก๋ํ๊ณ TensorFlow ํ๋ ์์ํฌ์ op๋ฅผ ๋ฑ๋กํฉ๋๋ค. load_op_library๋ op ๋ฐ ์ปค๋์ ๋ํ Python ๋ํผ๊ฐ ํฌํจ๋ Python ๋ชจ๋์ ๋ฐํํฉ๋๋ค. ๋ฐ๋ผ์, ์ผ๋จ op๋ฅผ ๋น๋ํ๋ฉด ๋ค์์ ์ํํ์ฌ Python์์ ์คํํ ์ ์์ต๋๋ค.
์์ฑ๋ ํจ์์๋ snake_case ์ด๋ฆ์ด ์ง์ ๋ฉ๋๋ค(PEP8 ์ค์). ๋ฐ๋ผ์, C++ ํ์ผ์์ op์ ์ด๋ฆ์ด ZeroOut์ธ ๊ฒฝ์ฐ, Python ํจ์์ ์ด๋ฆ์ zero_out์
๋๋ค.
Python ๋ชจ๋์์ op๋ฅผ ์ ๊ท ํจ์๋ก import ๊ฐ๋ฅํ๊ฒ ํ๋ ค๋ฉด, ๋ค์๊ณผ ๊ฐ์ด Python ์์ค ํ์ผ์ load_op_library ํธ์ถ์ ํฌํจํ๋ ๊ฒ์ด ์ ์ฉํ ์ ์์ต๋๋ค.
op๊ฐ ์๋ํ๋์ง ํ์ธํ๊ธฐ
op๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ๊ตฌํํ๋์ง ํ์ธํ๋ ์ข์ ๋ฐฉ๋ฒ์ ํ
์คํธ๋ฅผ ์์ฑํ๋ ๊ฒ์
๋๋ค. ๋ค์ ๋ด์ฉ์ผ๋ก zero_out_op_test.py ํ์ผ์ ์์ฑํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ ํ ์คํธ๋ฅผ ์คํํฉ๋๋ค(tensorflow๊ฐ ์ค์น๋์๋ค๊ณ ๊ฐ์ ).
op์ ๊ณ ๊ธ ํน์ฑ ๋น๋ํ๊ธฐ
๊ธฐ๋ณธ (๊ทธ๋ฆฌ๊ณ , ๋ค์ ์ ํ์ ์ธ) op ๋ฐ ๊ตฌํ์ ๋น๋ํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด์์ผ๋ฏ๋ก ์ผ๋ฐ์ ์ผ๋ก op์ ๋น๋ํ๋ ๋ฐ ํ์ํ ์กฐ๊ธ ๋ ๋ณต์กํ ํญ๋ชฉ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์๋ ๋ค์์ด ํฌํจ๋ฉ๋๋ค.
์กฐ๊ฑด๋ถ ๊ฒ์ฌ ๋ฐ ํ์ธ
์์ ์์ ์์๋ op๊ฐ ๋ชจ๋ ํ์์ ํ ์์ ์ ์ฉ๋์๋ค๊ณ ๊ฐ์ ํ์ต๋๋ค. ๋ฒกํฐ์๋ง ์ ์ฉ๋ ๊ฒฝ์ฐ๋ ์ด๋ป๊ฒ ํด์ผ ํ ๊น์? ์์ OpKernel ๊ตฌํ์ ๊ฒ์ฌ๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
์
๋ ฅ์ด ๋ฒกํฐ์์ ์ธ์ฆํ๋ ๋ด์ฉ์ด๋ฉฐ, ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ, InvalidArgument ์ํ๋ฅผ ์ค์ ํ์ฌ ๋ฐํํฉ๋๋ค. OP_REQUIRES ๋งคํฌ๋ก๋ ์ธ ๊ฐ์ง ์ธ์๋ฅผ ์ฌ์ฉํฉ๋๋ค.
context๋SetStatus()๋ฉ์๋์ ๋ํOpKernelContext๋๋OpKernelConstructionํฌ์ธํฐ(tensorflow/core/framework/op_kernel.h์ฐธ์กฐ)์ผ ์ ์์ต๋๋ค.์กฐ๊ฑด. ์๋ฅผ ๋ค์ด,
tensorflow/core/framework/tensor_shape.h์ ํ ์์ ํ์์ ํ์ธํ๋ ํจ์๊ฐ ์์ต๋๋ค.Status๊ฐ์ฒด๋ก ํ์๋๋ ์ค๋ฅ ์์ฒด๋tensorflow/core/platform/status.h๋ฅผ ์ฐธ์กฐํ์ธ์.Status์๋ ์ ํ(์ข ์ขInvalidArgument์ด์ง๋ง, ์ ํ์ ๋ชฉ๋ก ์ฐธ์กฐ)๊ณผ ๋ฉ์์ง๊ฐ ์์ต๋๋ค. ์ค๋ฅ ์์ฑ ํจ์๋tensorflow/core/lib/core/errors.h์์ ์ฐพ์ ์ ์์ต๋๋ค.
์ผ๋ถ ํจ์์์ ๋ฐํ๋ Status ๊ฐ์ฒด๊ฐ ์ค๋ฅ์ธ์ง ํ
์คํธํ๋ ค๋ ๊ฒฝ์ฐ, OP_REQUIRES_OK๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด ๋ ๋งคํฌ๋ก๋ ๋ชจ๋ ์ค๋ฅ ์ ํจ์๋ก๋ถํฐ ๋ฐํํฉ๋๋ค.
op ๋ฑ๋ก
Attrs
Ops๋ attr์ ๊ฐ์ง ์ ์์ผ๋ฉฐ, op๊ฐ ๊ทธ๋ํ์ ์ถ๊ฐ๋ ๋ ๊ฐ์ด ์ค์ ๋ฉ๋๋ค. ์ด๋ค ๊ฐ์ op๋ฅผ ๊ตฌ์ฑํ๋ ๋ฐ ์ฌ์ฉ๋๋ฉฐ ์ปค๋ ๊ตฌํ ๋ด์์, ๊ทธ๋ฆฌ๊ณ op ๋ฑ๋ก์์ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ์ ํ์ผ๋ก ํด๋น ๊ฐ์ ์ก์ธ์คํ ์ ์์ต๋๋ค. ์ ๋ ฅ์ด ๋ ์ ์ฐํ๊ธฐ ๋๋ฌธ์ ๊ฐ๋ฅํ๋ฉด attr ๋์ ์ ๋ ฅ์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. attrs๋ ์์์ด๊ณ ๊ทธ๋ํ ์์ฑ ์ ์ ์ํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ฐ๋ฉด์, ์ ๋ ฅ์ ๊ฐ์ด ๋์ ์ผ ์ ์๋ ํ ์์ ๋๋ค. ์ฆ, ์ ๋ ฅ์ ๋จ๊ณ๋ง๋ค ๋ณํ ์ ์๊ณ ํผ๋๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ ํ ์ ์์ต๋๋ค. Attrs์ ์๋ช (์ ๋ ฅ ๋๋ ์ถ๋ ฅ์ ์ ๋๋ ์ ํ)์ ์ํฅ์ ๋ฏธ์น๊ฑฐ๋ ๋จ๊ณ๋ณ๋ก ๋ณ๊ฒฝํ ์ ์๋ ๊ตฌ์ฑ๊ณผ ๊ฐ์ด ์ ๋ ฅ์ผ๋ก ๊ตฌ์ฑํ ์ ์๋ ์ฐ์ฐ์ ์ฌ์ฉ๋ฉ๋๋ค.
op๋ฅผ ๋ฑ๋กํ ๋ Attr ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ op์ ์ด๋ฆ๊ณผ ์ ํ์ ์ง์ ํจ์ผ๋ก์จ attr๋ฅผ ์ ์ํฉ๋๋ค. ๋ค์ ํ์์ ์ฌ์์ด ํ์ํฉ๋๋ค.
<name>์ ๋ฌธ์๋ก ์์ํ๊ณ ์์ซ์์ ๋ฐ์ค๋ก ๊ตฌ์ฑ๋ ์ ์์ผ๋ฉฐ, <attr-type-expr>์ ์๋ ์ค๋ช
๋ ํ์์ ์ ํ ํํ์์
๋๋ค.
์๋ฅผ ๋ค์ด, ZeroOut op๊ฐ 0๋ฒ์งธ ์์๋ง์ด ์๋ ์ฌ์ฉ์ ์ง์ ์ธ๋ฑ์ค๋ฅผ ์ ์งํ๋๋ก ํ๋ ค๋ฉด op๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ๋ฑ๋กํ ์ ์์ต๋๋ค.
(์์ฑ ์ ํ์ ์งํฉ์ ์
๋ ฅ ๋ฐ ์ถ๋ ฅ์ ์ฌ์ฉ๋๋ tf.DType๊ณผ๋ ๋ค๋ฆ
๋๋ค.)
์ปค๋์ context ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์์ฑ์์์ ์ด attr์ ์ก์ธ์คํ ์ ์์ต๋๋ค.
๊ทธ๋ฐ ๋ค์ Compute ๋ฉ์๋์์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
Attr ์ ํ
๋ค์ ์ ํ์ด attr์์ ์ง์๋ฉ๋๋ค.
string: ๋ฐ์ดํธ ์ํ์ค(UTF8์ผ ํ์๋ ์์)int: ๋ถํธ ์๋ ์ ์float: ๋ถ๋ ์์์ ์ซ์bool: ์ฐธ ๋๋ ๊ฑฐ์งtype:DataType์ (๋น์ฐธ์กฐ) ๊ฐ ์ค ํ๋shape:TensorShapeProtolist(<type>):<type>์ ๋ชฉ๋ก,<type>์ ์์ ์ ํ ์ค ํ๋์ ๋๋ค.list(list(<type>))๋ ์ ํจํ์ง ์์ต๋๋ค.
๋ช
ํํ ๋ชฉ๋ก์ op_def_builder.cc:FinalizeAttr์ ์ฐธ์กฐํ์ธ์.
๊ธฐ๋ณธ๊ฐ ๋ฐ ์ ์ฝ ์กฐ๊ฑด
Attrs๋ ๊ธฐ๋ณธ๊ฐ์ ๊ฐ์ง ์ ์์ผ๋ฉฐ, attrs์ ์ผ๋ถ ์ ํ์๋ ์ ์ฝ ์กฐ๊ฑด์ด ์์ ์ ์์ต๋๋ค. ์ ์ฝ ์กฐ๊ฑด์ด ์๋ attr์ ์ ์ํ๋ ค๋ฉด, ๋ค์ <attr-type-expr>์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
{'<string1>', '<string2>'}: ๊ฐ์ <string1> ๋๋ <string2> ๊ฐ์ ๊ฐ์ง ๋ฌธ์์ด์ด์ด์ผ ํฉ๋๋ค. ์ด ๊ตฌ๋ฌธ์ ์ฌ์ฉํ๋ฉด ์ ํ string์ ์ด๋ฆ์ด ํฌํจ๋ฉ๋๋ค. ์ด๊ฑฐํ์ ์๋ฎฌ๋ ์ดํธํฉ๋๋ค.
{<type1>, <type2>}: ๊ฐ์ ์ ํ type์ด๋ฉฐ, <type1> ๋๋ <type2> ์ค ํ๋์ฌ์ผ ํฉ๋๋ค. <type1> ๋ฐ <type2>๋ tf.DType์ ์ง์ํฉ๋๋ค. attr์ ์ ํ์ด type์์ ์ง์ ํ์ง ์์์ต๋๋ค. {...}์ ์ ํ์ ๋ชฉ๋ก์ด ์์ ๋ ์์๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์ด ๊ฒฝ์ฐ attr t์ ์ ํ์ int32, float ๋๋ bool์ด์ด์ผ ํฉ๋๋ค.
๋ค์์ ์ผ๋ฐ์ ์ธ ์ ํ ์ ์ฝ ์กฐ๊ฑด์ ๋ํ ๋ฐ๋ก ๊ฐ๊ธฐ์ ๋๋ค.
numbertype: ์ ํtype์ ์ซ์(๋ฌธ์์ด๋ ๋ถ์ธ๋ ์๋) ์ ํ์ผ๋ก ์ ํ๋ฉ๋๋ค.realnumbertype: ๋ณต์กํ ์ ํ์ด ์๋numbertype๊ณผ ์ ์ฌํฉ๋๋ค.quantizedtype:numbertype๊ณผ ์ ์ฌํ์ง๋ง, ์์ํ๋ ์ซ์ ์ ํ๊ณผ ๊ฐ์ต๋๋ค.
์ด๋ค ์ ์ฝ ์กฐ๊ฑด์์ ํ์ฉ๋๋ ์ ํ์ ํน์ ๋ชฉ๋ก์tensorflow/core/framework/types.h์์ ํจ์(์: NumberTypes())๋ก ์ ์๋ฉ๋๋ค. ์ด ์์ ์์ attr t๋ ์ซ์ ์ ํ ์ค ํ๋์ฌ์ผ ํฉ๋๋ค.
๋ค์ op์ ๊ฒฝ์ฐ:
๋ชฉ๋ก์ ๋ค๋ฅธ ๋ชฉ๋ก ๋ฐ ๋จ์ผ ์ ํ๊ณผ ๊ฒฐํฉ๋ ์ ์์ต๋๋ค. ๋ค์ op์์๋ attr t๊ฐ ์ซ์ ์ ํ์ด๊ฑฐ๋ ๋ถ์ธ ์ ํ์ผ ์ ์์ต๋๋ค.
๋ค์ op์ ๊ฒฝ์ฐ:
int >= <n>: ๊ฐ์ <n>๋ณด๋ค ํฌ๊ฑฐ๋ ๊ฐ์ ์ ์์ฌ์ผ ํฉ๋๋ค. <n>๋ ์์ฐ์์
๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์ op ๋ฑ๋ก์์ attr a์ ๊ฐ์ 2 ์ด์์ด์ด์ผ ํจ์ ์ง์ ํฉ๋๋ค.
list(<type>) >= <n>: ๊ธธ์ด๊ฐ <n> ์ด์์ธ ์ ํ <type>์ ๋ชฉ๋ก์
๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์ op ๋ฑ๋ก์์ attr a์ ์ ํ (int32 ๋๋ float)์ ๋ชฉ๋ก์ด๋ฉฐ, ์ ์ด๋ 3๊ฐ ์ด์ ์์ด์ผ ํจ์ ์ง์ ํฉ๋๋ค.
attr์ ๊ธฐ๋ณธ๊ฐ์ ์ค์ ํ๋ ค๋ฉด(์์ฑ๋ ์ฝ๋์์ ์ ํ ์ฌํญ), ๋ค์๊ณผ ๊ฐ์ด ๋์ = <default>๋ฅผ ์ถ๊ฐํฉ๋๋ค.
๋ํ, ์ ์ฝ ์กฐ๊ฑด๊ณผ ๊ธฐ๋ณธ๊ฐ์ ๋ชจ๋ ์ง์ ํ ์ ์์ต๋๋ค.
์ง์๋๋ ๊ธฐ๋ณธ๊ฐ ๊ตฌ๋ฌธ์ ์ต์ข GraphDef ์ ์์ ํ๋กํ ํ์ ํํ์ ์ฌ์ฉ๋๋ ๊ตฌ๋ฌธ์ ๋๋ค.
๋ค์์ ๋ชจ๋ ์ ํ์ ๊ธฐ๋ณธ๊ฐ์ ์ง์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์์ ์ ๋๋ค.
ํนํ, ์ ํ type์ ๊ฐ์ tf.DType์ ์ฌ์ฉํฉ๋๋ค.
๋คํ์ฑ
์ ํ ๋คํ์ฑ
๋ค๋ฅธ ์ ํ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ฑฐ๋ ๋ค๋ฅธ ์ถ๋ ฅ ์ ํ์ ์์ฑํ ์ ์๋ op์ ๊ฒฝ์ฐ, op ๋ฑ๋ก์์ ์
๋ ฅ ๋๋ ์ถ๋ ฅ ์ ํ์ attr์ ์ง์ ํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ์ง์๋๋ ๊ฐ ์ ํ์ ๋ํด OpKernel์ ๋ฑ๋กํฉ๋๋ค.
์๋ฅผ ๋ค์ด, int32 ์ด์ธ์ float์ ๋ํด ZeroOut op๊ฐ ์๋ํ๊ฒ ํ๋ ค๋ฉด op ๋ฑ๋ก์ ๋ค์๊ณผ ๊ฐ์ ์ ์์ต๋๋ค.
op ๋ฑ๋ก์์ ์ด์ ์
๋ ฅ์ ์ ํ์ด float ๋๋ int32์ฌ์ผ ํจ์ ์ง์ ํฉ๋๋ค. ์
๋ ฅ๊ณผ ์ถ๋ ฅ ์ ํ์ด ๋ชจ๋ T์ด๋ฏ๋ก ์ถ๋ ฅ์ ์ ํ๋ ๊ฐ์ต๋๋ค.
๋ช ๋ช
์ ๋ ฅ, ์ถ๋ ฅ ๋ฐ attrs์๋ ์ผ๋ฐ์ ์ผ๋ก snake_case ์ด๋ฆ์ด ์ง์ ๋์ด์ผ ํฉ๋๋ค. ํ ๊ฐ์ง ์์ธ๋ ์ ๋ ฅ์ ์ ํ ๋๋ ์ถ๋ ฅ์ ์ ํ์ผ๋ก ์ฌ์ฉ๋๋ attrs์ ๋๋ค. ์ด๋ฌํ attrs๋ op๊ฐ ๊ทธ๋ํ์ ์ถ๊ฐ๋ ๋ ์ ์ถ๋ ์ ์์ผ๋ฏ๋ก op์ ํจ์์๋ ๋ํ๋์ง ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์ด ZeroOut์ ์ต์ข ์ ์๋ ๋ค์๊ณผ ๊ฐ์ Python ํจ์๋ฅผ ์์ฑํฉ๋๋ค.
to_zero์ int32 ํ
์๊ฐ ์ ๋ฌ๋๋ฉด, T๋ ์๋์ผ๋ก int32๋ก ์ค์ ๋ฉ๋๋ค(์ค์ ๋ก DT_INT32). ์ ์ถ๋ attrs์๋ ๋๋ฌธ์ ๋๋ CamelCase ์ด๋ฆ์ด ์ง์ ๋ฉ๋๋ค.
์ ์ถ๋ attrs๋ฅผ ์ถ๋ ฅ ์ ํ์ ๊ฒฐ์ ํ๋ ์ ํ attr์ด ์๋ op์ ๋น๊ตํฉ๋๋ค.
์ด ๊ฒฝ์ฐ, ์ฌ์ฉ์๋ ์์ฑ๋ Python์์์ ๊ฐ์ด ์ถ๋ ฅ ์ ํ์ ์ง์ ํด์ผ ํฉ๋๋ค.
์ ํ ๋คํ์ฑ ์์
์ด์ ๋ฒ์ ๊ณผ์ ํธํ์ฑ์ ์ ์งํ๋ ค๋ฉด, ๊ธฐ์กด op์ attr์ ์ถ๊ฐํ ๋ ๊ธฐ๋ณธ๊ฐ์ ์ง์ ํด์ผ ํฉ๋๋ค.
๋ ๋ง์ ์ ํ์ ์ถ๊ฐํ๊ณ ์ถ๋ค๊ณ ๊ฐ์ ํด ๋ด
์๋ค. ์: double
์์ ๊ฐ์ด ์ค๋ณต ์ฝ๋๋ก ๋ ๋ค๋ฅธ OpKernel์ ์์ฑํ๋ ๋์ , ์ข
์ข
C++ ํ
ํ๋ฆฟ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ค๋ฒ๋ก๋๋น ์ฌ์ ํ ํ๋์ ์ปค๋ ๋ฑ๋ก(REGISTER_KERNEL_BUILDER ํธ์ถ)์ด ์์ต๋๋ค.
์ค๋ฒ๋ก๋๊ฐ ๋ ๊ฐ ์ด์์ธ ๊ฒฝ์ฐ, ๋ฑ๋ก์ ๋งคํฌ๋ก์ ๋ฃ์ ์ ์์ต๋๋ค.
์ปค๋์ ๋ฑ๋กํ๋ ค๋ ์ ํ์ ๋ชฉ๋ก์ ๋ฐ๋ผ tensorflow/core/framework/register_types.h์์ ์ ๊ณต๋๋ ๋งคํฌ๋ก๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ๋ชฉ๋ก
๋ค์ํ ์ ํ์ ํ์ฉํ๊ฑฐ๋ ์์ฑํ ์ ์์ ๋ฟ๋ง ์๋๋ผ ops๋ ๋ค์ํ ๊ฐ์์ ํ ์๋ฅผ ์๋นํ๊ฑฐ๋ ์์ฑํ ์ ์์ต๋๋ค.
๋ค์ ์์ ์์, attr T๋ ์ ํ์ list๋ฅผ ๋ณด์ ํ๊ณ , ์๊ธฐ ์
๋ ฅ in๊ณผ ์ถ๋ ฅ out์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค. ์
๋ ฅ ๋ฐ ์ถ๋ ฅ์ ํด๋น ์ ํ์ ํ
์ ๋ชฉ๋ก์
๋๋ค(์ถ๋ ฅ์ ํ
์ ์์ ์ ํ์ ์
๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ ํ์ด ๋ชจ๋ T์ด๋ฏ๋ก ์
๋ ฅ๊ณผ ๊ฐ์ต๋๋ค).
๋ชฉ๋ก์์ ์ง์ ํ ์ ์๋ ์ ํ์ ์ ํ์ ๋ ์๋ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ, ์
๋ ฅ์ float ๋ฐ double ํ
์์ ๋ชฉ๋ก์
๋๋ค. op๋ ์๋ฅผ ๋ค์ด, ์
๋ ฅ ์ ํ (float, double, float)์ ํ์ฉํ๋ฉฐ, ์ด ๊ฒฝ์ฐ ์ถ๋ ฅ ์ ํ๋ (float, double, float)์
๋๋ค.
๋ชฉ๋ก์ ๋ชจ๋ ํ ์๊ฐ ๊ฐ์ ์ ํ์ด ๋๋๋ก ํ๋ ค๋ฉด, ๋ค์๊ณผ ๊ฐ์ด ํ ์ ์์ต๋๋ค.
int32 ํ
์์ ๋ชฉ๋ก์ ํ์ฉํ๊ณ int attr N์ ์ฌ์ฉํ์ฌ ๋ชฉ๋ก์ ๊ธธ์ด๋ฅผ ์ง์ ํฉ๋๋ค.
๋คํ ์ ํ์ผ๋ก ๋ง๋ค ์๋ ์์ต๋๋ค. ๋ค์ ์์ ์์, ์
๋ ฅ์ ์ ํ ("T")์ด ๊ฐ์ (ํ์ง๋ง ์ง์ ๋์ง๋ ์์) ํ
์(๊ธธ์ด "N")์ ๋ชฉ๋ก์ด๋ฉฐ, ์ถ๋ ฅ์ ์ผ์นํ๋ ์ ํ์ ๋จ์ผ ํ
์์
๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก, ํ
์ ๋ชฉ๋ก์ ์ต์ ๊ธธ์ด๋ 1์
๋๋ค. ํด๋น attr์ ๋ํ ">=" ์ ์ฝ ์กฐ๊ฑด์ ์ฌ์ฉํ์ฌ ํด๋น ๊ธฐ๋ณธ๊ฐ์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค. ๋ค์ ์์ ์์ ์
๋ ฅ์ int32 ํ
์๊ฐ 2๊ฐ ์ด์์ธ ๋ชฉ๋ก์
๋๋ค.
๊ฐ์ ๊ตฌ๋ฌธ์ด "list(type)" attrs์์ ์๋ํฉ๋๋ค.
์ ๋ ฅ ๋ฐ ์ถ๋ ฅ
์์ ๋ด์ฉ์ ์์ฝํ๋ฉด, op ๋ฑ๋ก์๋ ์ฌ๋ฌ ๊ฐ์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด ์์ ์ ์์ต๋๋ค.
๊ฐ ์ ๋ ฅ ๋๋ ์ถ๋ ฅ ์ฌ์์ ํ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
<name>์ ๋ฌธ์๋ก ์์ํ๋ฉฐ ์์ซ์์ ๋ฐ์ค๋ก ๊ตฌ์ฑ๋ ์ ์์ต๋๋ค. <io-type-expr>์ ๋ค์ ์ ํ ํํ์ ์ค์ ํ๋์
๋๋ค.
<type>,<type>์ ์ง์๋๋ ์ ๋ ฅ ์ ํ์ ๋๋ค(์:float,int32,string). ํน์ ์ ํ์ ๋จ์ผ ํ ์๋ฅผ ์ง์ ํฉ๋๋ค.tf.DType์ ์ฐธ์กฐํ์ธ์.<attr-type>,<attr-type>์ ์ ํ์ดtype๋๋list(type)(๊ฐ๋ฅํ ์ ํ ์ ํ์ด ์๋)์ธ Attr์ ์ด๋ฆ์ ๋๋ค. ์ด ๊ตฌ๋ฌธ์ ๋คํ ops๋ฅผ ํ์ฉํฉ๋๋ค.์ ํ์ด
list(type)์ธ attr์ ์ฐธ์กฐํ๋ฉด ํ ์ ์ํ์ค๋ฅผ ๋ฐ์๋ค์ผ ์ ์์ต๋๋ค.์ถ๋ ฅ
out์์ ํ ์์ ์ ๋ฐ ์ ํ์ ์ ๋ ฅin์์์ ๊ฐ์๋ฐ, ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ ํ์ด ๋ชจ๋T์ด๊ธฐ ๋๋ฌธ์ ๋๋ค.์ ํ์ด ๊ฐ์ ํ ์ ์ํ์ค์ ๊ฒฝ์ฐ:
<number>*<type>์์<number>๋ ์ ํ์ดint์ธ Attr์ ์ด๋ฆ์ ๋๋ค.<type>์tf.DType์ด๊ฑฐ๋ ์ ํ์ดtype์ธ attr์ ์ด๋ฆ์ ๋๋ค. ์ฒซ ๋ฒ์งธ์ ์๋ก, ์ด op๋int32ํ ์์ ๋ชฉ๋ก์ ํ์ฉํฉ๋๋ค.์ด op๋ ๋ชจ๋ ์ ํ์ ํ ์ ๋ชฉ๋ก์ ํ์ฉํ๋๋ฐ, ์ด๋ ํ ์์ ์ ํ์ ๋ชจ๋ ๊ฐ์ต๋๋ค.
ํ ์์ ๋ํ ์ฐธ์กฐ:
Ref(<type>),<type>์ ์ด์ ์ ํ ์ค์ ํ๋์ ๋๋ค.
์
๋ ฅ์ ์ ํ์ ์ฌ์ฉ๋ ๋ชจ๋ attr๊ฐ ์ ์ถ๋ฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ์ ์ถ๋ attr์ (T ๋๋ N๊ณผ ๊ฐ์) ๋๋ฌธ์ ์ด๋ฆ์ ์ฌ์ฉํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด, ์
๋ ฅ, ์ถ๋ ฅ ๋ฐ attr์ ์ด๋ฆ์ ํจ์ ๋งค๊ฐ๋ณ์(์: num_outputs)์ ๊ฐ์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ๋ช
๋ช
์ ๊ดํ ์ด์ ์น์
์ ์ฐธ์กฐํ์ธ์.
์์ธํ ๋ด์ฉ์ tensorflow/core/framework/op_def_builder.h๋ฅผ ์ฐธ์กฐํ์ธ์.
์ด์ ๋ฒ์ ๊ณผ์ ํธํ์ฑ
๋ฉ์ง ์ฌ์ฉ์ ์ง์ op๋ฅผ ์์ฑํ๊ณ ๋ค๋ฅธ ์ฌ์ฉ์์ ๊ณต์ ํ๋ค๊ณ ๊ฐ์ ํ์ฌ ์ฐ์ฐ์ ์ฌ์ฉํ๋ ํ๋ณตํ ๊ณ ๊ฐ์ด ์์ต๋๋ค. ๊ทธ๋ฌ๋ op๋ฅผ ๋ณ๊ฒฝํ๊ณ ์ถ์ต๋๋ค.
์ผ๋ฐ์ ์ผ๋ก, ๊ธฐ์กด์ ํ์ธ๋(checked-in) ์ฌ์์ ๋ํ ๋ณ๊ฒฝ ์ฌํญ์ ์ด์ ๋ฒ์ ๊ณผ ํธํ๋์ด์ผ ํฉ๋๋ค. op์ ์ฌ์์ ๋ณ๊ฒฝํ ํ ์ด์ ์ฌ์์์ ์์ฑ๋ ์ด์ ์ ์ง๋ ฌํ๋ GraphDef ํ๋กํ ์ฝ ๋ฒํผ๊ฐ ์์๋๋ฉด ์ ๋ฉ๋๋ค. GraphDef ํธํ์ฑ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ฌ๊ธฐ์ ์ค๋ช
๋์ด ์์ต๋๋ค.
์ด์ ๋ฒ์ ๊ณผ์ ํธํ์ฑ์ ์ ์งํ๋ ๋ช ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
์ฐ์ฐ์ ์ถ๊ฐ๋ ์ attrs์๋ ๊ธฐ๋ณธ๊ฐ์ด ์ ์๋์ด ์์ด์ผ ํ๋ฉฐ, ํด๋น ๊ธฐ๋ณธ๊ฐ์ ๊ฐ์ง op๋ ์๋ ๋์์ด ์์ด์ผ ํฉ๋๋ค. ๋คํ์ด ์๋ ์ฐ์ฐ์์ ๋คํ ์ฐ์ฐ์ผ๋ก ๋ณ๊ฒฝํ๋ ค๋ฉด, ๊ธฐ๋ณธ์ ์ผ๋ก ์๋ ์๋ช ์ ์ ์งํ๊ธฐ ์ํด ์ ์ ํ attr์ ๊ธฐ๋ณธ๊ฐ์ ์ง์ ํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์ฐ์ฐ์ด ๋ค์๊ณผ ๊ฐ์ ๊ฒฝ์ฐ,
๋ค์์ ์ฌ์ฉํ์ฌ ์ด์ ๋ฒ์ ๊ณผ ํธํ๋๋ ๋คํ ์ฐ์ฐ์ผ๋ก ๋ง๋ค ์ ์์ต๋๋ค.
attr์ ๋ํ ์ ์ฝ ์กฐ๊ฑด์ ๋ ์ ํ์ ์ผ๋ก ์์ ํ๊ฒ ๋ง๋ค ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด,
{int32, int64}์์{int32, int64, float}๋๋type๋ก ๋ณ๊ฒฝํ ์ ์์ต๋๋ค. ๋๋{"apple", "orange"}์์{"apple", "banana", "orange"}๋๋string๋ก ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.๋ชฉ๋ก ์ ํ์ ๊ธฐ๋ณธ๊ฐ์ด ์ด์ ์๋ช ๊ณผ ์ผ์นํ๋ ํ ๋จ์ผ ์ ๋ ฅ/์ถ๋ ฅ์ ๋ชฉ๋ก ์ ๋ ฅ/์ถ๋ ฅ์ผ๋ก ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
๊ธฐ๋ณธ๊ฐ์ด ๋น์ด ์์ผ๋ฉด ์ ๋ชฉ๋ก ์ ๋ ฅ/์ถ๋ ฅ์ ์ถ๊ฐํ ์ ์์ต๋๋ค.
op ์ด๋ฆ ์์ ํ๋ก์ ํธ ๊ณ ์ ์ ์ด๋ฆ์ ๋ถ์ฌ์ ์์ฑํ๋ ๋ชจ๋ ์๋ก์ด ops์ ๋ค์์คํ์ด์ค๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ์ดํ ๋ฒ์ ์ TensorFlow์ ํฌํจ๋ ์ ์๋ ops์ ํด๋น op๊ฐ ์ถฉ๋ํ์ง ์์ต๋๋ค.
๋ฏธ๋ฆฌ ๊ณํํ์ธ์! op์ ํฅํ ์ฉ๋๋ฅผ ์์ํฉ๋๋ค. ์๋ช ์ ์ผ๋ถ ๋ณ๊ฒฝํ๋ ๊ฒ์ ํธํ ๊ฐ๋ฅํ ๋ฐฉ์์ผ๋ก ์ํํ ์ ์์ต๋๋ค(์: ๊ฐ์ ์ ํ์ ๋ชฉ๋ก์ ๋ค์ํ ์ ํ์ ๋ชฉ๋ก์ผ๋ก ๋ง๋ค๊ธฐ).
์์ ํ๊ฑฐ๋ ์์ ํ์ง ์์ ๋ณ๊ฒฝ ์ฌํญ์ ์ ์ฒด ๋ชฉ๋ก์ tensorflow/core/framework/op_compatibility_test.cc ์์ ์ฐพ์ ์ ์์ต๋๋ค. ์ด์ ๋ฒ์ ๊ณผ ํธํ๋๋๋ก ์ฐ์ฐ์ ๋ณ๊ฒฝํ ์ ์๋ ๊ฒฝ์ฐ, ์ ์๋ฏธ ์ฒด๊ณ๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ด๋ฆ์ผ๋ก ์ ์ฐ์ฐ์ ๋ง๋ญ๋๋ค.
๋ํ, ์ด๋ฌํ ๋ณ๊ฒฝ ์ฌํญ์ GraphDef ํธํ์ฑ์ ์ ์งํ ์ ์์ง๋ง, ์์ฑ๋ Python ์ฝ๋๋ ์ด์ ํธ์ถ์์ ํธํ๋์ง ์๋ ๋ฐฉ์์ผ๋ก ๋ณ๊ฒฝ๋ ์ ์์ต๋๋ค. Python API๋ ์๋ก์ด ์ ํ์ ์ธ์๋ฅผ ๋์ ์ถ๊ฐํ๋ ๊ฒ์ ์ ์ธํ๊ณ ์ด์ ์๋ช
์ ์ ์งํจ์ผ๋ก์จ ์์ผ๋ก ์์ฑํ Python ๋ํผ๋ฅผ ์ ์คํ๊ฒ ๋ณ๊ฒฝํ์ฌ ํธํ์ฑ์ ์ ์งํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ํธํ๋์ง ์๋ ๋ณ๊ฒฝ ์ฌํญ์ TensorFlow์ ์ฃผ์ ๋ฒ์ ์ด ๋ณ๊ฒฝ๋ ๋๋ง ์ํ๋ ์ ์์ผ๋ฉฐ GraphDef๋ฒ์ ์๋ฏธ ์ฒด๊ณ๋ฅผ ์ค์ํด์ผ ํฉ๋๋ค.
GPU ์ง์
์๋ก ๋ค๋ฅธ ์ ํ์ ์ปค๋์ ๋ฑ๋กํ๋ ๊ฒ์ฒ๋ผ ์๋ก ๋ค๋ฅธ OpKernel์ ๊ตฌํํ๊ณ CPU ๋ฐ GPU์ฉ ์ปค๋์ ๊ฐ๊ฐ ๋ฑ๋กํ ์ ์์ต๋๋ค. tensorflow/core/kernels/์ GPU๋ฅผ ์ง์ํ๋ ์ปค๋์ ๋ช ๊ฐ์ง ์๊ฐ ์์ต๋๋ค. ์ผ๋ถ ์ปค๋์๋ .cc ํ์ผ์ CPU ๋ฒ์ , _gpu.cu.cc๋ก ๋๋๋ ํ์ผ์ GPU ๋ฒ์ ๋ฐ .h ํ์ผ์์ ๊ณตํต์ผ๋ก ๊ณต์ ๋๋ ์ฝ๋๊ฐ ์์ต๋๋ค.
์๋ฅผ ๋ค์ด, tf.pad๋ tensorflow/core/kernels/pad_op.cc์ GPU ์ปค๋์ ์ ์ธํ ๋ชจ๋ ๊ฒ์ด ์์ต๋๋ค. GPU ์ปค๋์ tensorflow/core/kernels/pad_op_gpu.cu.cc์ ์์ผ๋ฉฐ, ๊ณต์ ์ฝ๋๋ tensorflow/core/kernels/pad_op.h์ ์ ์๋ ํ
ํ๋ฆฟ ํ์์ ํด๋์ค์
๋๋ค. ์ฝ๋๋ฅผ ์ด ๋ฐฉ์์ผ๋ก ๊ตฌ์ฑํ๋ ๋ฐ๋ ๋ ๊ฐ์ง ์ด์ ๊ฐ ์์ต๋๋ค. CPU์ GPU ๊ตฌํ ๊ฐ์ ๊ณตํต ์ฝ๋๋ฅผ ๊ณต์ ํ ์ ์์ผ๋ฉฐ GPU ๊ตฌํ์ ๋ณ๋์ ํ์ผ์ ๋ฃ์ด GPU ์ปดํ์ผ๋ฌ๋ก๋ง ์ปดํ์ผํ ์ ์์ต๋๋ค.
pad์ GPU ์ปค๋ ๋ฒ์ ์ ์ฌ์ฉํ๋๋ผ๋ CPU ๋ฉ๋ชจ๋ฆฌ์ ์ฌ์ ํ "paddings" ์
๋ ฅ์ด ํ์ํฉ๋๋ค. ์
๋ ฅ ๋๋ ์ถ๋ ฅ์ด CPU์์ ์ ์ง๋๋ค๋ ๊ฒ์ ํ์ํ๋ ค๋ฉด, ์ปค๋ ๋ฑ๋ก์ HostMemory() ํธ์ถ์ ์ถ๊ฐํฉ๋๋ค. ์๋ฅผ ๋ค๋ฉด, ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
GPU ๊ธฐ๊ธฐ์ฉ ์ปค๋ ์ปดํ์ผํ๊ธฐ
CUDA ์ปค๋์ ์ฌ์ฉํ์ฌ op๋ฅผ ๊ตฌํํ๋ ์๋ cuda_op_kernel.cu.cc๋ฅผ ์ฐธ์กฐํ์ธ์. tf_custom_op_library์ CUDA ์ปค๋(*.cu.cc ํ์ผ)์ ํฌํจํ๋ ์์ค ํ์ผ์ ๋ชฉ๋ก์ ์ง์ ํ ์์๋ gpu_srcs ์ธ์๋ฅผ ํ์ฉํฉ๋๋ค. TensorFlow์ ๋ฐ์ด๋๋ฆฌ ์ค์น์์ ์ฌ์ฉํ๋ ค๋ฉด, CUDA ์ปค๋์ NVIDIA์ nvcc ์ปดํ์ผ๋ฌ๋ก ์ปดํ์ผํด์ผ ํฉ๋๋ค. ๋ค์์ cuda_op_kernel.cu.cc ๋ฐ cuda_op_kernel.cc๋ฅผ ๋์ ์ผ๋ก ๋ก๋ ๊ฐ๋ฅํ ๋จ์ผ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์ปดํ์ผํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๋ช
๋ น ์ํ์ค์
๋๋ค.
์์์ ์์ฑ๋ cuda_op_kernel.so๋ tf.load_op_library ํจ์๋ฅผ ์ฌ์ฉํ์ฌ Python์์ ํ์์ ๊ฐ์ด ๋ก๋ํ ์ ์์ต๋๋ค.
CUDA ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ /usr/local/lib64์ ์ค์น๋์ง ์์ ๊ฒฝ์ฐ, ์์ ๋ ๋ฒ์งธ(g++) ๋ช
๋ น์์ ๊ฒฝ๋ก๋ฅผ ๋ช
์์ ์ผ๋ก ์ง์ ํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, CUDA๊ฐ /usr/local/cuda-8.0์ ์ค์น๋์ด ์๋ ๊ฒฝ์ฐ, -L /usr/local/cuda-8.0/lib64/๋ฅผ ์ถ๊ฐํฉ๋๋ค.
์ฐธ๊ณ : ์ผ๋ถ Linux ์ค์ ์์๋ nvcc ์ปดํ์ผ ๋จ๊ณ์ ๋ํ ์ถ๊ฐ ์ต์
์ด ํ์ํฉ๋๋ค. -D_MWAITXINTRIN_H_INCLUDED๋ฅผ nvcc ๋ช
๋ น์ค์ ์ถ๊ฐํ์ฌ mwaitxintrin.h์ ์ค๋ฅ๋ฅผ ๋ฐฉ์งํฉ๋๋ค.
Python์์ ๊ทธ๋๋์ธํธ ๊ตฌํํ๊ธฐ
ops์ ๊ทธ๋ํ์์ TensorFlow๋ ์๋ ๋ฏธ๋ถ(์ญ์ ํ)์ ์ฌ์ฉํ์ฌ ๊ธฐ์กด op์ ๋ํ ๊ทธ๋๋์ธํธ๋ฅผ ๋ํ๋ด๋ ์ ops๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์๋ก์ด ops์ ๋ํด ์๋ ๋ฏธ๋ถ์ ์ํํ๋ ค๋ฉด, ops์ ์ถ๋ ฅ์ ๋ํ ๊ทธ๋๋์ธํธ๊ฐ ์ง์ ๋ ops์ ์ ๋ ฅ์ ๋ํ ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ๋ ๊ทธ๋๋์ธํธ ํจ์๋ฅผ ๋ฑ๋กํด์ผ ํฉ๋๋ค.
์ํ์ ์ผ๋ก, op๊ฐ (y = f(x))๋ฅผ ๊ณ์ฐํ๋ ๊ฒฝ์ฐ, ๋ฑ๋ก๋ ๊ทธ๋๋์ธํธ op๋ (y)์ ๋ํ ์์ค (L)์ ๊ทธ๋๋์ธํธ (\partial L/ \partial y)๋ฅผ ์ฐ์ ๊ท์น์ ํตํด (x)์ ๋ํ ๊ทธ๋๋์ธํธ (\partial L/ \ partial x)๋ก ๋ณํํฉ๋๋ค.
ZeroOut์ ๊ฒฝ์ฐ, ์
๋ ฅ์ ํ ํญ๋ชฉ๋ง ์ถ๋ ฅ์ ์ํฅ์ ๋ฏธ์น๋ฏ๋ก ์
๋ ฅ์ ๋ํ ๊ทธ๋๋์ธํธ๋ "์-ํซ" ํฌ์ ํ
์์
๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ํํ๋ฉ๋๋ค.
tf.RegisterGradient๋ก ๊ทธ๋๋์ธํธ ํจ์๋ฅผ ๋ฑ๋กํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ธ๋ถ ์ฌํญ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์ถ๋ ฅ์ด ํ๋์ธ op์ ๊ฒฝ์ฐ, ๊ทธ๋๋์ธํธ ํจ์๋
tf.Operation,op๋ฐtf.Tensorgrad๋ฅผ ์ฌ์ฉํ๊ณ ํ ์op.inputs[i],op.outputs[i]๋ฐgrad์์ ์ ops๋ฅผ ๋น๋ํฉ๋๋ค. ๋ชจ๋ attrs์ ๋ํ ์ ๋ณด๋tf.Operation.get_attr์ ํตํด ์ฐพ์ ์ ์์ต๋๋ค.์ถ๋ ฅ์ด ์ฌ๋ฌ ๊ฐ์ธ op์ธ ๊ฒฝ์ฐ, ๊ทธ๋๋์ธํธ ํจ์๋
op๋ฐgrads๋ฅผ ์ฌ์ฉํ๊ณ , ์ด๋grads๋ ๊ฐ ์ถ๋ ฅ์ ๋ํ ๊ทธ๋๋์ธํธ์ ๋ชฉ๋ก์ ๋๋ค. ๊ทธ๋๋์ธํธ ํจ์์ ๊ฒฐ๊ณผ๋ ๊ฐ ์ ๋ ฅ์ ๋ํ ๊ทธ๋๋์ธํธ๋ฅผ ๋ํ๋ด๋Tensor๊ฐ์ฒด์ ๋ชฉ๋ก์ด์ด์ผ ํฉ๋๋ค.์ธ๋ฑ์ค๋ก ์ฌ์ฉ๋๋ ์ ์ ์ ๋ ฅ๊ณผ ๊ฐ์ด ์ผ๋ถ ์ ๋ ฅ์ ๋ํด ์ ์ ์๋ ๊ทธ๋๋์ธํธ๊ฐ ์๋ ๊ฒฝ์ฐ, ๋ฐํ๋๋ ํด๋น ๊ทธ๋๋์ธํธ๋
None์ด์ด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ถ๋ ์์์ ํ ์x๋ฐ ์ ์ ์ธ๋ฑ์คi๋ฅผ ์ฌ์ฉํ๋ op์ ๊ฒฝ์ฐ, ๊ทธ๋๋์ธํธ ํจ์๋[x_grad, None]๋ฅผ ๋ฐํํฉ๋๋ค.op์ ์๋ฏธ ์๋ ๊ทธ๋๋์ธํธ๊ฐ ์๋ ๊ฒฝ์ฐ, ๊ทธ๋๋์ธํธ๋ฅผ ๋ฑ๋กํ ํ์๊ฐ ์์ผ๋ฉฐ, op์ ๊ทธ๋๋์ธํธ๊ฐ ํ์ํ์ง ์์ ํ ๋ฌธ์ ์์ต๋๋ค. ๊ฒฝ์ฐ์ ๋ฐ๋ผ op์ ์ ์ ์๋ ๊ทธ๋๋์ธํธ๊ฐ ์์ด๋ ๊ทธ๋๋์ธํธ ๊ณ์ฐ์ ๊ด์ฌํ ์ ์์ต๋๋ค. ์ด๋
ops.NotDifferentiable์ ์ฌ์ฉํ์ฌ ์๋์ผ๋ก 0์ ๋ค๋ก ์ ํํ ์ ์์ต๋๋ค.
๊ทธ๋๋์ธํธ ํจ์๊ฐ ํธ์ถ๋ ๋ ํ ์ ๋ฐ์ดํฐ ์์ฒด๊ฐ ์๋๋ผ ops์ ๋ฐ์ดํฐ ํ๋ฆ ๊ทธ๋ํ๋ง ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์, ๋ชจ๋ ๊ณ์ฐ์ ๊ทธ๋ํ ์คํ ์๊ฐ์ ์คํ๋๋๋ก ๋ค๋ฅธ tensorflow ops๋ฅผ ์ฌ์ฉํ์ฌ ์ํํด์ผ ํฉ๋๋ค.
op ์ ํ์ ๋ํ ์ฌ์ฉ์ ์ ์ ๊ทธ๋๋์ธํธ๋ฅผ ๋ฑ๋กํ ๋ ์ ํ ํํธ๋ฅผ ์ถ๊ฐํ๋ฉด ๋ฐ์ดํฐ ์ ํจ์ฑ ๊ฒ์ฌ๋ฅผ ํตํด ์ฝ๋์ ๊ฐ๋
์ฑ, ๋๋ฒ๊น
๊ฐ๋ฅ์ฑ, ์ ์ง ๊ด๋ฆฌ ์ฉ์ด์ฑ ๋ฐ ๊ฒฌ๊ณ ์ฑ์ ๋์ผ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ํจ์์์ op๋ฅผ ๋งค๊ฐ๋ณ์๋ก ์ฌ์ฉํ ๋ ๊ทธ๋๋์ธํธ ํจ์๊ฐ tf.Operation์ ๋งค๊ฐ๋ณ์ ์ ํ์ผ๋ก ์ฌ์ฉํ๋๋ก ์ง์ ํฉ๋๋ค.
C++์ ํ์ ํจ์
TensorFlow API์ "๋ํ ์ ์ถ"๋ผ๋ ํน์ฑ์ด ์์ด ๊ทธ๋ํ๋ฅผ ์คํํ์ง ์๊ณ ๋ ํ
์ ๋ํ์ ๋ํ ์ ๋ณด๋ฅผ ์ ๊ณตํฉ๋๋ค. ๋ํ ์ ์ถ๋ C++ REGISTER_OP ์ ์ธ์์ ๊ฐ op ์ ํ์ ๋ฑ๋ก๋ "๋ํ ํจ์"์ ์ํด ์ง์๋๋ฉฐ ๋ ๊ฐ์ง ์ญํ ์ ์ํํฉ๋๋ค. ์
๋ ฅ์ ๋ํ์ด ๊ทธ๋ํ ์์ฑ ์ค์ ํธํ๋๋์ง ํ์ธํ๊ณ ์ถ๋ ฅ์ ๋ํ์ ์ง์ ํฉ๋๋ค.
ํ์ ํจ์๋ shape_inference::InferenceContext ํด๋์ค์ ๋ํ ์ฐ์ฐ์ผ๋ก ์ ์๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ZeroOut์ ํ์ ํจ์์์
c->set_output (0, c->input (0));์ ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ์ ํ์์ด ์ฒซ ๋ฒ์งธ ์
๋ ฅ์ ํ์์ผ๋ก ์ค์ ๋์ด์ผ ํจ์ ์ ์ธํฉ๋๋ค. ์์ ์์ ์์์ ๊ฐ์ด ์ธ๋ฑ์ค์ ์ํด ์ถ๋ ฅ์ด ์ ํ๋ ๊ฒฝ์ฐ, set_output์ ๋ ๋ฒ์งธ ๋งค๊ฐ๋ณ์๋ ShapeHandle ๊ฐ์ฒด์ฌ์ผ ํฉ๋๋ค. ๊ธฐ๋ณธ ์์ฑ์๋ก ๋น ShapeHandle ๊ฐ์ฒด๋ฅผ ๋ง๋ค ์ ์์ต๋๋ค. ์ธ๋ฑ์ค idx๋ฅผ ๊ฐ์ง ์
๋ ฅ์ ๋ํ ShapeHandle ๊ฐ์ฒด๋ c->input(idx)๋ก ๊ตฌํ ์ ์์ต๋๋ค.
shape_inference::UnchangedShape์ ๊ฐ์ด ๋ง์ ops์ ์ ์ฉ๋๋ ๊ณตํต ํ์ ํจ์๊ฐ ์ฌ๋ฌ ๊ฐ ์์ผ๋ฉฐ, common_shape_fns.h์์ ์ฐพ์ ์ ์๊ณ , ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉ๋ฉ๋๋ค.
ํ์ ํจ์๋ ์
๋ ฅ์ ํ์์ ์ ํํ ์๋ ์์ต๋๋ค. ๋ฒกํฐ ํ์ ์ ์ฝ ์กฐ๊ฑด์ด์๋ ZeroOut์ ๊ฒฝ์ฐ, ํ์ ํจ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
WithRank ํธ์ถ์ ์
๋ ฅ ํ์ c->input(0) ์ด ์ ํํ 1์ฐจ์์ ํ์์ธ์ง ํ์ธํฉ๋๋ค(๋๋ ์
๋ ฅ ํ์์ ์ ์ ์๋ ๊ฒฝ์ฐ, ์ถ๋ ฅ ํ์์ ์ ์ ์๋ 1์ฐจ์์ ๋ฒกํฐ๊ฐ ๋จ).
์
๋ ฅ์ด ์ฌ๋ฌ ๊ฐ์ธ ๋คํ op์ธ ๊ฒฝ์ฐ, InferenceContext์ ๋ฉค๋ฒ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฒ์ฌํ ํ์์ ์๋ฅผ ๊ฒฐ์ ํ๊ณ Merge์ ๋ฉค๋ฒ๋ฅผ ์ฌ์ฉํ์ฌ ํ์์ด ๋ชจ๋ ํธํ๋๋์ง ํ์ธํฉ๋๋ค(๋๋ ๊ธธ์ด๋ฅผ ๋ํ๋ด๋ ์ก์ธ์ค ์์ฑ๊ณผ op์ ์์ฑ์ ๋ํ ์ก์ธ์ค๋ฅผ ์ ๊ณตํ๋ InferenceContext::GetAttr).
ํ์ ์ ์ถ๋ ์ ํ์ ์ธ ํน์ฑ์ด๋ฉฐ ํ
์์ ํ์์ ๋์ ์ผ๋ก ๋ณํ ์ ์์ผ๋ฏ๋ก ํ์ ํจ์๋ ๋ชจ๋ ์
๋ ฅ์ ๋ถ์์ ํ ํ์ ์ ๋ณด์ ๋ํด ๊ฒฌ๊ณ ํด์ผ ํฉ๋๋ค. InferenceContext์ Merge ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋ฉด ๋ ๊ฐ์ง ํ์ ์ค ํ๋ ๋๋ ๋ ๋ค์ ์์ ํ ์ ๋ณด๊ฐ ์๋ ๊ฒฝ์ฐ์๋ ํธ์ถ์๊ฐ ๋ ํ์์ด ๊ฐ์์ ํ์ธํ ์ ์์ต๋๋ค. ํ์ ํจ์๋ ๋ชจ๋ ํต์ฌ TensorFlow ops์ ๋ํด ์ ์๋๋ฉฐ ๋ค์ํ ์ฌ์ฉ ์๋ฅผ ์ ๊ณตํฉ๋๋ค.
InferenceContext ํด๋์ค์๋ ํ์ ํจ์ ์กฐ์์ ์ ์ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๋ง์ ํจ์๊ฐ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ํน์ ์ฐจ์์ InferenceContext::Dim ๋ฐInferenceContext::WithValue๋ฅผ ์ฌ์ฉํ๋ ๋งค์ฐ ํน์ ํ ๊ฐ์ด ์๋์ง ํ์ธํ๊ณ , ์ถ๋ ฅ ์ฐจ์์ด InferenceContext::Add ๋ฐ InferenceContext::Multiply๋ฅผ ์ฌ์ฉํ๋ ๋ ์
๋ ฅ ์ฐจ์์ ํฉ/๊ณฑ์์ ์ง์ ํ ์ ์์ต๋๋ค. ์ง์ ํ ์ ์๋ ๋ค์ํ ํ์ ์กฐ์์ ๋ํด์๋ InferenceContext ํด๋์ค๋ฅผ ์ฐธ์กฐํ์ธ์. ๋ค์ ์์ ๋ ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ์ ํ์์ (n, 3)์ผ๋ก ์ค์ ํฉ๋๋ค. ์ฌ๊ธฐ์์ ์ฒซ ๋ฒ์งธ ์
๋ ฅ์ ํ์์ (n, ...)์
๋๋ค.
๋ณต์กํ ํ์ ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ, ๋ค์ํ ์
๋ ฅ ํ์์ ์กฐํฉ์ด ์์๋๋ ์ถ๋ ฅํ์์ ์กฐํฉ์ ์์ฑํ๋์ง ํ์ธํ๋ ํ
์คํธ๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ผ๋ถ ํต์ฌ ops ํ
์คํธ์์ ํ
์คํธ๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์์ ๋ฅผ ๋ณผ ์ ์์ต๋๋ค. (INFER_OK ๋ฐ INFER_ERROR์ ๊ตฌ๋ฌธ์ด ์ฝ๊ฐ ๊น๋ค๋กญ์ง๋ง, ํ
์คํธ์์ ์
๋ ฅ ๋ฐ ์ถ๋ ฅ ํ์ ์ฌ์์ ๊ฐ๊ฒฐํ๊ฒ ํํํ์ธ์. ์ง๊ธ์ ํด๋น ํ
์คํธ์ ์ฃผ๋ณ ์ฃผ์์ ์ฐธ์กฐํ์ฌ ํ์ ๋ฌธ์์ด ์ฌ์์ ์ดํดํ์ธ์.)
์ฌ์ฉ์ ์ ์ op์ฉ pip ํจํค์ง ๋น๋ํ๊ธฐ
op์ ๋ํ pip ํจํค์ง๋ฅผ ๋น๋ํ๋ ค๋ฉด, tensorflow/custom-op ์์ ๋ฅผ ์ฐธ์กฐํ์ธ์. ์ด ๊ฐ์ด๋๋ ์์ค์์ TensorFlow๋ฅผ ๋น๋ํ๋ ๋์ TensorFlow pip ํจํค์ง์์ ์ฌ์ฉ์ ์ ์ op๋ฅผ ๋น๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.