This is the default final project for the Stanford CS 224N class. Please refer to the project handout on the course website for detailed instructions and an overview of the codebase.
This project comprises two parts. In the first part, you will implement some important components of the GPT-2 model to better understand its architecture. In the second part, you will use the token embeddings produced by your GPT-2 model on two downstream tasks: paraphrase detection and sonnet generation. You will implement extensions to improve your model's performance on these tasks.
In broad strokes, Part 1 of this project targets:
- modules/attention.py: Missing code blocks.
- modules/gpt2_layer.py: Missing code blocks.
- models/gpt2.py: Missing code blocks.
- classifier.py: Missing code blocks.
- optimizer.py: Missing code blocks.
To test Part 1, you will run:
optimizer_test.py
: To test your implementation ofoptimizer.py
.sanity_check.py
: To test your implementation of GPT models.classifier.py
: To perform sentiment classification using your models.
In Part 2 of this project, you will use GPT2 (via cloze-style classification) detect if one sentence is a paraphrase of another as well as generate sonnets via autoregressive language modeling.
To test Part 2, you will run:
paraphrase_detection.py
: To perform paraphrase detection.sonnet_generation.py
: To perform sonnet generation.
Important: Adjust training hyperparameters, particularly batch size, according to your GPU's specifications to optimize performance and prevent out-of-memory errors.
While there are missing code blocks that you need to implement in both of these files, the main focus of this second part are the extensions: how you modify your GPT2 model to improve its ability to determine if one sentence is a paraphrase of another as well as its ability to generate sonnets.
Follow setup.sh
to properly setup a conda environment and install dependencies.
This project is adapted from a prior year's CS 224N project Implement BERT .
Parts of the code are from the transformers
library (Apache License 2.0).