Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object Detection Model Overview Flask template #2004

Merged
merged 9 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
const callBack: Pick<
IModelAssessmentDashboardProps,
| "requestExp"
| "requestObjectDetectionMetrics"
| "requestPredictions"
| "requestDebugML"
| "requestMatrix"
Expand All @@ -54,6 +55,18 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
callBack.requestExp = async (data: number): Promise<any[]> => {
return callFlaskService(this.props.config, data, "/get_exp");
};
callBack.requestObjectDetectionMetrics = async (
trueY: number[][][],
predictedY: number[][][],
aggregate_method: string,
class_name: string,
iou_thresh: number
): Promise<any[]> => {
return callFlaskService(
this.props.config,
[trueY, predictedY, aggregate_method, class_name, iou_thresh],
"/get_object_detection_metrics");
};
callBack.requestPredictions = async (data: any[]): Promise<any[]> => {
return callFlaskService(this.props.config, data, "/predict");
};
Expand Down
3 changes: 3 additions & 0 deletions erroranalysis/erroranalysis/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Metrics(str, Enum):
FALSE_POSITIVE_RATE = 'false_positive_rate'
FALSE_NEGATIVE_RATE = 'false_negative_rate'
SELECTION_RATE = 'selection_rate'
MEAN_AVERAGE_PRECISION = 'mean_average_precision'
AVERAGE_PRECISION = 'average_precision'
AVERAGE_RECALL = 'average_recall'


class MetricKeys(str, Enum):
Expand Down
9 changes: 9 additions & 0 deletions libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ export interface IModelAssessmentContext {
requestExp?:
| ((index: number, abortSignal: AbortSignal) => Promise<any[]>)
| undefined;
requestObjectDetectionMetrics?:
| ((
trueY: number[][][],
predictedY: number[][][],
aggregate_method: string,
class_name: string,
iou_thresh: number
) => Promise<any[]>)
| undefined;
requestSplinePlotDistribution?: (
request: any,
abortSignal: AbortSignal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ export interface IVisionExplanationDashboardProps {
cohorts: ErrorCohort[];
dataSummary: IVisionExplanationDashboardData;
requestExp?: (index: number, abortSignal: AbortSignal) => Promise<any[]>;
requestObjectDetectionMetrics?: (
trueY: number[][][],
predictedY: number[][][],
aggregate_method: string,
class_name: string,
iou_thresh: number
) => Promise<any[]>;
selectedCohort: ErrorCohort;
setSelectedCohort: (cohort: ErrorCohort) => void;
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ export class TabsView extends React.PureComponent<
true_y: this.props.dataset.true_y
}}
requestExp={this.props.requestExp}
requestObjectDetectionMetrics={this.props.requestObjectDetectionMetrics}
cohorts={this.props.cohorts}
setSelectedCohort={this.props.setSelectedCohort}
selectedCohort={this.props.selectedCohort}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ export interface ITabsViewProps {
dataset: IDataset;
onClearCohortSelectionClick: () => void;
requestExp?: (index: number, abortSignal: AbortSignal) => Promise<any[]>;
requestObjectDetectionMetrics?: (
trueY: number[][][],
predictedY: number[][][],
aggregate_method: string,
class_name: string,
iou_thresh: number
) => Promise<any[]>;
requestPredictions?: (
request: any[],
abortSignal: AbortSignal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export class ModelAssessmentDashboard extends CohortBasedComponent<
requestDatasetAnalysisBoxChart:
this.props.requestDatasetAnalysisBoxChart,
requestExp: this.props.requestExp,
requestObjectDetectionMetrics: this.props.requestObjectDetectionMetrics,
requestForecast: this.props.requestForecast,
requestGlobalCausalEffects: this.props.requestGlobalCausalEffects,
requestGlobalCausalPolicy: this.props.requestGlobalCausalPolicy,
Expand Down Expand Up @@ -129,6 +130,7 @@ export class ModelAssessmentDashboard extends CohortBasedComponent<
dataset={this.props.dataset}
>
requestExp={this.props.requestExp}
requestObjectDetectionMetrics={this.props.requestObjectDetectionMetrics}
requestPredictions={this.props.requestPredictions}
requestDebugML={this.props.requestDebugML}
requestImportances={this.props.requestImportances}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ export interface IModelAssessmentDashboardProps
abortSignal: AbortSignal
) => Promise<any>;
requestExp?: (index: number, abortSignal: AbortSignal) => Promise<any[]>;
requestObjectDetectionMetrics?: (
trueY: number[][][],
predictedY: number[][][],
aggregate_method: string,
class_name: string,
iou_thresh: number
) => Promise<any[]>;
requestBubblePlotData?: (
filter: unknown[],
compositeFilter: unknown[],
Expand Down
5 changes: 5 additions & 0 deletions raiwidgets/raiwidgets/responsibleai_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,8 @@ def get_exp():
data = request.get_json(force=True)
return jsonify(self.input.get_exp(data))
self.add_url_rule(get_exp, '/get_exp', methods=["POST"])

def get_object_detection_metrics():
data = request.get_json(force=True)
return jsonify(self.input.get_object_detection_metrics(data))
self.add_url_rule(get_exp, '/get_object_detection_metrics', methods=["POST"])
22 changes: 22 additions & 0 deletions raiwidgets/raiwidgets/responsibleai_dashboard_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,25 @@ def get_global_causal_policy(self, post_data):
"inner error: {}".format(e_str),
WidgetRequestResponseConstants.data: []
}

def get_object_detection_metrics(self, post_data):
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
try:
true_y = post_data[0]
predicted_y = post_data[1]
aggregate_method = post_data[2]
class_name = post_data[3]
iou_thresh = post_data[4]
exp = self._analysis.compute_object_detection_metrics(true_y, predicted_y, aggregate_method, class_name, iou_thresh)
return {
WidgetRequestResponseConstants.data: exp
}
except Exception as e:
print(e)
traceback.print_exc()
e_str = _format_exception(e)
return {
WidgetRequestResponseConstants.error:
"Failed to get OD MO metrics,"
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
"inner error: {}".format(e_str),
WidgetRequestResponseConstants.data: []
}