자연어처리

[NLP-tensorflow] Training an AI to create poetry (NLP Zero to Hero - Part 6)

kk_eezz 2022. 6. 8. 16:40

https://www.youtube.com/watch?v=ZMudJXhsUpY&list=PLQY2H8rRoyvzDbLUZkbudP-MFQZwNmU4S&index=6 

tokenizer = Tokenizer()

data = "In the town of Athy one Jeremy Lanigan \n Battered away ... ..."
corpus = data.lower().split("\n")

tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1

앞서, classification을 했을 때와는 다르게 문장을 생성하는 것이므로 train, test set을 분리하지 않는다.

input_sequences = [] # empty list of input sequences
for line in corpus:
	token_list = tokenizer.texts_to_sequences([line])[0] # for each line in the corpus, create the list of tokens
    	for i in range(1, len(token_list)):
    		n_gram_sequence = token_list[:i+1]
        	input_sequences.append(n_gram_sequence)

다음에 올 단어를 예측하기 위해 다음과 같이 증가하는 단어를 포함하는 sequence를 생성함

이렇게 하면 4 - 2 -> 66이 오는구나! 를 예측할 수 있음 

 

max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
# padding the sequences

xs = input_sequences[:,:-1]
labels = input_sequences[:,:-1]
# splitting X and label

ys = tf.keras.utils.to_categorical(labels, num_classes=total_words)
# one hot encoded

model = Sequential()
model.add(Embedding(total_words, 240, input_length=max_sequence_len=1))
model.add(Bidirectional(LSTM(150)))
model.add(Dense(total_words, activation='softmax'))
adam = Adam(lr=0.01)
model.compile(loss='categorical_crossentrotpy', opimizer=adam, metrics=['accuracy'])
history = model.fit(xs, ys, epochs=100, verbose=1)
seed_text = "I made a poetry machine"
next_words = 20

for _ in range(next_words):
	token_list = tokenizer.texts_to_sequences([seed_text])[0]
   	token_list = pad_sequences([token_list], maxlen=max_sequence_len-1,padding='pre')
    	predicted = model.predict_classes(token_list, verbose=0)
    	output_word=""
    	for word, index in tokenizer.word_index.items():
    		if index == predicted:
        		output_word = word
            	break
        seep_text += " " + output_word
print(seed_text)

결과값: