Skip to content

Commit

Permalink
[Model Overview] Box plot data for classification probabilities from …
Browse files Browse the repository at this point in the history
…SDK endpoint (#1689)

* temp changes for calling box plot endpoint

* resolve import error

* fix import error

* fix import error

* temp change for box plot

* temp change

* temp change to remove local raiwidgets_big_data

* remove unused changes

* remove unused changes

* move to calculateBoxData

* relabel filters

* only call model_overview for once

* remove unused comments

* remove unused comments

* make new api call on probability option update

* fix box plot error on tab switch

* fix tab switch error

* add UT

* fix lint

* put behind flight
  • Loading branch information
tongyu-microsoft committed Sep 8, 2022
1 parent 05944f4 commit 16e307f
Show file tree
Hide file tree
Showing 11 changed files with 414 additions and 28 deletions.
12 changes: 11 additions & 1 deletion apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import {
ICausalWhatIfData,
IErrorAnalysisMatrix
IErrorAnalysisMatrix,
IHighchartBoxData
} from "@responsible-ai/core-ui";
import {
ModelAssessmentDashboard,
Expand All @@ -28,6 +29,7 @@ export class ModelAssessment extends React.Component {
| "requestMatrix"
| "requestImportances"
| "requestCausalWhatIf"
| "requestBoxPlotDistribution"
> = {};
if (config.baseUrl) {
callBack.requestExp = async (data: number): Promise<any[]> => {
Expand Down Expand Up @@ -61,6 +63,14 @@ export class ModelAssessment extends React.Component {
abortSignal
);
};
callBack.requestBoxPlotDistribution = async (
data: any
): Promise<IHighchartBoxData> => {
return callFlaskService(
data,
"/model_overview_probability_distribution"
);
};
}

return (
Expand Down
1 change: 1 addition & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export * from "./lib/Interfaces/ICohort";
export * from "./lib/Interfaces/IPreBuiltCohort";
export * from "./lib/Interfaces/IErrorAnalysisData";
export * from "./lib/Interfaces/IDataBalanceMeasures";
export * from "./lib/Interfaces/IHighchartBoxData";
export * from "./lib/Interfaces/IMetaData";
export * from "./lib/Interfaces/TextExplanationInterfaces";
export * from "./lib/Interfaces/VisionExplanationInterfaces";
Expand Down
5 changes: 5 additions & 0 deletions libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { ICounterfactualData } from "../Interfaces/ICounterfactualData";
import { IDataset } from "../Interfaces/IDataset";
import { IErrorAnalysisData } from "../Interfaces/IErrorAnalysisData";
import { IExplanationModelMetadata } from "../Interfaces/IExplanationContext";
import { IHighchartBoxData } from "../Interfaces/IHighchartBoxData";
import { IModelExplanationData } from "../Interfaces/IModelExplanationData";
import { ITelemetryEvent } from "../util/ITelemetryEvent";
import { JointDataset } from "../util/JointDataset";
Expand Down Expand Up @@ -54,6 +55,10 @@ export interface IModelAssessmentContext {
explanationAlgorithm?: string
) => Promise<any[]>)
| undefined;
requestBoxPlotDistribution?: (
request: any,
abortSignal: AbortSignal
) => Promise<IHighchartBoxData>;
requestExp?:
| ((index: number, abortSignal: AbortSignal) => Promise<any[]>)
| undefined;
Expand Down
224 changes: 223 additions & 1 deletion libs/core-ui/src/lib/util/CalculateBoxPlot.test.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,229 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { calculateBoxPlotData, getPercentile } from "./calculateBoxData";
import { RangeTypes } from "@responsible-ai/mlchartlib";

import { IExplanationModelMetadata } from "../Interfaces/IExplanationContext";

import {
calculateBoxPlotData,
calculateBoxPlotDataFromErrorCohort,
calculateBoxPlotDataFromSDK,
getPercentile
} from "./calculateBoxData";
import { JointDataset } from "./JointDataset";
import { ColumnCategories } from "./JointDatasetUtils";

let mockRequestBoxPlotDistribution: jest.Mock;
const featureIsCategorical = [
false,
false,
false,
true,
false,
false,
false,
false,
false,
false,
false,
false,
false
];
const modelMetadata = {
classNames: ["Class 0"],
featureIsCategorical,
featureNames: [],
featureNamesAbridged: [],
featureRanges: [],
modelType: "binary"
} as IExplanationModelMetadata;
const jointDataset = new JointDataset({
dataset: [],
metadata: modelMetadata
});
jointDataset.metaDict = {
Age: {
abbridgedLabel: "age",
category: ColumnCategories.Dataset,
featureRange: { max: 64, min: 18, rangeType: RangeTypes.Integer },
index: 0,
isCategorical: false,
label: "age"
},
ClassificationOutcome: {
abbridgedLabel: "Classification outcome",
category: ColumnCategories.Outcome,
isCategorical: true,
label: "Classification outcome",
sortedCategoricalValues: [
"True negative",
"False positive",
"False negative",
"True positive"
],
treatAsCategorical: true
},
Index: {
abbridgedLabel: "Index",
category: ColumnCategories.Index,
featureRange: { max: 47, min: 0, rangeType: RangeTypes.Integer },
isCategorical: false,
label: "Index"
},
PredictedY: {
abbridgedLabel: "Predicted Y",
category: ColumnCategories.Outcome,
isCategorical: true,
label: "Predicted Y",
sortedCategoricalValues: ["<=50K", ">50K"],
treatAsCategorical: true
},
TrueY: {
abbridgedLabel: "True Y",
category: ColumnCategories.Outcome,
isCategorical: true,
label: "True Y",
sortedCategoricalValues: ["<=50K", ">50K"],
treatAsCategorical: true
}
};

describe("calculateBoxPlotDataFromErrorCohort", () => {
it.each`
filters | expectedResult
${[]} | ${{ high: 0, low: 0, median: 0, q1: 0, q3: 0 }}
${[{ arg: [62], column: "Age", method: "less" }]} | ${{ high: 1, low: 1, median: 1, q1: 1, q3: 1 }}
${[{ arg: [62], column: "Age", method: "less" }, { arg: [6], column: "Index", method: "less" }]} | ${{ high: 100000, low: 100000, median: 100000, q1: 100000, q3: 100000 }}
${[{ arg: [6], column: "Index", method: "less" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [1], column: "PredictedY", method: "includes" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [0], column: "TrueY", method: "includes" }]} | ${{ high: 2, low: 0, median: 1, q1: 0, q3: 2 }}
${[{ arg: ["True negative", "False positive"], column: "ClassificationOutcome", method: "includes" }]} | ${{ high: 10, low: -10, median: 0, q1: -5, q3: 5 }}
`(
"should return correct box values from SDK",
async ({ filters, expectedResult }) => {
mockRequestBoxPlotDistribution = jest
.fn()
.mockReturnValue(expectedResult);
const boxPlotData = await calculateBoxPlotDataFromErrorCohort(
{
cohort: {
cachedAverageImportance: [],
cachedTransposedLocalFeatureImportances: [],
cohortIndex: 1,
compositeFilters: [],
currentSortKey: "Index",
currentSortReversed: false,
filteredData: [],
filters,
name: "Cohort Classification outcome"
},
jointDataset
},
0,
"",
"0",
mockRequestBoxPlotDistribution
);
expect(boxPlotData?.high).toEqual(expectedResult.high);
expect(boxPlotData?.q3).toEqual(expectedResult.q3);
expect(boxPlotData?.median).toEqual(expectedResult.median);
expect(boxPlotData?.q1).toEqual(expectedResult.q1);
expect(boxPlotData?.low).toEqual(expectedResult.low);
expect(mockRequestBoxPlotDistribution).toHaveBeenCalled();
}
);

it.each`
filters | expectedResult
${[]} | ${{ high: 0, low: 0, median: 0, q1: 0, q3: 0 }}
${[{ arg: [62], column: "Age", method: "less" }]} | ${{ high: 1, low: 1, median: 1, q1: 1, q3: 1 }}
${[{ arg: [62], column: "Age", method: "less" }, { arg: [6], column: "Index", method: "less" }]} | ${{ high: 100000, low: 100000, median: 100000, q1: 100000, q3: 100000 }}
${[{ arg: [6], column: "Index", method: "less" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [1], column: "PredictedY", method: "includes" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [0], column: "TrueY", method: "includes" }]} | ${{ high: 2, low: 0, median: 1, q1: 0, q3: 2 }}
${[{ arg: ["True negative", "False positive"], column: "ClassificationOutcome", method: "includes" }]} | ${{ high: 10, low: -10, median: 0, q1: -5, q3: 5 }}
`(
"should return correct box values from UI when requestBoxPlotDistribution is undefined",
async ({ filters, expectedResult }) => {
mockRequestBoxPlotDistribution = jest
.fn()
.mockReturnValue(expectedResult);
const boxPlotData = await calculateBoxPlotDataFromErrorCohort(
{
cohort: {
cachedAverageImportance: [],
cachedTransposedLocalFeatureImportances: [],
cohortIndex: 1,
compositeFilters: [],
currentSortKey: "Index",
currentSortReversed: false,
filteredData: [
{
Age: 67,
ClassificationOutcome: 3,
Index: 0,
PredictedY: 1,
ProbabilityClass0: 0.7510962272030672,
TrueY: 1
}
],
filters,
name: "Cohort Classification outcome"
},
jointDataset
},
0,
"ProbabilityClass0"
);
expect(boxPlotData).toBeDefined();
expect(mockRequestBoxPlotDistribution).not.toHaveBeenCalled();
}
);
});

describe("calculateBoxPlotDataFromSDK", () => {
it.each`
filters | expectedResult
${[]} | ${{ high: 0, low: 0, median: 0, q1: 0, q3: 0 }}
${[{ arg: [62], column: "Age", method: "less" }]} | ${{ high: 1, low: 1, median: 1, q1: 1, q3: 1 }}
${[{ arg: [62], column: "Age", method: "less" }, { arg: [6], column: "Index", method: "less" }]} | ${{ high: 100000, low: 100000, median: 100000, q1: 100000, q3: 100000 }}
${[{ arg: [6], column: "Index", method: "less" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [1], column: "PredictedY", method: "includes" }]} | ${{ high: 1, low: 0, median: 0.5, q1: 0, q3: 1 }}
${[{ arg: [0], column: "TrueY", method: "includes" }]} | ${{ high: 2, low: 0, median: 1, q1: 0, q3: 2 }}
${[{ arg: ["True negative", "False positive"], column: "ClassificationOutcome", method: "includes" }]} | ${{ high: 10, low: -10, median: 0, q1: -5, q3: 5 }}
`(
"should return correct box values from SDK",
async ({ filters, expectedResult }) => {
mockRequestBoxPlotDistribution = jest
.fn()
.mockReturnValue(expectedResult);
const boxPlotData = await calculateBoxPlotDataFromSDK(
{
cohort: {
cachedAverageImportance: [],
cachedTransposedLocalFeatureImportances: [],
cohortIndex: 1,
compositeFilters: [],
currentSortKey: "Index",
currentSortReversed: false,
filteredData: [],
filters,
name: "Cohort Classification outcome"
},
jointDataset
},
mockRequestBoxPlotDistribution,
"0"
);
expect(boxPlotData.high).toEqual(expectedResult.high);
expect(boxPlotData.q3).toEqual(expectedResult.q3);
expect(boxPlotData.median).toEqual(expectedResult.median);
expect(boxPlotData.q1).toEqual(expectedResult.q1);
expect(boxPlotData.low).toEqual(expectedResult.low);
}
);
});

describe("calculateBoxPlot", () => {
it.each`
Expand Down
48 changes: 46 additions & 2 deletions libs/core-ui/src/lib/util/calculateBoxData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,62 @@
import { ErrorCohort } from "../Cohort/ErrorCohort";
import { IHighchartBoxData } from "../Interfaces/IHighchartBoxData";

export function calculateBoxPlotDataFromErrorCohort(
export async function calculateBoxPlotDataFromErrorCohort(
errorCohort: ErrorCohort,
index: number,
key: string
key: string,
queryClass?: string,
requestBoxPlotDistribution?: (
request: any,
abortSignal: AbortSignal
) => Promise<IHighchartBoxData>
) {
if (requestBoxPlotDistribution) {
return await calculateBoxPlotDataFromSDK(
errorCohort,
requestBoxPlotDistribution,
queryClass
);
}
// key is the identifier for the column (e.g., probability)
// If compute instance is not connected, calculate based on the first 5k data
return calculateBoxPlotData(
errorCohort.cohort.filteredData.map((dict) => dict[key]),
index
);
}

export async function calculateBoxPlotDataFromSDK(
errorCohort: ErrorCohort,
requestBoxPlotDistribution: (
request: any,
abortSignal: AbortSignal
) => Promise<IHighchartBoxData>,
queryClass?: string
): Promise<IHighchartBoxData> {
const filtersRelabeled = ErrorCohort.getLabeledFilters(
errorCohort.cohort.filters,
errorCohort.jointDataset
);

const compositeFiltersRelabeled = ErrorCohort.getLabeledCompositeFilters(
errorCohort.cohort.compositeFilters,
errorCohort.jointDataset
);
const data = [
filtersRelabeled,
compositeFiltersRelabeled,
Number(queryClass)
];

const result: IHighchartBoxData = await requestBoxPlotDistribution?.(
data,
new AbortController().signal
);

return result;
}

export function calculateBoxPlotData(
data: number[],
index?: number
Expand Down
Loading

0 comments on commit 16e307f

Please sign in to comment.