Path: blob/main/transformers_doc/ko/pytorch/summarization.ipynb
8328 views
์์ฝ[[summarization]]
์์ฝ์ ๋ฌธ์๋ ๊ธฐ์ฌ์์ ์ค์ํ ์ ๋ณด๋ฅผ ๋ชจ๋ ํฌํจํ๋ ์งง๊ฒ ๋ง๋๋ ์ผ์ ๋๋ค. ๋ฒ์ญ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก, ์ํ์ค-ํฌ-์ํ์ค ๋ฌธ์ ๋ก ๊ตฌ์ฑํ ์ ์๋ ๋ํ์ ์ธ ์์ ์ค ํ๋์ ๋๋ค. ์์ฝ์๋ ์๋์ ๊ฐ์ด ์ ํ์ด ์์ต๋๋ค:
์ถ์ถ(Extractive) ์์ฝ: ๋ฌธ์์์ ๊ฐ์ฅ ๊ด๋ จ์ฑ ๋์ ์ ๋ณด๋ฅผ ์ถ์ถํฉ๋๋ค.
์์ฑ(Abstractive) ์์ฝ: ๊ฐ์ฅ ๊ด๋ จ์ฑ ๋์ ์ ๋ณด๋ฅผ ํฌ์ฐฉํด๋ด๋ ์๋ก์ด ํ ์คํธ๋ฅผ ์์ฑํฉ๋๋ค.
์ด ๊ฐ์ด๋์์ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค:
์์ฑ ์์ฝ์ ์ํ BillSum ๋ฐ์ดํฐ์ ์ค ์บ๋ฆฌํฌ๋์ ์ฃผ ๋ฒ์ ํ์ ์งํฉ์ผ๋ก T5๋ฅผ ํ์ธํ๋ํฉ๋๋ค.
ํ์ธํ๋๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํฉ๋๋ค.
์ด ์์ ๊ณผ ํธํ๋๋ ๋ชจ๋ ์ํคํ ์ฒ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณด๋ ค๋ฉด ์์ ํ์ด์ง๋ฅผ ํ์ธํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๋ชจ๋ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์.
BillSum ๋ฐ์ดํฐ์ ๊ฐ์ ธ์ค๊ธฐ[[load-billsum-dataset]]
๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ BillSum ๋ฐ์ดํฐ์ ์ ์์ ๋ฒ์ ์ธ ์บ๋ฆฌํฌ๋์ ์ฃผ ๋ฒ์ ํ์ ์งํฉ์ ๊ฐ์ ธ์ค์ธ์:
train_test_split ๋ฉ์๋๋ก ๋ฐ์ดํฐ์
์ ํ์ต์ฉ์ ํ
์คํธ์ฉ์ผ๋ก ๋๋์ธ์:
๊ทธ๋ฐ ๋ค์ ์์๋ฅผ ํ๋ ์ดํด๋ณด์ธ์:
์ฌ๊ธฐ์ ๋ค์ ๋ ๊ฐ์ ํ๋๋ฅผ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค:
text: ๋ชจ๋ธ์ ์ ๋ ฅ์ด ๋ ๋ฒ์ ํ ์คํธ์ ๋๋ค.summary:text์ ๊ฐ๋ตํ ๋ฒ์ ์ผ๋ก ๋ชจ๋ธ์ ํ๊ฒ์ด ๋ฉ๋๋ค.
์ ์ฒ๋ฆฌ[[preprocess]]
๋ค์์ผ๋ก text์ summary๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ T5 ํ ํฌ๋์ด์ ๋ฅผ ๊ฐ์ ธ์ต๋๋ค:
์์ฑํ๋ ค๋ ์ ์ฒ๋ฆฌ ํจ์๋ ์๋ ์กฐ๊ฑด์ ๋ง์กฑํด์ผ ํฉ๋๋ค:
์ ๋ ฅ ์์ ํ๋กฌํํธ๋ฅผ ๋ถ์ฌ T5๊ฐ ์์ฝ ์์ ์์ ์ธ์ํ ์ ์๋๋ก ํฉ๋๋ค. ์ฌ๋ฌ NLP ์์ ์ ์ํํ ์ ์๋ ์ผ๋ถ ๋ชจ๋ธ์ ํน์ ์์ ์ ๋ํ ํ๋กฌํํธ๊ฐ ํ์ํฉ๋๋ค.
๋ ์ด๋ธ์ ํ ํฐํํ ๋
text_target์ธ์๋ฅผ ์ฌ์ฉํฉ๋๋ค.max_length๋งค๊ฐ๋ณ์๋ก ์ค์ ๋ ์ต๋ ๊ธธ์ด๋ฅผ ๋์ง ์๋๋ก ๊ธด ์ํ์ค๋ฅผ ์๋ผ๋ ๋๋ค.
์ ์ฒด ๋ฐ์ดํฐ์
์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ map ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ธ์. batched=True๋ก ์ค์ ํ์ฌ ๋ฐ์ดํฐ์
์ ์ฌ๋ฌ ์์๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํ๋ฉด map ํจ์์ ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค.
์ด์ DataCollatorForSeq2Seq๋ฅผ ์ฌ์ฉํ์ฌ ์์ ๋ฐฐ์น๋ฅผ ๋ง๋์ธ์. ์ ์ฒด ๋ฐ์ดํฐ์ ์ ์ต๋ ๊ธธ์ด๋ก ํจ๋ฉํ๋ ๊ฒ๋ณด๋ค ๋ฐฐ์น๋ง๋ค ๊ฐ์ฅ ๊ธด ๋ฌธ์ฅ ๊ธธ์ด์ ๋ง์ถฐ ๋์ ํจ๋ฉํ๋ ๊ฒ์ด ๋ ํจ์จ์ ์ ๋๋ค.
ํ๊ฐ[[evaluate]]
ํ์ต ์ค์ ํ๊ฐ ์งํ๋ฅผ ํฌํจํ๋ฉด ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๋ฐ ๋์์ด ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๐ค Evaluate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ํ๊ฐ ๋ฐฉ๋ฒ์ ๋น ๋ฅด๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ์ด ์์ ์์๋ ROUGE ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ต๋๋ค. (ํ๊ฐ ์งํ๋ฅผ ๋ถ๋ฌ์ค๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๐ค Evaluate ๋๋ฌ๋ณด๊ธฐ๋ฅผ ์ฐธ์กฐํ์ธ์.)
๊ทธ๋ฐ ๋ค์ ์์ธก๊ฐ๊ณผ ๋ ์ด๋ธ์ compute์ ์ ๋ฌํ์ฌ ROUGE ์งํ๋ฅผ ๊ณ์ฐํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค:
์ด์ compute_metrics ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ผ๋ฉฐ, ํ์ต์ ์ค์ ํ ๋ ์ด ํจ์๋ก ๋๋์์ฌ ๊ฒ์
๋๋ค.
ํ์ต[[train]]
๋ชจ๋ธ์ Trainer๋ก ํ์ธํ๋ ํ๋ ๊ฒ์ด ์ต์ํ์ง ์๋ค๋ฉด, ์ฌ๊ธฐ์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํด๋ณด์ธ์!
์ด์ ๋ชจ๋ธ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! AutoModelForSeq2SeqLM๋ก T5๋ฅผ ๊ฐ์ ธ์ค์ธ์:
์ด์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค:
Seq2SeqTrainingArguments์์ ํ์ต ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์. ์ ์ผํ ํ์ ๋งค๊ฐ๋ณ์๋ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํ๋
output_dir์ ๋๋ค.push_to_hub=True๋ฅผ ์ค์ ํ์ฌ ์ด ๋ชจ๋ธ์ Hub์ ํธ์ํ ์ ์์ต๋๋ค(๋ชจ๋ธ์ ์ ๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค.) Trainer๋ ๊ฐ ์ํญ์ด ๋๋ ๋๋ง๋ค ROUGE ์งํ๋ฅผ ํ๊ฐํ๊ณ ํ์ต ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค.๋ชจ๋ธ, ๋ฐ์ดํฐ์ , ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ ๋ฐ
compute_metricsํจ์์ ํจ๊ป ํ์ต ์ธ์๋ฅผ Seq2SeqTrainer์ ์ ๋ฌํ์ธ์.train()์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ํ์ธํ๋ํ์ธ์.
ํ์ต์ด ์๋ฃ๋๋ฉด, ๋๊ตฌ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก push_to_hub() ๋ฉ์๋๋ก Hub์ ๊ณต์ ํฉ๋๋ค:
์์ฝ์ ์ํด ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ ์์ธํ ์์ ๋ฅผ ๋ณด๋ ค๋ฉด PyTorch notebook ๋๋ TensorFlow notebook์ ์ฐธ๊ณ ํ์ธ์.
์ถ๋ก [[inference]]
์ข์์, ์ด์ ๋ชจ๋ธ์ ํ์ธํ๋ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค!
์์ฝํ ํ ์คํธ๋ฅผ ์์ฑํด๋ณด์ธ์. T5์ ๊ฒฝ์ฐ ์์ ์ ๋ฐ๋ผ ์ ๋ ฅ ์์ ์ ๋์ฌ๋ฅผ ๋ถ์ฌ์ผ ํฉ๋๋ค. ์์ฝ์ ๊ฒฝ์ฐ, ์๋์ ๊ฐ์ ์ ๋์ฌ๋ฅผ ์ ๋ ฅ ์์ ๋ถ์ฌ์ผ ํฉ๋๋ค:
์ถ๋ก ์ ์ํด ํ์ธํ๋ํ ๋ชจ๋ธ์ ์ํํด ๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ pipeline()์์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์์ฝ์ ์ํํ pipeline()์ ์ธ์คํด์คํํ๊ณ ํ
์คํธ๋ฅผ ์ ๋ฌํ์ธ์:
์ํ๋ค๋ฉด ์๋์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์์
์ ์ํํ์ฌ pipeline()์ ๊ฒฐ๊ณผ์ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ต๋๋ค:
ํ
์คํธ๋ฅผ ํ ํฌ๋์ด์ฆํ๊ณ input_ids๋ฅผ PyTorch ํ
์๋ก ๋ฐํํฉ๋๋ค:
์์ฝ๋ฌธ์ ์์ฑํ๋ ค๋ฉด generate() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ธ์. ํ ์คํธ ์์ฑ์ ๋ํ ๋ค์ํ ์ ๋ต๊ณผ ์์ฑ์ ์ ์ดํ๊ธฐ ์ํ ๋งค๊ฐ๋ณ์์ ๋ํ ์์ธํ ๋ด์ฉ์ ํ ์คํธ ์์ฑ API๋ฅผ ์ฐธ์กฐํ์ธ์.
์์ฑ๋ ํ ํฐ ID๋ฅผ ํ ์คํธ๋ก ๋์ฝ๋ฉํฉ๋๋ค: