Merge "st-hal: Support different mic config on LPI and Non-LPI"
diff --git a/sml_model_parser.h b/sml_model_parser.h
index ee154c0..8691709 100644
--- a/sml_model_parser.h
+++ b/sml_model_parser.h
@@ -90,6 +90,8 @@
ST_SM_ID_SVA_GMM = 0x0001,
ST_SM_ID_SVA_CNN = 0x0002,
ST_SM_ID_SVA_VOP = 0x0004,
+ ST_SM_ID_SVA_RNN = 0x0008,
+ ST_SM_ID_SVA_KWD = 0x000A, //ST_SM_ID_SVA_CNN | ST_SM_ID_SVA_RNN
ST_SM_ID_SVA_END = 0x00F0,
ST_SM_ID_CUSTOM_START = 0x0100,
ST_SM_ID_CUSTOM_END = 0xF000,
diff --git a/st_hw_session_lsm.c b/st_hw_session_lsm.c
index 0276d09..a49adb5 100644
--- a/st_hw_session_lsm.c
+++ b/st_hw_session_lsm.c
@@ -2751,8 +2751,8 @@
if (param_tag_tracker & PARAM_OPERATION_MODE_BIT) {
op_params = ¶m_info[param_count++];
- /* CNN supports only keyword detection */
- if (ss_cfg->params->common_params.sm_id == ST_SM_ID_SVA_CNN)
+ /* CNN and RNN only support keyword detection */
+ if (ss_cfg->params->common_params.sm_id & ST_SM_ID_SVA_KWD)
det_mode.mode = LSM_MODE_KEYWORD_ONLY_DETECTION;
op_params->param_size = sizeof(det_mode);
diff --git a/st_second_stage.c b/st_second_stage.c
index f27b045..6a5c0ca 100644
--- a/st_second_stage.c
+++ b/st_second_stage.c
@@ -77,7 +77,7 @@
stream_input->buf_ptr->data_ptr = (int8_t *)frame;
ALOGV("%s: Issuing capi_process", __func__);
- ATRACE_BEGIN("sthal:second_stage: process keyword detection (CNN)");
+ ATRACE_BEGIN("sthal:second_stage: process keyword detection (CNN/RNN)");
rc = ss_session->capi_handle->vtbl_ptr->process(ss_session->capi_handle,
&stream_input, NULL);
ATRACE_END();
@@ -211,9 +211,9 @@
exit:
/*
- * The CNN algorithm doesn't set reject because it is continuously called
- * until the keyword has passed. So if a detection success has not been
- * declared inside the above loop, it is set to detection reject.
+ * The CNN/RNN algorithm doesn't set reject because it is continuously
+ * called until the keyword has passed. So if a detection success has not
+ * been declared inside the above loop, it is set to detection reject.
*/
pthread_mutex_unlock(&ss_session->lock);
pthread_mutex_lock(&ss_session->st_ses->ss_detections_lock);
@@ -578,7 +578,8 @@
capi_buf.max_data_len = sizeof(sva_threshold_config_t);
threshold_cfg = (sva_threshold_config_t *)capi_buf.data_ptr;
threshold_cfg->smm_threshold = ss_session->confidence_threshold;
- ALOGD("%s: Keyword detection (CNN) confidence level = %d", __func__,
+ ALOGD("%s: Keyword detection %s confidence level = %d", __func__,
+ st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN ? "(CNN)" : "(RNN)",
ss_session->confidence_threshold);
ALOGV("%s: Issuing capi_set_param for param %d", __func__,
diff --git a/st_session.c b/st_session.c
index f7b0617..31f415b 100644
--- a/st_session.c
+++ b/st_session.c
@@ -65,6 +65,17 @@
#define IS_SS_DETECTION_SUCCESS(det)\
!(det & (KEYWORD_DETECTION_REJECT | USER_VERIFICATION_REJECT))
+#define IS_KEYWORD_DETECTION_MODEL(sm_id) (sm_id & ST_SM_ID_SVA_KWD)
+
+#define IS_USER_VERIFICATION_MODEL(sm_id) (sm_id & ST_SM_ID_SVA_VOP)
+
+#define IS_SECOND_STAGE_MODEL(sm_id)\
+ ((sm_id & ST_SM_ID_SVA_KWD) || (sm_id & ST_SM_ID_SVA_VOP))
+
+#define IS_MATCHING_SS_MODEL(usecase_sm_id, levels_sm_id)\
+ ((usecase_sm_id & levels_sm_id) ||\
+ ((usecase_sm_id & ST_SM_ID_SVA_RNN) && (levels_sm_id & ST_SM_ID_SVA_CNN)))
+
#define STATE_TRANSITION(st_session, new_state_fn)\
do {\
if (st_session->current_state != new_state_fn) {\
@@ -2033,16 +2044,16 @@
(void *)sm_levels, out_conf_levels, out_num_conf_levels,
stc_ses->conf_levels_intf_version);
gmm_conf_found = true;
- } else if ((sm_levels->sm_id == ST_SM_ID_SVA_CNN) ||
- (sm_levels->sm_id == ST_SM_ID_SVA_VOP)) {
- confidence_level = (sm_levels->sm_id == ST_SM_ID_SVA_CNN) ?
+ } else if (IS_SECOND_STAGE_MODEL(sm_levels->sm_id)) {
+ confidence_level = IS_KEYWORD_DETECTION_MODEL(sm_levels->sm_id) ?
sm_levels->kw_levels[0].kw_level:
sm_levels->kw_levels[0].user_levels[0].level;
if (arm_second_stage) {
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t,
list_node);
- if (st_sec_stage->ss_info->sm_id == sm_levels->sm_id)
+ if (IS_MATCHING_SS_MODEL(st_sec_stage->ss_info->sm_id,
+ sm_levels->sm_id))
st_sec_stage->ss_session->confidence_threshold =
confidence_level;
}
@@ -2050,7 +2061,8 @@
list_for_each(node, &st_hw_ses->lsm_ss_cfg_list) {
ss_cfg = node_to_item(node, st_lsm_ss_config_t,
list_node);
- if (ss_cfg->ss_info->sm_id == sm_levels->sm_id)
+ if (IS_MATCHING_SS_MODEL(ss_cfg->ss_info->sm_id,
+ sm_levels->sm_id))
ss_cfg->confidence_threshold = confidence_level;
}
}
@@ -2090,18 +2102,17 @@
(void *)sm_levels_v2, out_conf_levels,
out_num_conf_levels, stc_ses->conf_levels_intf_version);
gmm_conf_found = true;
- } else if ((sm_levels_v2->sm_id == ST_SM_ID_SVA_CNN) ||
- (sm_levels_v2->sm_id == ST_SM_ID_SVA_VOP)) {
+ } else if (IS_SECOND_STAGE_MODEL(sm_levels_v2->sm_id)) {
confidence_level_v2 =
- (sm_levels_v2->sm_id == ST_SM_ID_SVA_CNN) ?
+ (IS_KEYWORD_DETECTION_MODEL(sm_levels_v2->sm_id)) ?
sm_levels_v2->kw_levels[0].kw_level:
sm_levels_v2->kw_levels[0].user_levels[0].level;
if (arm_second_stage) {
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t,
list_node);
- if (st_sec_stage->ss_info->sm_id ==
- sm_levels_v2->sm_id)
+ if (IS_MATCHING_SS_MODEL(st_sec_stage->ss_info->sm_id,
+ sm_levels_v2->sm_id))
st_sec_stage->ss_session->confidence_threshold =
confidence_level_v2;
}
@@ -2109,7 +2120,8 @@
list_for_each(node, &st_hw_ses->lsm_ss_cfg_list) {
ss_cfg = node_to_item(node, st_lsm_ss_config_t,
list_node);
- if (ss_cfg->ss_info->sm_id == sm_levels_v2->sm_id)
+ if (IS_MATCHING_SS_MODEL(ss_cfg->ss_info->sm_id,
+ sm_levels_v2->sm_id))
ss_cfg->confidence_threshold = confidence_level_v2;
}
}
@@ -2864,9 +2876,9 @@
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t, list_node);
- if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN) {
+ if (IS_KEYWORD_DETECTION_MODEL(st_sec_stage->ss_info->sm_id)) {
kw_level = st_sec_stage->ss_session->confidence_score;
- } else if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_VOP) {
+ } else if (IS_USER_VERIFICATION_MODEL(st_sec_stage->ss_info->sm_id)) {
user_level = st_sec_stage->ss_session->confidence_score;
}
}
@@ -2897,9 +2909,9 @@
payload_size, user_id);
}
}
- } else if (conf_levels->conf_levels[i].sm_id == ST_SM_ID_SVA_CNN) {
+ } else if (IS_KEYWORD_DETECTION_MODEL(conf_levels->conf_levels[i].sm_id)) {
conf_levels->conf_levels[i].kw_levels[0].kw_level = kw_level;
- } else if (conf_levels->conf_levels[i].sm_id == ST_SM_ID_SVA_VOP) {
+ } else if (IS_USER_VERIFICATION_MODEL(conf_levels->conf_levels[i].sm_id)) {
/*
* Fill both the keyword and user confidence level with the
* confidence score returned from the voiceprint algorithm.
@@ -2937,11 +2949,9 @@
payload_size, user_id);
}
}
- } else if (conf_levels_v2->conf_levels[i].sm_id ==
- ST_SM_ID_SVA_CNN) {
+ } else if (IS_KEYWORD_DETECTION_MODEL(conf_levels_v2->conf_levels[i].sm_id)) {
conf_levels_v2->conf_levels[i].kw_levels[0].kw_level = kw_level;
- } else if (conf_levels_v2->conf_levels[i].sm_id ==
- ST_SM_ID_SVA_VOP) {
+ } else if (IS_USER_VERIFICATION_MODEL(conf_levels_v2->conf_levels[i].sm_id)) {
/*
* Fill both the keyword and user confidence level with the
* confidence score returned from the voiceprint algorithm.
@@ -3002,10 +3012,10 @@
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t, list_node);
- if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN) {
+ if (IS_KEYWORD_DETECTION_MODEL(st_sec_stage->ss_info->sm_id)) {
local_event->phrase_extras[0].confidence_level =
(uint8_t)st_sec_stage->ss_session->confidence_score;
- } else if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_VOP) {
+ } else if (IS_USER_VERIFICATION_MODEL(st_sec_stage->ss_info->sm_id)) {
local_event->phrase_extras[0].levels[0].level =
(uint8_t)st_sec_stage->ss_session->confidence_score;
}
@@ -3078,7 +3088,7 @@
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t,
list_node);
- if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN) {
+ if (IS_KEYWORD_DETECTION_MODEL(st_sec_stage->ss_info->sm_id)) {
kw_indices->start_index =
st_sec_stage->ss_session->kw_start_idx;
kw_indices->end_index =
@@ -3323,7 +3333,7 @@
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t, list_node);
- if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN) {
+ if (IS_KEYWORD_DETECTION_MODEL(st_sec_stage->ss_info->sm_id)) {
enable_kw_indices = true;
opaque_size += sizeof(struct st_param_header) +
sizeof(struct st_keyword_indices_info);
@@ -3393,7 +3403,7 @@
list_for_each(node, &stc_ses->second_stage_list) {
st_sec_stage = node_to_item(node, st_arm_second_stage_t,
list_node);
- if (st_sec_stage->ss_info->sm_id == ST_SM_ID_SVA_CNN) {
+ if (IS_KEYWORD_DETECTION_MODEL(st_sec_stage->ss_info->sm_id)) {
kw_indices->start_index =
st_sec_stage->ss_session->kw_start_idx;
kw_indices->end_index =