mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
16 lines
416 B
Python
16 lines
416 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int) -> None:
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, num_classes),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.net(x)
|