Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/arch/arm64/lib/csum.c
26424 views
1
// SPDX-License-Identifier: GPL-2.0-only
2
// Copyright (C) 2019-2020 Arm Ltd.
3
4
#include <linux/compiler.h>
5
#include <linux/kasan-checks.h>
6
#include <linux/kernel.h>
7
8
#include <net/checksum.h>
9
10
/* Looks dumb, but generates nice-ish code */
11
static u64 accumulate(u64 sum, u64 data)
12
{
13
__uint128_t tmp = (__uint128_t)sum + data;
14
return tmp + (tmp >> 64);
15
}
16
17
/*
18
* We over-read the buffer and this makes KASAN unhappy. Instead, disable
19
* instrumentation and call kasan explicitly.
20
*/
21
unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
22
{
23
unsigned int offset, shift, sum;
24
const u64 *ptr;
25
u64 data, sum64 = 0;
26
27
if (unlikely(len <= 0))
28
return 0;
29
30
offset = (unsigned long)buff & 7;
31
/*
32
* This is to all intents and purposes safe, since rounding down cannot
33
* result in a different page or cache line being accessed, and @buff
34
* should absolutely not be pointing to anything read-sensitive. We do,
35
* however, have to be careful not to piss off KASAN, which means using
36
* unchecked reads to accommodate the head and tail, for which we'll
37
* compensate with an explicit check up-front.
38
*/
39
kasan_check_read(buff, len);
40
ptr = (u64 *)(buff - offset);
41
len = len + offset - 8;
42
43
/*
44
* Head: zero out any excess leading bytes. Shifting back by the same
45
* amount should be at least as fast as any other way of handling the
46
* odd/even alignment, and means we can ignore it until the very end.
47
*/
48
shift = offset * 8;
49
data = *ptr++;
50
#ifdef __LITTLE_ENDIAN
51
data = (data >> shift) << shift;
52
#else
53
data = (data << shift) >> shift;
54
#endif
55
56
/*
57
* Body: straightforward aligned loads from here on (the paired loads
58
* underlying the quadword type still only need dword alignment). The
59
* main loop strictly excludes the tail, so the second loop will always
60
* run at least once.
61
*/
62
while (unlikely(len > 64)) {
63
__uint128_t tmp1, tmp2, tmp3, tmp4;
64
65
tmp1 = *(__uint128_t *)ptr;
66
tmp2 = *(__uint128_t *)(ptr + 2);
67
tmp3 = *(__uint128_t *)(ptr + 4);
68
tmp4 = *(__uint128_t *)(ptr + 6);
69
70
len -= 64;
71
ptr += 8;
72
73
/* This is the "don't dump the carry flag into a GPR" idiom */
74
tmp1 += (tmp1 >> 64) | (tmp1 << 64);
75
tmp2 += (tmp2 >> 64) | (tmp2 << 64);
76
tmp3 += (tmp3 >> 64) | (tmp3 << 64);
77
tmp4 += (tmp4 >> 64) | (tmp4 << 64);
78
tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
79
tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80
tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
81
tmp3 += (tmp3 >> 64) | (tmp3 << 64);
82
tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
83
tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84
tmp1 = ((tmp1 >> 64) << 64) | sum64;
85
tmp1 += (tmp1 >> 64) | (tmp1 << 64);
86
sum64 = tmp1 >> 64;
87
}
88
while (len > 8) {
89
__uint128_t tmp;
90
91
sum64 = accumulate(sum64, data);
92
tmp = *(__uint128_t *)ptr;
93
94
len -= 16;
95
ptr += 2;
96
97
#ifdef __LITTLE_ENDIAN
98
data = tmp >> 64;
99
sum64 = accumulate(sum64, tmp);
100
#else
101
data = tmp;
102
sum64 = accumulate(sum64, tmp >> 64);
103
#endif
104
}
105
if (len > 0) {
106
sum64 = accumulate(sum64, data);
107
data = *ptr;
108
len -= 8;
109
}
110
/*
111
* Tail: zero any over-read bytes similarly to the head, again
112
* preserving odd/even alignment.
113
*/
114
shift = len * -8;
115
#ifdef __LITTLE_ENDIAN
116
data = (data << shift) >> shift;
117
#else
118
data = (data >> shift) << shift;
119
#endif
120
sum64 = accumulate(sum64, data);
121
122
/* Finally, folding */
123
sum64 += (sum64 >> 32) | (sum64 << 32);
124
sum = sum64 >> 32;
125
sum += (sum >> 16) | (sum << 16);
126
if (offset & 1)
127
return (u16)swab32(sum);
128
129
return sum >> 16;
130
}
131
132
__sum16 csum_ipv6_magic(const struct in6_addr *saddr,
133
const struct in6_addr *daddr,
134
__u32 len, __u8 proto, __wsum csum)
135
{
136
__uint128_t src, dst;
137
u64 sum = (__force u64)csum;
138
139
src = *(const __uint128_t *)saddr->s6_addr;
140
dst = *(const __uint128_t *)daddr->s6_addr;
141
142
sum += (__force u32)htonl(len);
143
#ifdef __LITTLE_ENDIAN
144
sum += (u32)proto << 24;
145
#else
146
sum += proto;
147
#endif
148
src += (src >> 64) | (src << 64);
149
dst += (dst >> 64) | (dst << 64);
150
151
sum = accumulate(sum, src >> 64);
152
sum = accumulate(sum, dst >> 64);
153
154
sum += ((sum >> 32) | (sum << 32));
155
return csum_fold((__force __wsum)(sum >> 32));
156
}
157
EXPORT_SYMBOL(csum_ipv6_magic);
158
159