mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
Initial commit
This commit is contained in:
36
ml/inference.py
Normal file
36
ml/inference.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
# TODO: Import model when ready
|
||||
from models import * # TODO: SPECIFY
|
||||
|
||||
class InputData(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
weights_path = os.getenv("ML_LATEST_WEIGHTS_PATH")
|
||||
if weights_path is None:
|
||||
raise RuntimeError("ML_LATEST_WEIGHTS_PATH not set")
|
||||
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="ML Inference API", version="1.0.0")
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
return {"status": "healthy", "service": "ml-inference"}
|
||||
|
||||
@app.post("/predict")
|
||||
def predict(data: InputData):
|
||||
|
||||
#TODO: x = torch.tensor([data.features], dtype=torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
#TODO: y = model(x)
|
||||
|
||||
y=torch.tensor(0)
|
||||
return {"prediction": y.tolist()}
|
||||
Reference in New Issue
Block a user