Time Series Forecasting
TiRex
TiRex-aws-inference / inference.py
Nikita
inference and reqs
155c2a4
import torch
import os
from tirex import load_model, ForecastModel
# Disable CUDA for Hugging Face endpoints unless explicitly enabled
os.environ['TIREX_NO_CUDA'] = '1'
class EndpointModel:
def __init__(self):
"""
This class is used by Hugging Face Inference Endpoints
to initialize the model once at startup.
"""
# Load the TiRex model from Hugging Face hub
# This will resolve to your repo (NX-AI/TiRex)
self.model: ForecastModel = load_model("NX-AI/TiRex")
def __call__(self, inputs: dict) -> dict:
"""
This method is called for every inference request.
Inputs must be JSON-serializable.
Example request:
{
"data": [[0.1, 0.2, 0.3, ...], [0.5, 0.6, ...]], # 2D array: batch_size x context_length
"prediction_length": 64
}
"""
# Convert input data to a torch tensor
data = torch.tensor(inputs["data"], dtype=torch.float32)
# Default prediction length if not provided
prediction_length = inputs.get("prediction_length", 64)
# Run forecast
quantiles, mean = self.model.forecast(
context=data,
prediction_length=prediction_length
)
# Return both quantiles and mean as Python lists (JSON-safe)
return {
"quantiles": {k: v.tolist() for k, v in quantiles.items()},
"mean": mean.tolist()
}