Skip to content

Commit

Permalink
fix label column incorrectly added by feature extractors in RAI Visio…
Browse files Browse the repository at this point in the history
…n dashboard for automl models (#2532)

* fix label column incorrectly added by feature extractors in RAI Vision dashboard for automl models

* fixup
  • Loading branch information
imatiach-msft committed Feb 9, 2024
1 parent 8d2a117 commit 973631d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from tqdm import tqdm

from responsibleai.feature_metadata import FeatureMetadata
from responsibleai_vision.common.constants import ExtractedFeatures
from responsibleai_vision.common.constants import (ExtractedFeatures,
ImageColumns)
from responsibleai_vision.utils.image_reader import (
get_all_exif_feature_names, get_image_from_path,
get_image_pointer_from_path)

MEAN_PIXEL_VALUE = ExtractedFeatures.MEAN_PIXEL_VALUE.value
MAX_CUSTOM_LEN = 100
IMAGE_DETAILS = ImageColumns.IMAGE_DETAILS.value


def extract_features(image_dataset: pd.DataFrame,
Expand Down Expand Up @@ -58,6 +60,8 @@ def extract_features(image_dataset: pd.DataFrame,
start_meta_index = 2
if isinstance(target_column, list):
start_meta_index = len(target_column) + 1
if IMAGE_DETAILS in column_names:
start_meta_index += 1
for j in range(start_meta_index, image_dataset.shape[1]):
if has_dropped_features and column_names[j] in dropped_features:
continue
Expand Down
7 changes: 7 additions & 0 deletions responsibleai_vision/tests/rai_vision_insights_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def validate_rai_vision_insights(
pd.testing.assert_frame_equal(rai_vision_test, test_data)
assert rai_vision_insights.target_column == target_column
assert rai_vision_insights.task_type == task_type
# make sure label column not in _ext_test extracted features data
assert target_column not in rai_vision_insights._ext_features
# also not in last column of _ext_test, which is prone to happen
# if incorrect number of metadata columns specified in
# feature_extractors call
first_row = rai_vision_insights._ext_test[0]
assert not isinstance(first_row[len(first_row) - 1], list)


def run_and_validate_serialization(
Expand Down
7 changes: 5 additions & 2 deletions responsibleai_vision/tests/test_feature_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import pytest
from common_vision_utils import (load_flowers_dataset, load_fridge_dataset,
load_fridge_object_detection_dataset,
load_imagenet_dataset)
Expand Down Expand Up @@ -39,8 +40,10 @@ def extract_dataset_features(data, feature_metadata=None):


class TestFeatureExtractors(object):
def test_extract_features_fridge_object_detection(self):
data = load_fridge_object_detection_dataset(automl_format=False)
@pytest.mark.parametrize("automl_format", [True, False])
def test_extract_features_fridge_object_detection(self, automl_format):
data = load_fridge_object_detection_dataset(
automl_format=automl_format)
extracted_features, feature_names = extract_dataset_features(data)
expected_feature_names = [MEAN_PIXEL_VALUE] + FRIDGE_METADATA_FEATURES
validate_extracted_features(extracted_features, feature_names,
Expand Down

0 comments on commit 973631d

Please sign in to comment.