import torch import torch.nn as nn class Model(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, num_classes: int) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)