recap_gen/train.py

100 lines
3.2 KiB
Python

from datasets import Dataset
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
import torch
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from keywords import KEYWORDS
from data_loader.dir_loader import load_subtitles_from_dir
sentiment_model_name = "cointegrated/rubert-tiny-sentiment-balanced"
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(
sentiment_model_name)
def is_important(text):
low = text.lower()
try:
inputs = sentiment_tokenizer(
low, return_tensors="pt", truncation=True, padding=True
)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
probs = F.softmax(logits, dim=1)
label_id = torch.argmax(probs, dim=1).item()
labels = ["NEGATIVE", "NEUTRAL", "POSITIVE"]
label = labels[label_id]
except Exception:
print("Ошибка расчета сентимента")
return 0
if len(low.split()) < 2:
return 0
if label in ("NEGATIVE", "POSITIVE"):
return 1
if any(kw in low for kw in KEYWORDS):
return 1
return 0
def main():
df = load_subtitles_from_dir("./data/subtitles/train")
df["label"] = df["text"].astype(str).apply(is_important)
train_texts, test_texts = train_test_split(
df, test_size=0.2, random_state=42)
model_name = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)
def tokenize_function(example):
return tokenizer(
example["text"],
padding="max_length",
truncation=True, max_length=64
)
train_dataset = Dataset.from_pandas(train_texts[["text", "label"]])
test_dataset = Dataset.from_pandas(test_texts[["text", "label"]])
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
train_dataset = train_dataset.rename_column("label", "labels")
test_dataset = test_dataset.rename_column("label", "labels")
train_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "labels"])
model = BertForSequenceClassification.from_pretrained(
model_name, num_labels=2)
training_args = TrainingArguments(
output_dir="./bert_trained",
num_train_epochs=3,
per_device_train_batch_size=16,
save_strategy="no",
logging_dir="./logs",
logging_steps=10,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
trainer.save_model("./bert_trained")
tokenizer.save_pretrained("./bert_trained")
if __name__ == "__main__":
main()