Skip to content

Commit

Permalink
Forecasting: consume locally served model (microsoft#2319)
Browse files Browse the repository at this point in the history
  • Loading branch information
romanlutz authored Sep 9, 2023
1 parent a307b8d commit 615d68e
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def forecast_quantiles(self, X, quantiles=None):
"""
if quantiles is None:
quantiles = [0.025, 0.975]
if (type(quantiles) is not list or
if (not isinstance(quantiles, list) or
len(quantiles) == 0 or
any([type(q) is not float or
any([not isinstance(q, float) or
q <= 0 or
q >= 1 for q in quantiles])):
raise ValueError(
Expand Down
56 changes: 56 additions & 0 deletions responsibleai/responsibleai/_internal/_served_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import json

import requests

from responsibleai.serialization_utilities import serialize_json_safe


class ServedModelWrapper:
"""Wrapper for locally served model.
The purpose of ServedModelWrapper is to provide an abstraction
for locally served models. This allows us to use the same code in
RAIInsights for loaded models that can run in the same environment and
also for locally served models.
Locally served in this case means on localhost via HTTP.
This could be in a separate conda environment, or even in a Docker
container.
:param port: The port on which the model is served.
:type port: int
"""
def __init__(self, port):
self.port = port

def forecast(self, X):
"""Get forecasts from the model.
:param X: The input data.
:type X: pandas.DataFrame
:return: The model's forecasts based on the input data.
:rtype: List[float]
"""
# request formatting according to mlflow docs
# https://mlflow.org/docs/latest/cli.html#mlflow-models-serve
# JSON safe serialization takes care of datetime columns
response = requests.post(
url=f"http://localhost:{self.port}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(
{"dataframe_split": X.to_dict(orient='split')},
default=serialize_json_safe))
try:
response.raise_for_status()
except Exception:
raise RuntimeError(
"Could not retrieve predictions. "
f"Model server returned status code {response.status_code} "
f"and the following response: {response.content}")

# json.loads decodes byte string response.
# Response is a dictionary with a single entry 'predictions'
return json.loads(response.content)['predictions']
5 changes: 5 additions & 0 deletions responsibleai/responsibleai/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,8 @@ class FileFormats:
JSON = '.json'
PKL = '.pkl'
TXT = '.txt'


class ModelServingConstants:
"""Constants relevant for model serving."""
RAI_MODEL_SERVING_PORT_ENV_VAR = "RAI_MODEL_SERVING_PORT"
14 changes: 14 additions & 0 deletions responsibleai/responsibleai/rai_insights/rai_base_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Defines the RAIBaseInsights class."""

import json
import os
import pickle
import warnings
from abc import ABC, abstractmethod
Expand All @@ -14,7 +15,9 @@

import responsibleai
from raiutils.exceptions import UserConfigValidationException
from responsibleai._internal._served_model_wrapper import ServedModelWrapper
from responsibleai._internal.constants import (FileFormats, Metadata,
ModelServingConstants,
SerializationAttributes)

_DTYPES = 'dtypes'
Expand Down Expand Up @@ -263,6 +266,17 @@ def _load_model(inst, path):
:param path: The directory path to model location.
:type path: str
"""
# Communicate with locally served model
# if the environment variable RAI_MODEL_SERVING_PORT is set.
model_serving_port = os.getenv(
ModelServingConstants.RAI_MODEL_SERVING_PORT_ENV_VAR)
if model_serving_port is not None:
inst.__dict__['_' + _SERIALIZER] = None
inst.__dict__[Metadata.MODEL] = \
ServedModelWrapper(port=model_serving_port)
return

# Otherwise use the conventional paths with local artifacts
top_dir = Path(path)
serializer_path = top_dir / _SERIALIZER
model_load_err = ('ERROR-LOADING-USER-MODEL: '
Expand Down
3 changes: 2 additions & 1 deletion responsibleai/responsibleai/serialization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import numpy as np
import pandas as pd


def serialize_json_safe(o: Any):
Expand All @@ -31,7 +32,7 @@ def serialize_json_safe(o: Any):
if isinstance(o, str):
return json.dumps(o)[1:-1]
return o
elif isinstance(o, datetime.datetime):
elif isinstance(o, datetime.datetime) or isinstance(o, pd.Timestamp):
return o.__str__()
elif isinstance(o, dict):
return {k: serialize_json_safe(v, ) for k, v in o.items()}
Expand Down
14 changes: 14 additions & 0 deletions responsibleai/tests/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import json
import random

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -64,6 +65,19 @@ def create_tiny_forecasting_dataset():
return X_train, X_test, y_train, y_test


class RandomForecastingModel():
def forecast(self, X):
return np.array([random.random() for _ in range(len(X))])


class RandomForecastingModelWithQuantiles(RandomForecastingModel):
def forecast_quantiles(self, X, quantiles):
return [
[random.random() for _ in range(len(X))],
[random.random() for _ in range(len(X))]
]


class FetchDiceAdultCensusIncomeDataset(object):
def __init__(self):
pass
Expand Down
22 changes: 5 additions & 17 deletions responsibleai/tests/rai_insights/test_rai_insights_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import pandas as pd
import pytest
from lightgbm import LGBMClassifier
from tests.common_utils import (create_iris_data,
from tests.common_utils import (RandomForecastingModel,
RandomForecastingModelWithQuantiles,
create_iris_data,
create_tiny_forecasting_dataset)

from rai_test_utils.datasets.tabular import (
Expand Down Expand Up @@ -1089,11 +1091,7 @@ def predict_proba(self, test_data_pandas):
def test_optional_method_not_present(self):
X_train, X_test, y_train, y_test = create_tiny_forecasting_dataset()

class ForecastModelWithoutQuantiles():
def forecast(self, X):
return [random.random() for _ in range(len(X))]

model = ForecastModelWithoutQuantiles()
model = RandomForecastingModel()
X_train = X_train.copy()
X_test = X_test.copy()
X_train[TARGET] = y_train
Expand All @@ -1114,17 +1112,7 @@ def forecast(self, X):
def test_optional_method_present(self):
X_train, X_test, y_train, y_test = create_tiny_forecasting_dataset()

class ForecastModelWithoutQuantiles():
def forecast(self, X):
return [random.random() for _ in range(len(X))]

def forecast_quantiles(self, X, quantiles):
return [
[random.random() for _ in range(len(X))],
[random.random() for _ in range(len(X))]
]

model = ForecastModelWithoutQuantiles()
model = RandomForecastingModelWithQuantiles()
X_train = X_train.copy()
X_test = X_test.copy()
X_train[TARGET] = y_train
Expand Down
82 changes: 82 additions & 0 deletions responsibleai/tests/rai_insights/test_served_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import json
import random
from unittest import mock

import pytest
import requests
from tests.common_utils import (RandomForecastingModel,
create_tiny_forecasting_dataset)

from responsibleai import FeatureMetadata, RAIInsights

RAI_INSIGHTS_DIR_NAME = "rai_insights_test_served_model"


# create a pytest fixture
@pytest.fixture(scope="session")
def rai_forecasting_insights_for_served_model():
X_train, X_test, y_train, y_test = create_tiny_forecasting_dataset()
train = X_train.copy()
train["target"] = y_train
test = X_test.copy()
test["target"] = y_test
model = RandomForecastingModel()

# create RAI Insights and save it
rai_insights = RAIInsights(
model=model,
train=train,
test=test,
target_column="target",
task_type='forecasting',
feature_metadata=FeatureMetadata(
datetime_features=['time'],
time_series_id_features=['id']
),
forecasting_enabled=True)
rai_insights.save(RAI_INSIGHTS_DIR_NAME)


@mock.patch("requests.post")
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"})
def test_served_model(
mock_post,
rai_forecasting_insights_for_served_model):
X_train, X_test, _, _ = create_tiny_forecasting_dataset()

mock_post.return_value = mock.Mock(
status_code=200,
content=json.dumps({
"predictions": [random.random() for _ in range(len(X_train))]
})
)

rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME)
forecasts = rai_insights.model.forecast(X_test)
assert len(forecasts) == len(X_test)
assert mock_post.call_count == 1


@mock.patch("requests.post")
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"})
def test_served_model_failed(
mock_post,
rai_forecasting_insights_for_served_model):
_, X_test, _, _ = create_tiny_forecasting_dataset()

response = requests.Response()
response.status_code = 400
response._content = b"Could not connect to host since it actively " \
b"refuses the connection."
mock_post.return_value = response

rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME)
with pytest.raises(
Exception,
match="Could not retrieve predictions. "
"Model server returned status code 400 "
f"and the following response: {response.content}"):
rai_insights.model.forecast(X_test)
6 changes: 6 additions & 0 deletions responsibleai/tests/test_serialization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def test_serialize_timestamp(self):
result = serialize_json_safe(datetime_object)
assert datetime_str in result

def test_serialize_pandas_timestamp(self):
datetime_str = "2020-10-10"
datetime_object = pd.Timestamp(datetime_str)
result = serialize_json_safe(datetime_object)
assert datetime_str in result

def test_serialize_via_json_timestamp(self):
timestamp_obj = pd.Timestamp(2020, 1, 1)
assert isinstance(timestamp_obj, pd.Timestamp)
Expand Down

0 comments on commit 615d68e

Please sign in to comment.