forked from microsoft/responsible-ai-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Forecasting: consume locally served model (microsoft#2319)
- Loading branch information
Showing
9 changed files
with
186 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
responsibleai/responsibleai/_internal/_served_model_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) Microsoft Corporation | ||
# Licensed under the MIT License. | ||
|
||
import json | ||
|
||
import requests | ||
|
||
from responsibleai.serialization_utilities import serialize_json_safe | ||
|
||
|
||
class ServedModelWrapper: | ||
"""Wrapper for locally served model. | ||
The purpose of ServedModelWrapper is to provide an abstraction | ||
for locally served models. This allows us to use the same code in | ||
RAIInsights for loaded models that can run in the same environment and | ||
also for locally served models. | ||
Locally served in this case means on localhost via HTTP. | ||
This could be in a separate conda environment, or even in a Docker | ||
container. | ||
:param port: The port on which the model is served. | ||
:type port: int | ||
""" | ||
def __init__(self, port): | ||
self.port = port | ||
|
||
def forecast(self, X): | ||
"""Get forecasts from the model. | ||
:param X: The input data. | ||
:type X: pandas.DataFrame | ||
:return: The model's forecasts based on the input data. | ||
:rtype: List[float] | ||
""" | ||
# request formatting according to mlflow docs | ||
# https://mlflow.org/docs/latest/cli.html#mlflow-models-serve | ||
# JSON safe serialization takes care of datetime columns | ||
response = requests.post( | ||
url=f"http://localhost:{self.port}/invocations", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps( | ||
{"dataframe_split": X.to_dict(orient='split')}, | ||
default=serialize_json_safe)) | ||
try: | ||
response.raise_for_status() | ||
except Exception: | ||
raise RuntimeError( | ||
"Could not retrieve predictions. " | ||
f"Model server returned status code {response.status_code} " | ||
f"and the following response: {response.content}") | ||
|
||
# json.loads decodes byte string response. | ||
# Response is a dictionary with a single entry 'predictions' | ||
return json.loads(response.content)['predictions'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) Microsoft Corporation | ||
# Licensed under the MIT License. | ||
|
||
import json | ||
import random | ||
from unittest import mock | ||
|
||
import pytest | ||
import requests | ||
from tests.common_utils import (RandomForecastingModel, | ||
create_tiny_forecasting_dataset) | ||
|
||
from responsibleai import FeatureMetadata, RAIInsights | ||
|
||
RAI_INSIGHTS_DIR_NAME = "rai_insights_test_served_model" | ||
|
||
|
||
# create a pytest fixture | ||
@pytest.fixture(scope="session") | ||
def rai_forecasting_insights_for_served_model(): | ||
X_train, X_test, y_train, y_test = create_tiny_forecasting_dataset() | ||
train = X_train.copy() | ||
train["target"] = y_train | ||
test = X_test.copy() | ||
test["target"] = y_test | ||
model = RandomForecastingModel() | ||
|
||
# create RAI Insights and save it | ||
rai_insights = RAIInsights( | ||
model=model, | ||
train=train, | ||
test=test, | ||
target_column="target", | ||
task_type='forecasting', | ||
feature_metadata=FeatureMetadata( | ||
datetime_features=['time'], | ||
time_series_id_features=['id'] | ||
), | ||
forecasting_enabled=True) | ||
rai_insights.save(RAI_INSIGHTS_DIR_NAME) | ||
|
||
|
||
@mock.patch("requests.post") | ||
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"}) | ||
def test_served_model( | ||
mock_post, | ||
rai_forecasting_insights_for_served_model): | ||
X_train, X_test, _, _ = create_tiny_forecasting_dataset() | ||
|
||
mock_post.return_value = mock.Mock( | ||
status_code=200, | ||
content=json.dumps({ | ||
"predictions": [random.random() for _ in range(len(X_train))] | ||
}) | ||
) | ||
|
||
rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME) | ||
forecasts = rai_insights.model.forecast(X_test) | ||
assert len(forecasts) == len(X_test) | ||
assert mock_post.call_count == 1 | ||
|
||
|
||
@mock.patch("requests.post") | ||
@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"}) | ||
def test_served_model_failed( | ||
mock_post, | ||
rai_forecasting_insights_for_served_model): | ||
_, X_test, _, _ = create_tiny_forecasting_dataset() | ||
|
||
response = requests.Response() | ||
response.status_code = 400 | ||
response._content = b"Could not connect to host since it actively " \ | ||
b"refuses the connection." | ||
mock_post.return_value = response | ||
|
||
rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME) | ||
with pytest.raises( | ||
Exception, | ||
match="Could not retrieve predictions. " | ||
"Model server returned status code 400 " | ||
f"and the following response: {response.content}"): | ||
rai_insights.model.forecast(X_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters