Skip to content

Inference script on a single file #17

@ellkrauze

Description

@ellkrauze

Hello! Is it possible to get the OBJ file from a CC file using pretrained model? If so, how do I do it properly?

I've managed to create this script so far, but it creates a TestOnlyDataset which is not a proper way to get the model's inference, I suppose...

import os
import torch
import pickle
import numpy as np
from pathlib import Path

import yaml
import hydra
from torch_geometric.loader import DataLoader
from network import PolyGNN
from dataset import TestOnlyDataset
from utils import Sampler
from abspy import VertexGroup, CellComplex
from omegaconf import DictConfig

from utils import init_device, Sampler, set_seed

def load_model(cfg: DictConfig, device: str='cpu'):
    """Load the PolyGNN model from a checkpoint."""
    model = PolyGNN(cfg)
    state = torch.load(cfg.checkpoint_path, map_location=device)
    state_dict = state['state_dict']
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model

def preprocess_data(file_path, config):
    """Prepare the dataset for inference."""
    sampler = Sampler(strategy=config.sample.strategy,
                      length=config.sample.length,
                      ratio=config.sample.ratio,
                      resolutions=config.sample.resolutions,
                      duplicate=config.sample.duplicate,
                      seed=config.seed)
    transform = sampler.sample if config.sample.transform else None
    dataset = TestOnlyDataset(pre_transform=None, transform=transform, root=file_path, split=None)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    return dataloader

def run_inference(model, dataloader, device, output_dir):
    """Run inference on the given dataset."""
    Path(output_dir).mkdir(exist_ok=True, parents=True)

    for batch in dataloader:
        batch = batch.to(device)
        with torch.no_grad():
            outputs = model(batch)
            predictions = outputs.argmax(dim=1)

        # Save predictions
        for i, name in enumerate(batch.name):
            output_path = Path(output_dir) / f"{name}.npy"
            np.save(output_path, predictions[i].cpu().numpy())

def convert_data(src_path, dst_path=None):
    if dst_path is None:
        base, ext = os.path.splitext(src_path)
        dst_path = os.path.join(base + ".cc")
    vertex_group = VertexGroup(src_path, quiet=True)
    cell_complex = CellComplex(vertex_group.planes, vertex_group.aabbs,
                               vertex_group.points_grouped, build_graph=True, quiet=True)
    cell_complex.prioritise_planes(prioritise_verticals=True)
    cell_complex.construct()
    cell_complex.save(dst_path)
    print(f"New file written to {dst_path}")
    return dst_path

@hydra.main(config_path='./conf', config_name='infer_config', version_base='1.2')
def main(cfg: DictConfig):
    set_seed(cfg.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = load_model(cfg).to(device)

    cc_file = convert_data(cfg.data_file)
    print(f"Converted {cfg.data_file} to cc")

    dataloader = preprocess_data(cc_file, cfg)
    print("Created dataloader")

    run_inference(model, dataloader, device, cfg.output_dir)
    print(f"Inference completed. Results saved to {cfg.output_dir}")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions