mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
Initial commit
This commit is contained in:
15
ml/models/arch.py
Normal file
15
ml/models/arch.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
98
ml/models/train.py
Normal file
98
ml/models/train.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from alveslib import get_logger
|
||||
from ml.models.arch import Model
|
||||
|
||||
logger = get_logger("ml-train")
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
learning_rate: float,
|
||||
log_dir: str,
|
||||
log_every_n_steps: int,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
self.step = 0
|
||||
self.log_every_n_steps = log_every_n_steps
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
for batch_idx, (features, target) in enumerate(self.train_loader):
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
if batch_idx % self.log_every_n_steps == 0:
|
||||
self.writer.add_scalar("Loss/TrainStep", loss.item(), self.step)
|
||||
self.step += 1
|
||||
|
||||
return total_loss / max(len(self.train_loader), 1)
|
||||
|
||||
def train(self, epochs: int) -> None:
|
||||
for epoch in range(epochs):
|
||||
avg_loss = self.train_epoch()
|
||||
self.writer.add_scalar("Loss/TrainEpoch", avg_loss, epoch)
|
||||
logger.info(f"epoch={epoch + 1}/{epochs} avg_loss={avg_loss:.5f}")
|
||||
self.writer.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Train a baseline model")
|
||||
parser.add_argument("--config", default="ml/configs/train/default.yaml")
|
||||
parser.add_argument("--dataset", default="ml/data/processed/dataset.pt")
|
||||
parser.add_argument("--weights", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
torch.manual_seed(int(cfg["seed"]))
|
||||
|
||||
dataset_blob = torch.load(args.dataset, map_location="cpu")
|
||||
dataset = TensorDataset(dataset_blob["features"], dataset_blob["labels"])
|
||||
train_loader = DataLoader(dataset, batch_size=int(cfg["batch_size"]), shuffle=True)
|
||||
|
||||
model = Model(
|
||||
input_dim=int(cfg["input_dim"]),
|
||||
hidden_dim=int(cfg["hidden_dim"]),
|
||||
num_classes=int(cfg["num_classes"]),
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
log_dir=str(cfg["tensorboard_dir"]),
|
||||
log_every_n_steps=int(cfg["log_every_n_steps"]),
|
||||
)
|
||||
trainer.train(epochs=int(cfg["epochs"]))
|
||||
|
||||
weights_target = args.weights or cfg.get(
|
||||
"weights_output", "ml/models/weights/model.pt"
|
||||
)
|
||||
weights_path = Path(weights_target)
|
||||
weights_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
logger.info(f"saved_weights={weights_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user