starting from tf 1.13 it looks like a native tf.keras precision metric exists. However for tf 1.10, it does not exist. So here is a custom created precision metric function that can be used for tf 1.10. I suppose this approach of creating custom metrics should work in other tf versions that do not have officially supported metrics.

def precision_metric(y_true, y_pred):

    # mod_y_pred = tf.where(y_pred>0.8, 1, 0)
    # mod_y_pred = tf.to_float(y_pred > 0.8)
    mod_y_pred = tf.cast(y_pred > 0.8 , dtype=tf.float32)

    dl.debug("mod_y_pred: {}".format(mod_y_pred))

    # compare_matrix = tf.where(y_true == mod_y_pred, 1, 0)
    compare_matrix = mod_y_pred * y_true

    dl.debug("compare_matrix: {}".format(compare_matrix))

    match_count_matrix = tf.reduce_sum(compare_matrix, axis=1)

    predicted_count_matrix = tf.reduce_sum(mod_y_pred, axis=1)

    eps = 1e-8

    predicted_count_matrix += eps

    precision_matrix = match_count_matrix / predicted_count_matrix

    precision = tf.reduce_mean(precision_matrix)

    dl.debug("precision mean : {}".format(precision))

    return precision

As you can see, the function is a graph builder and not an eagerly function.

Categories: tensorflow


Yasir Hussain · October 17, 2019 at 8:09 am

Thanks for sharing such a good example. I am wondering if we have a multi class problem how we can calculate precision, recall and f-measure? Further, your example seems not using batches approach.

    chadrick_author · October 29, 2019 at 8:50 pm

    the example can be modified to handle multi class problems too. Its just a matter of using the right numpy array operations.
    But the downside is that for one metric function, it can only output one metric value so to calculate precision, recall, and f-score you will need to create three metric functions. Considering that f-score requires precision and recall values to be precalculated, this does force the user to waste computing resources.
    The example does not calculate precision value for each items(or rows) of a given batch. I just meaned it to one value so that it can be used as a single value to represent the overall precision of a batch.

Leave a Reply

Your email address will not be published.