from datasets import Dataset from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification from transformers.models.bert import BertTokenizer, BertForSequenceClassification from transformers.trainer import Trainer from transformers.training_args import TrainingArguments from subtitles.parser import parse_srt_to_df import torch from sklearn.model_selection import train_test_split import torch.nn.functional as F sentiment_model_name = "cointegrated/rubert-tiny-sentiment-balanced" sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) sentiment_model = AutoModelForSequenceClassification.from_pretrained( sentiment_model_name) def is_important(text): try: inputs = sentiment_tokenizer( text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): logits = sentiment_model(**inputs).logits probs = F.softmax(logits, dim=1) label_id = torch.argmax(probs, dim=1).item() labels = ["NEGATIVE", "NEUTRAL", "POSITIVE"] label = labels[label_id] except Exception: return 0 if label in ("NEGATIVE", "POSITIVE"): return 1 if len(text.split()) > 6: return 1 return 0 def main(): df = parse_srt_to_df("Breaking_Bad_RUS_2008_20210402033853.srt") df["label"] = df["text"].astype(str).apply(is_important) train_texts, test_texts = train_test_split( df, test_size=0.2, random_state=42) model_name = "bert-base-multilingual-cased" tokenizer = BertTokenizer.from_pretrained(model_name) def tokenize_function(example): return tokenizer(example["text"], padding="max_length", truncation=True, max_length=64) train_dataset = Dataset.from_pandas(train_texts[["text", "label"]]) test_dataset = Dataset.from_pandas(test_texts[["text", "label"]]) train_dataset = train_dataset.map(tokenize_function, batched=True) test_dataset = test_dataset.map(tokenize_function, batched=True) train_dataset = train_dataset.rename_column("label", "labels") test_dataset = test_dataset.rename_column("label", "labels") train_dataset.set_format( "torch", columns=["input_ids", "attention_mask", "labels"]) test_dataset.set_format( "torch", columns=["input_ids", "attention_mask", "labels"]) model = BertForSequenceClassification.from_pretrained( model_name, num_labels=2) training_args = TrainingArguments( output_dir="./bert_trained", num_train_epochs=3, per_device_train_batch_size=16, save_strategy="no", logging_dir="./logs", logging_steps=10, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, ) trainer.train() trainer.save_model("./bert_trained") tokenizer.save_pretrained("./bert_trained") if __name__ == "__main__": main()