100 lines
3.2 KiB
Python
100 lines
3.2 KiB
Python
from datasets import Dataset
|
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
|
|
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
|
|
from transformers.trainer import Trainer
|
|
from transformers.training_args import TrainingArguments
|
|
import torch
|
|
from sklearn.model_selection import train_test_split
|
|
import torch.nn.functional as F
|
|
from keywords import KEYWORDS
|
|
from data_loader.dir_loader import load_subtitles_from_dir
|
|
|
|
sentiment_model_name = "cointegrated/rubert-tiny-sentiment-balanced"
|
|
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
|
|
sentiment_model = AutoModelForSequenceClassification.from_pretrained(
|
|
sentiment_model_name)
|
|
|
|
|
|
def is_important(text):
|
|
low = text.lower()
|
|
try:
|
|
inputs = sentiment_tokenizer(
|
|
low, return_tensors="pt", truncation=True, padding=True
|
|
)
|
|
with torch.no_grad():
|
|
logits = sentiment_model(**inputs).logits
|
|
probs = F.softmax(logits, dim=1)
|
|
label_id = torch.argmax(probs, dim=1).item()
|
|
labels = ["NEGATIVE", "NEUTRAL", "POSITIVE"]
|
|
label = labels[label_id]
|
|
except Exception:
|
|
print("Ошибка расчета сентимента")
|
|
return 0
|
|
if len(low.split()) < 2:
|
|
return 0
|
|
if label in ("NEGATIVE", "POSITIVE"):
|
|
return 1
|
|
if any(kw in low for kw in KEYWORDS):
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def main():
|
|
df = load_subtitles_from_dir("./data/subtitles/train")
|
|
df["label"] = df["text"].astype(str).apply(is_important)
|
|
|
|
train_texts, test_texts = train_test_split(
|
|
df, test_size=0.2, random_state=42)
|
|
|
|
model_name = "bert-base-multilingual-cased"
|
|
tokenizer = BertTokenizer.from_pretrained(model_name)
|
|
|
|
def tokenize_function(example):
|
|
return tokenizer(
|
|
example["text"],
|
|
padding="max_length",
|
|
truncation=True, max_length=64
|
|
)
|
|
|
|
train_dataset = Dataset.from_pandas(train_texts[["text", "label"]])
|
|
test_dataset = Dataset.from_pandas(test_texts[["text", "label"]])
|
|
|
|
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
|
test_dataset = test_dataset.map(tokenize_function, batched=True)
|
|
|
|
train_dataset = train_dataset.rename_column("label", "labels")
|
|
test_dataset = test_dataset.rename_column("label", "labels")
|
|
|
|
train_dataset.set_format(
|
|
"torch", columns=["input_ids", "attention_mask", "labels"])
|
|
test_dataset.set_format(
|
|
"torch", columns=["input_ids", "attention_mask", "labels"])
|
|
|
|
model = BertForSequenceClassification.from_pretrained(
|
|
model_name, num_labels=2)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir="./bert_trained",
|
|
num_train_epochs=3,
|
|
per_device_train_batch_size=16,
|
|
save_strategy="no",
|
|
logging_dir="./logs",
|
|
logging_steps=10,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=test_dataset,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model("./bert_trained")
|
|
tokenizer.save_pretrained("./bert_trained")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|