Fine-Tuning with Trainer API
Desliza para mostrar el menú
When you want to fine-tune a Transformer model on your own data, the Hugging Face Trainer API provides a powerful, user-friendly interface to manage the training process. The Trainer API abstracts away much of the boilerplate code required for training, evaluation, and logging, making it easier to focus on your model and data. You configure training using a set of training arguments, which control hyperparameters such as learning rate, number of epochs, batch size, and evaluation strategy. The evaluation strategy determines how and when the Trainer evaluates your model on a validation set (for instance, after each epoch or every set number of steps). Callbacks can also be attached to the Trainer to enable features like early stopping, dynamic learning rate scheduling, or custom logging.
Use small batch sizes when working with limited memory to avoid out-of-memory errors during fine-tuning.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import pandas as pd # Load a small subset of IMDb data data = { "text": [ "This movie was fantastic!", "Terrible film.", "I loved the acting.", "Worst plot ever.", "Great direction and story.", "Not worth the time." ], "label": [1, 0, 1, 0, 1, 0] } df = pd.DataFrame(data) train_df, val_df = train_test_split(df, test_size=0.33, random_state=42) tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") def tokenize(batch): return tokenizer(batch["text"], padding=True, truncation=True, max_length=64) train_encodings = tokenizer(list(train_df["text"]), truncation=True, padding=True, max_length=64) val_encodings = tokenizer(list(val_df["text"]), truncation=True, padding=True, max_length=64) import torch class IMDbDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) train_dataset = IMDbDataset(train_encodings, list(train_df["label"])) val_dataset = IMDbDataset(val_encodings, list(val_df["label"])) model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") training_args = TrainingArguments( output_dir="./results", num_train_epochs=2, per_device_train_batch_size=2, per_device_eval_batch_size=2, evaluation_strategy="epoch", logging_dir="./logs", logging_steps=5, load_best_model_at_end=True, metric_for_best_model="accuracy" ) def compute_metrics(eval_pred): logits, labels = eval_pred preds = logits.argmax(-1) return {"accuracy": accuracy_score(labels, preds)} trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics ) trainer.train() results = trainer.evaluate() print("Validation accuracy:", results["eval_accuracy"])
Monitor validation loss during training to detect and prevent overfitting.
¡Gracias por tus comentarios!
Pregunte a AI
Pregunte a AI
Pregunte lo que quiera o pruebe una de las preguntas sugeridas para comenzar nuestra charla