RAG(Retrieval-Augmented Generation)λ μ 보 κ²μ(Retrieval)κ³Ό ν μ€νΈ μμ±(Generation)μ κ²°ν©ν AI λͺ¨λΈ μν€ν μ²μ λλ€. μΈλΆ μ§μ μμ€μμ μ 보λ₯Ό κ²μν ν, μ΄ μ 보λ₯Ό λ°νμΌλ‘ λ΅λ³μ μμ±νλ λ°©μμΌλ‘ μλν©λλ€.
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import torch
# λͺ¨λΈκ³Ό ν ν¬λμ΄μ λ‘λ
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
# μ§λ¬Έ μ€μ
input_text = "RAGκ° λ¬΄μμΈκ°μ?"
# λͺ¨λΈμ μ
λ ₯ μ μ²λ¦¬
inputs = tokenizer(input_text, return_tensors="pt")
# λ΅λ³ μμ±
outputs = model.generate(inputs["input_ids"])
# μμ±λ λ΅λ³ λμ½λ©
print(tokenizer.decode(outputs[0], skip_special_tokens=True))