|
|
import torch |
|
|
import os |
|
|
from tirex import load_model, ForecastModel |
|
|
|
|
|
|
|
|
os.environ['TIREX_NO_CUDA'] = '1' |
|
|
|
|
|
class EndpointModel: |
|
|
def __init__(self): |
|
|
""" |
|
|
This class is used by Hugging Face Inference Endpoints |
|
|
to initialize the model once at startup. |
|
|
""" |
|
|
|
|
|
|
|
|
self.model: ForecastModel = load_model("NX-AI/TiRex") |
|
|
|
|
|
def __call__(self, inputs: dict) -> dict: |
|
|
""" |
|
|
This method is called for every inference request. |
|
|
Inputs must be JSON-serializable. |
|
|
Example request: |
|
|
{ |
|
|
"data": [[0.1, 0.2, 0.3, ...], [0.5, 0.6, ...]], # 2D array: batch_size x context_length |
|
|
"prediction_length": 64 |
|
|
} |
|
|
""" |
|
|
|
|
|
data = torch.tensor(inputs["data"], dtype=torch.float32) |
|
|
|
|
|
|
|
|
prediction_length = inputs.get("prediction_length", 64) |
|
|
|
|
|
|
|
|
quantiles, mean = self.model.forecast( |
|
|
context=data, |
|
|
prediction_length=prediction_length |
|
|
) |
|
|
|
|
|
|
|
|
return { |
|
|
"quantiles": {k: v.tolist() for k, v in quantiles.items()}, |
|
|
"mean": mean.tolist() |
|
|
} |
|
|
|
|
|
|