Note for Transformer & BERT
BERT
-
BERT is pretrained with two objectives: masked language modeling and next-sentence prediction
- [SEP] [CLS] Intro & Coda
- [CLS] classification input label
The input embeddings are passed through multiple encoder layers to output some final hidden states.
-
To use the pretrained model for text classification, add a sequence classification head on top of the base BERT model. The sequence classification head is a linear layer that accepts the final hidden states and performs a linear transformation to convert them into logits.
The cross-entropy loss is calculated between the logits and target to find the most likely label.
Token classification for BERT
https://huggingface.co/docs/transformers/tasks/token_classification
Use IMDb first:
from datasets import load_dataset
imdb = load_dataset("imdb")
imdb["test"][0] # {"label" : 0, "text" : ....}
tokenizer(examples["text"], truncation=True)
, truncation
truncates sequences to be no longer than DistilBERT’s maximum input length
To apply prepocesses, use map()
, batched
can enable you to train multiple module at the same time
Fine Tuning
Import module:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)
id2lable & label2id
: Map for translating ids to labels: E.G.
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
Define hyperparameters & trainer:
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()