8.3 点乘(Vector Dot Product)
这部分函数主要用于点乘,公式描述如下:
sum = pSrcA[0]*pSrcB[0] + pSrcA[1]*pSrcB[1] + ... + pSrcA[blockSize-1]*pSrcB[blockSize-1]
8.3.1 arm_dot_prod_f32
这个函数用于求32位浮点数的点乘,源代码分析如下:
void arm_dot_prod_f32(
float32_t * pSrcA,
float32_t * pSrcB,
uint32_t blockSize,
float32_t * result)
{
float32_t sum = 0.0f; (1)
uint32_t blkCnt;
#ifndef ARM_MATH_CM0_FAMILY
blkCnt = blockSize >> 2u;
while(blkCnt > 0u)
{
sum += (*pSrcA++) * (*pSrcB++); (2)
sum += (*pSrcA++) * (*pSrcB++);
sum += (*pSrcA++) * (*pSrcB++);
sum += (*pSrcA++) * (*pSrcB++);
blkCnt--;
}
blkCnt = blockSize % 0x4u;
#else
blkCnt = blockSize;
#endif
while(blkCnt > 0u)
{
sum += (*pSrcA++) * (*pSrcB++);
blkCnt--;
}
*result = sum;
}
1. 由于CM4上带的FPU是单精度的,所以初始化float32_t类型的浮点数时需要在数据的末尾加上f。
2. 类似函数sum += (*pSrcA++) * (*pSrcB++)最终会通过浮点的MAC(乘累加)实现,从而加快执行时间。
8.3.2 arm_dot_prod_q31
这个函数用于求32位定点数的点乘,源代码分析如下:
void arm_dot_prod_q31(
q31_t * pSrcA,
q31_t * pSrcB,
uint32_t blockSize,
q63_t * result)
{
q63_t sum = 0;
uint32_t blkCnt;
#ifndef ARM_MATH_CM0_FAMILY
q31_t inA1, inA2, inA3, inA4;
q31_t inB1, inB2, inB3, inB4;
blkCnt = blockSize >> 2u;
while(blkCnt > 0u)
{
inA1 = *pSrcA++;
inA2 = *pSrcA++;
inA3 = *pSrcA++;
inA4 = *pSrcA++;
inB1 = *pSrcB++;
inB2 = *pSrcB++;
inB3 = *pSrcB++;
inB4 = *pSrcB++;
sum += ((q63_t) inA1 * inB1) >> 14u; (2)
sum += ((q63_t) inA2 * inB2) >> 14u;
sum += ((q63_t) inA3 * inB3) >> 14u;
sum += ((q63_t) inA4 * inB4) >> 14u;
blkCnt--;
}
blkCnt = blockSize % 0x4u;
#else
blkCnt = blockSize;
#endif
while(blkCnt > 0u)
{
sum += ((q63_t) * pSrcA++ * *pSrcB++) >> 14u;
blkCnt--;
}
*result = sum;
}
1. 两个Q31格式的32位数相乘,那么输出结果的格式是1.31*1.31 = 2.62。实际应用中基本不需要这么高的精度,这个函数将低14位的数据截取掉,反应在函数中就是两个数的乘积左移14位,也就是定点数的小数点也左移14位,那么最终的结果的格式是16.48。所以只要乘累加的个数小于2^16就没有输出结果溢出的危险(不知道这里为什么不是2^14,留作以后解决)。
2. 将获取的结果左移14位。
8.3.3 arm_dot_prod_q15
这个函数用于求16位定点数的点乘,源代码分析如下:
void arm_dot_prod_q15(
q15_t * pSrcA,
q15_t * pSrcB,
uint32_t blockSize,
q63_t * result)
{
q63_t sum = 0;
uint32_t blkCnt;
#ifndef ARM_MATH_CM0_FAMILY
blkCnt = blockSize >> 2u;
while(blkCnt > 0u)
{
(2)
sum = __SMLALD(*__SIMD32(pSrcA)++, *__SIMD32(pSrcB)++, sum);
sum = __SMLALD(*__SIMD32(pSrcA)++, *__SIMD32(pSrcB)++, sum);
blkCnt--;
}
blkCnt = blockSize % 0x4u;
while(blkCnt > 0u)
{
sum = __SMLALD(*pSrcA++, *pSrcB++, sum);
blkCnt--;
}
#else
blkCnt = blockSize;
while(blkCnt > 0u)
{
sum += (q63_t) ((q31_t) * pSrcA++ * *pSrcB++);
blkCnt--;
}
#endif
*result = sum;
}
1. 两个Q15格式的数据相乘,那么输出结果的格式是1.15*1.15 = 2.30,这个函数将输出结果赋值给了64位变量,那么输出结果就是34.30格式。所以基本没有溢出的危险。
2. __SMLALD也是SIMD指令,实现两个16位数相乘,并把结果累加给64位变量。
8.3.4 arm_dot_prod_q7
这个函数用于求8位定点数的点乘,源代码分析如下:
void arm_dot_prod_q7(
q7_t * pSrcA,
q7_t * pSrcB,
uint32_t blockSize,
q31_t * result)
{
uint32_t blkCnt;
q31_t sum = 0;
#ifndef ARM_MATH_CM0_FAMILY
q31_t input1, input2;
q31_t inA1, inA2, inB1, inB2;
blkCnt = blockSize >> 2u;
while(blkCnt > 0u)
{
(2)
input1 = *__SIMD32(pSrcA)++;
input2 = *__SIMD32(pSrcB)++;
inA1 = __SXTB16(__ROR(input1, 8)); (3)
inA2 = __SXTB16(input1);
inB1 = __SXTB16(__ROR(input2, 8));
inB2 = __SXTB16(input2);
sum = __SMLAD(inA1, inB1, sum); (4)
sum = __SMLAD(inA2, inB2, sum);
blkCnt--;
}
blkCnt = blockSize % 0x4u;
while(blkCnt > 0u)
{
sum = __SMLAD(*pSrcA++, *pSrcB++, sum);
blkCnt--;
}
#else
blkCnt = blockSize;
while(blkCnt > 0u)
{
sum += (q31_t) ((q15_t) * pSrcA++ * *pSrcB++);
blkCnt--;
}
#endif
*result = sum;
}
1. 两个Q8格式的数据相乘,那么输出结果就是1.7*1.7 = 2.14格式。这里将最终结果赋值给了32位的变量,那么最终的格式就是18.14。如果乘累加的个数小于2^18那么就不会有溢出的危险(感觉这里应该是2^16)。
2. 一次读取4个8位的数据。
3. __SXTB16也是SIMD指令,用于将两个8位的有符号数扩展成16位。__ROR用于实现数据的循环右移。
4. __SMLAD也是SIMD指令,用于实现如下功能:
sum = __SMLAD(x, y, z)
sum = z + ((short)(x>>16) * (short)(y>>16)) + ((short)x * (short)y)
8.3.5 实例讲解
实验目的:
1. 四种类型数据的点乘。
实验内容:
1. 按下按键K3, 串口打印输出结果
实验现象:
通过窗口上位机软件SecureCRT(V5光盘里面有此软件)查看打印信息现象如下:

