mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import yaml
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from alveslib import get_logger
|
|
from ml.models.arch import Model
|
|
|
|
logger = get_logger("ml-train")
|
|
|
|
|
|
class Trainer:
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
learning_rate: float,
|
|
log_dir: str,
|
|
log_every_n_steps: int,
|
|
) -> None:
|
|
self.model = model
|
|
self.train_loader = train_loader
|
|
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
self.writer = SummaryWriter(log_dir)
|
|
self.step = 0
|
|
self.log_every_n_steps = log_every_n_steps
|
|
|
|
def train_epoch(self) -> float:
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
for batch_idx, (features, target) in enumerate(self.train_loader):
|
|
self.optimizer.zero_grad()
|
|
output = self.model(features)
|
|
loss = self.criterion(output, target)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
if batch_idx % self.log_every_n_steps == 0:
|
|
self.writer.add_scalar("Loss/TrainStep", loss.item(), self.step)
|
|
self.step += 1
|
|
|
|
return total_loss / max(len(self.train_loader), 1)
|
|
|
|
def train(self, epochs: int) -> None:
|
|
for epoch in range(epochs):
|
|
avg_loss = self.train_epoch()
|
|
self.writer.add_scalar("Loss/TrainEpoch", avg_loss, epoch)
|
|
logger.info(f"epoch={epoch + 1}/{epochs} avg_loss={avg_loss:.5f}")
|
|
self.writer.close()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Train a baseline model")
|
|
parser.add_argument("--config", default="ml/configs/train/default.yaml")
|
|
parser.add_argument("--dataset", default="ml/data/processed/dataset.pt")
|
|
parser.add_argument("--weights", default=None)
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config, "r", encoding="utf-8") as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
torch.manual_seed(int(cfg["seed"]))
|
|
|
|
dataset_blob = torch.load(args.dataset, map_location="cpu")
|
|
dataset = TensorDataset(dataset_blob["features"], dataset_blob["labels"])
|
|
train_loader = DataLoader(dataset, batch_size=int(cfg["batch_size"]), shuffle=True)
|
|
|
|
model = Model(
|
|
input_dim=int(cfg["input_dim"]),
|
|
hidden_dim=int(cfg["hidden_dim"]),
|
|
num_classes=int(cfg["num_classes"]),
|
|
)
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
learning_rate=float(cfg["learning_rate"]),
|
|
log_dir=str(cfg["tensorboard_dir"]),
|
|
log_every_n_steps=int(cfg["log_every_n_steps"]),
|
|
)
|
|
trainer.train(epochs=int(cfg["epochs"]))
|
|
|
|
weights_target = args.weights or cfg.get(
|
|
"weights_output", "ml/models/weights/model.pt"
|
|
)
|
|
weights_path = Path(weights_target)
|
|
weights_path.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(model.state_dict(), weights_path)
|
|
logger.info(f"saved_weights={weights_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|