谷歌Metrax为JAX引入了预定义的模型评估指标
Source: InfoQ - Backend
Metrax是一个JAX库,最近由谷歌开源,为分类、回归、自然语言处理(NLP)、视觉和音频模型提供了标准化的性能指标实现。
谷歌解释说,Metrax解决了JAX生态系统中的一个空白,这个空白迫使许多团队从TensorFlow迁移到JAX,以实现他们自己的通用评估指标版本,如准确性、F1分数、RMS误差等:
虽然在某些人看来,创建指标似乎是一个相当简单和直接的话题,但当考虑到跨数据中心规模的分布式计算环境中的大规模训练和评估时,它就变得不那么简单了。
Metrax为一系列机器学习模型提供了预定义的评估度量指标,包括分类、回归、推荐、视觉和音频,特别支持分布式和大规模的训练环境。对于视觉模型,该库包括诸如交并比(IoU)、信噪比(SNR)和结构相似性指数(SSIM)等指标,Metrax还包括鲁棒的NLP相关度量指标,包括困惑度(Perplexity)、BLEU和ROUGE。
谷歌指出,Metrax的目标之一是确保所有度量指标都得到很好的实施并遵循最佳实践。在度量指标定义支持的地方,Metrax使用JAX的高级功能,如vmap和jit来提高性能。例如,这些特性用于实现新的“at K”指标,以支持并行计算多个K值。这使我们能够更全面、更快地评估模型。
你可以使用PrecisionAtK来确定多个K值(比如K=1、K=8和K=20)下模型的精度,所有这些都是在模型的一次前向传递中进行的,而不需要对每个参数多次调用PrecisionAtK。
名为Neural Foundry的DevOps工程师在Substack上写道:
Metrax支持在单次传递中计算多个K值,这对排名系统来说是一个巨大的胜利。我每次切换项目时都需要重写度量工具,这种标准化早就应该实现了。API看起来也很干净。好奇他们是否针对特定用例(如大规模推荐管道)的自定义实现进行了基准测试。
下面的代码片段展示了如何根据预测结果和标签计算精度度量指标。可以指定一个可选的阈值,将概率预测转换为二元预测:
import metrax # 直接计算度量状态。
metric_state = metrax.Precision.from_model_output(
predictions=predictions,
labels=labels,
threshold=0.5
)
# 然后通过调用compute()即可获得结果。
result = metric_state.compute()
result谷歌还发布了一个笔记本,包含了一系列综合示例,包括多设备扩展和与Flax NNX的集成,Flax NNX是一个简化的API,使得在JAX中创建、检查、调试和分析神经网络变得更加容易。
JAX是一个开源的Python库,用于高性能数值计算和机器学习。
原文链接:
https://www.infoq.com/news/2025/12/metrax-jax-evaluation-metrics/