程序设计:
static void DSP_DotProduct(void)
{
static float32_t pSrcA[5] = {1.0f,1.0f,1.0f,1.0f,1.0f};
static float32_t pSrcB[5] = {1.0f,1.0f,1.0f,1.0f,1.0f};
static float32_t result;
static q31_t pSrcA1[5] = {0x7ffffff0,1,1,1,1};
static q31_t pSrcB1[5] = {1,1,1,1,1};
static q63_t result1;
static q15_t pSrcA2[5] = {1,1,1,1,1};
static q15_t pSrcB2[5] = {1,1,1,1,1};
static q63_t result2;
static q7_t pSrcA3[5] = {1,1,1,1,1};
static q7_t pSrcB3[5] = {1,1,1,1,1};
static q31_t result3;
pSrcA[0] -= 1.1f;
arm_dot_prod_f32(pSrcA, pSrcB, 5, &result);
printf("arm_dot_prod_f32 = %f\r\n", result);
pSrcA1[0] -= 0xffff;
arm_dot_prod_q31(pSrcA1, pSrcB1, 5, &result1);
printf("arm_dot_prod_q31 = %lld\r\n", result1);
pSrcA2[0] -= 1;
arm_dot_prod_q15(pSrcA2, pSrcB2, 5, &result2);
printf("arm_dot_prod_q15 = %lld\r\n", result2);
pSrcA3[0] -= 1;
arm_dot_prod_q7(pSrcA3, pSrcB3, 5, &result3);
printf("arm_dot_prod_q7 = %d\r\n", result3);
printf("***********************************\r\n");
}