This project implements an improved version of GCNFold for RNA secondary structure prediction, using RNA-FM embeddings and ViennaRNA base pair probabilities.
- Create a virtual environment for HuggingFace Token to load bprna dataset:
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies:
pip install -r requirements.txtBefore training the model, you need to download and prepare the dataset. The project includes two scripts for this purpose:
The download_dataset.py script downloads the bpRNA dataset from HuggingFace and saves it as CSV files.
python src/data/download_dataset.py [options]--dataset_name: Name of the dataset on HuggingFace (default: "multimolecule/bprna")--output_dir: Directory to save the dataset (default: "data/bprna")
python src/data/download_dataset.py --output_dir "data/custom_path"The clean_bprna_dataset.py script processes the downloaded dataset to prepare it for training.
python src/data/clean_bprna_dataset.py [options]--input_file: Path to the input CSV file (default: "data/bprna/train.csv")--output_file: Path to save the cleaned dataset (default: "data/bprna/train_clean.csv")--min_length: Minimum sequence length to keep (default: 20)--max_length: Maximum sequence length to keep (default: 150)
python src/data/clean_bprna_dataset.py --min_length 30 --max_length 200The improved_gcnfold_training.py script trains the GCNFold model on the bpRNA dataset.
python improved_gcnfold_training.py [options]--epochs: Number of epochs to train (default: 20)--max_length: Maximum sequence length (default: 150)--min_length: Minimum sequence length (default: 20)--pos_weight: Positive weight for loss function (default: 10.0)--lr: Initial learning rate (default: 0.0001)--seed: Random seed (default: 42)--max_samples: Maximum number of samples to use and then splits to training, validation, and test sets (default: None)--patience: Patience for early stopping (default: 3)
python improved_gcnfold_training.py --epochs 30 --max_length 200 --lr 0.0005The script creates a timestamped directory in output/ containing:
- Trained model (
gcnfold_model.pt) - Training and validation loss curves
- ROC-AUC curves
- Precision-recall curves
- Confusion matrices
- Evaluation metrics
- Hyperparameter logs
The visualize_rna_structures.py script generates visualizations of RNA structures and their predictions.
python visualize_rna_structures.py [options]--model_path: Path to trained model (default: latest model in output/)--data_path: Path to RNA sequences (default: data/bprna/test_clean.csv)--output_dir: Directory to save visualizations (default: output/visualizations)--num_samples: Number of samples to visualize (default: 10)--max_length: Maximum sequence length (default: 150)
python visualize_rna_structures.py --model_path output/gcnfold_improved_20250506_015201/gcnfold_model.ptThe script generates visualizations for five seuqences of varying sizes in the specified visualization output directory:
- Structure heatmaps
- Base pair probability matrices
- Dot-bracket notation comparisons
- Performance metrics
The improved GCNFold model includes:
- RNA-FM embeddings (640-dimensional)
- 4 Graph Convolutional Network layers
- Hidden dimension of 256
- Minimum base pair distance of 3
- Structural constraints and stacking energy features
- ViennaRNA base pair probabilities as priors
The model is evaluated using:
- Accuracy
- Precision
- Recall
- F1 Score
- ROC-AUC
- Precision-Recall AUC
See requirements.txt for a complete list of dependencies. Key requirements include:
- PyTorch >= 2.0.0
- ViennaRNA >= 2.5.0
- RNA-FM embeddings
- DGL (Deep Graph Library)
- Other standard ML and scientific computing packages
- The model uses curriculum learning for the first 3 epochs
- Early stopping is implemented with a patience of 3 epochs
- Base pair probabilities are normalized to log-odds form
- Structural constraints enforce RNA base pairing rules
- Stacking energies are applied to enhance prediction stability