Skip to content

Commit ed413f3

Browse files
FiloSottilegopherbot
authored andcommitted
crypto/internal/fips/mlkem: implement ML-KEM-1024
Decided to automatically duplicate the high-level code to avoid growing the ML-KEM-768 data structures. For #70122 Change-Id: I5c705b71ee1e23adba9113d5cf6b6e505c028967 Reviewed-on: https://go-review.googlesource.com/c/go/+/621983 Auto-Submit: Filippo Valsorda <filippo@golang.org> Reviewed-by: Roland Shoemaker <roland@golang.org> Reviewed-by: Daniel McCarney <daniel@binaryparadox.net> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
1 parent e82308c commit ed413f3

File tree

6 files changed

+740
-19
lines changed

6 files changed

+740
-19
lines changed

src/crypto/internal/fips/mlkem/field.go

+96-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte {
263263
s, b := sliceForAppend(s, encodingSize10)
264264
for i := 0; i < n; i += 4 {
265265
var x uint64
266-
x |= uint64(compress(f[i+0], 10))
266+
x |= uint64(compress(f[i], 10))
267267
x |= uint64(compress(f[i+1], 10)) << 10
268268
x |= uint64(compress(f[i+2], 10)) << 20
269269
x |= uint64(compress(f[i+3], 10)) << 30
@@ -296,6 +296,101 @@ func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
296296
return f
297297
}
298298

299+
// ringCompressAndEncode appends an encoding of a ring element to s,
300+
// compressing each coefficient to d bits.
301+
//
302+
// It implements Compress, according to FIPS 203, Definition 4.7,
303+
// followed by ByteEncode, according to FIPS 203, Algorithm 5.
304+
func ringCompressAndEncode(s []byte, f ringElement, d uint8) []byte {
305+
var b byte
306+
var bIdx uint8
307+
for i := 0; i < n; i++ {
308+
c := compress(f[i], d)
309+
var cIdx uint8
310+
for cIdx < d {
311+
b |= byte(c>>cIdx) << bIdx
312+
bits := min(8-bIdx, d-cIdx)
313+
bIdx += bits
314+
cIdx += bits
315+
if bIdx == 8 {
316+
s = append(s, b)
317+
b = 0
318+
bIdx = 0
319+
}
320+
}
321+
}
322+
if bIdx != 0 {
323+
panic("mlkem: internal error: bitsFilled != 0")
324+
}
325+
return s
326+
}
327+
328+
// ringDecodeAndDecompress decodes an encoding of a ring element where
329+
// each d bits are mapped to an equidistant distribution.
330+
//
331+
// It implements ByteDecode, according to FIPS 203, Algorithm 6,
332+
// followed by Decompress, according to FIPS 203, Definition 4.8.
333+
func ringDecodeAndDecompress(b []byte, d uint8) ringElement {
334+
var f ringElement
335+
var bIdx uint8
336+
for i := 0; i < n; i++ {
337+
var c uint16
338+
var cIdx uint8
339+
for cIdx < d {
340+
c |= uint16(b[0]>>bIdx) << cIdx
341+
c &= (1 << d) - 1
342+
bits := min(8-bIdx, d-cIdx)
343+
bIdx += bits
344+
cIdx += bits
345+
if bIdx == 8 {
346+
b = b[1:]
347+
bIdx = 0
348+
}
349+
}
350+
f[i] = fieldElement(decompress(c, d))
351+
}
352+
if len(b) != 0 {
353+
panic("mlkem: internal error: leftover bytes")
354+
}
355+
return f
356+
}
357+
358+
// ringCompressAndEncode5 appends a 160-byte encoding of a ring element to s,
359+
// compressing eight coefficients per five bytes.
360+
//
361+
// It implements Compress₅, according to FIPS 203, Definition 4.7,
362+
// followed by ByteEncode₅, according to FIPS 203, Algorithm 5.
363+
func ringCompressAndEncode5(s []byte, f ringElement) []byte {
364+
return ringCompressAndEncode(s, f, 5)
365+
}
366+
367+
// ringDecodeAndDecompress5 decodes a 160-byte encoding of a ring element where
368+
// each five bits are mapped to an equidistant distribution.
369+
//
370+
// It implements ByteDecode₅, according to FIPS 203, Algorithm 6,
371+
// followed by Decompress₅, according to FIPS 203, Definition 4.8.
372+
func ringDecodeAndDecompress5(bb *[encodingSize5]byte) ringElement {
373+
return ringDecodeAndDecompress(bb[:], 5)
374+
}
375+
376+
// ringCompressAndEncode11 appends a 352-byte encoding of a ring element to s,
377+
// compressing eight coefficients per eleven bytes.
378+
//
379+
// It implements Compress₁₁, according to FIPS 203, Definition 4.7,
380+
// followed by ByteEncode₁₁, according to FIPS 203, Algorithm 5.
381+
func ringCompressAndEncode11(s []byte, f ringElement) []byte {
382+
return ringCompressAndEncode(s, f, 11)
383+
}
384+
385+
// ringDecodeAndDecompress11 decodes a 352-byte encoding of a ring element where
386+
// each eleven bits are mapped to an equidistant distribution.
387+
//
388+
// It implements ByteDecode₁₁, according to FIPS 203, Algorithm 6,
389+
// followed by Decompress₁₁, according to FIPS 203, Definition 4.8.
390+
func ringDecodeAndDecompress11(bb *[encodingSize11]byte) ringElement {
391+
return ringDecodeAndDecompress(bb[:], 11)
392+
}
393+
299394
// samplePolyCBD draws a ringElement from the special Dη distribution given a
300395
// stream of random bytes generated by the PRF function, according to FIPS 203,
301396
// Algorithm 8 and Definition 4.3.

src/crypto/internal/fips/mlkem/field_test.go

+78
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package mlkem
66

77
import (
8+
"bytes"
9+
"crypto/rand"
810
"math/big"
11+
mathrand "math/rand/v2"
912
"strconv"
1013
"testing"
1114
)
@@ -151,6 +154,81 @@ func TestDecompress(t *testing.T) {
151154
}
152155
}
153156

157+
func randomRingElement() ringElement {
158+
var r ringElement
159+
for i := range r {
160+
r[i] = fieldElement(mathrand.IntN(q))
161+
}
162+
return r
163+
}
164+
165+
func TestEncodeDecode(t *testing.T) {
166+
f := randomRingElement()
167+
b := make([]byte, 12*n/8)
168+
rand.Read(b)
169+
170+
// Compare ringCompressAndEncode to ringCompressAndEncodeN.
171+
e1 := ringCompressAndEncode(nil, f, 10)
172+
e2 := ringCompressAndEncode10(nil, f)
173+
if !bytes.Equal(e1, e2) {
174+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2)
175+
}
176+
e1 = ringCompressAndEncode(nil, f, 4)
177+
e2 = ringCompressAndEncode4(nil, f)
178+
if !bytes.Equal(e1, e2) {
179+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2)
180+
}
181+
e1 = ringCompressAndEncode(nil, f, 1)
182+
e2 = ringCompressAndEncode1(nil, f)
183+
if !bytes.Equal(e1, e2) {
184+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2)
185+
}
186+
187+
// Compare ringDecodeAndDecompress to ringDecodeAndDecompressN.
188+
g1 := ringDecodeAndDecompress(b[:encodingSize10], 10)
189+
g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
190+
if g1 != g2 {
191+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2)
192+
}
193+
g1 = ringDecodeAndDecompress(b[:encodingSize4], 4)
194+
g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
195+
if g1 != g2 {
196+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2)
197+
}
198+
g1 = ringDecodeAndDecompress(b[:encodingSize1], 1)
199+
g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
200+
if g1 != g2 {
201+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2)
202+
}
203+
204+
// Round-trip ringCompressAndEncode and ringDecodeAndDecompress.
205+
for d := 1; d < 12; d++ {
206+
encodingSize := d * n / 8
207+
g := ringDecodeAndDecompress(b[:encodingSize], uint8(d))
208+
out := ringCompressAndEncode(nil, g, uint8(d))
209+
if !bytes.Equal(out, b[:encodingSize]) {
210+
t.Errorf("roundtrip failed for d = %d", d)
211+
}
212+
}
213+
214+
// Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN.
215+
g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
216+
out := ringCompressAndEncode10(nil, g)
217+
if !bytes.Equal(out, b[:encodingSize10]) {
218+
t.Errorf("roundtrip failed for specialized 10")
219+
}
220+
g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
221+
out = ringCompressAndEncode4(nil, g)
222+
if !bytes.Equal(out, b[:encodingSize4]) {
223+
t.Errorf("roundtrip failed for specialized 4")
224+
}
225+
g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
226+
out = ringCompressAndEncode1(nil, g)
227+
if !bytes.Equal(out, b[:encodingSize1]) {
228+
t.Errorf("roundtrip failed for specialized 1")
229+
}
230+
}
231+
154232
func BitRev7(n uint8) uint8 {
155233
if n>>7 != 0 {
156234
panic("not 7 bits")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build ignore
6+
7+
package main
8+
9+
import (
10+
"flag"
11+
"go/ast"
12+
"go/format"
13+
"go/parser"
14+
"go/token"
15+
"log"
16+
"os"
17+
"strings"
18+
)
19+
20+
var replacements = map[string]string{
21+
"k": "k1024",
22+
23+
"CiphertextSize768": "CiphertextSize1024",
24+
"EncapsulationKeySize768": "EncapsulationKeySize1024",
25+
26+
"encryptionKey": "encryptionKey1024",
27+
"decryptionKey": "decryptionKey1024",
28+
29+
"EncapsulationKey768": "EncapsulationKey1024",
30+
"NewEncapsulationKey768": "NewEncapsulationKey1024",
31+
"parseEK": "parseEK1024",
32+
33+
"kemEncaps": "kemEncaps1024",
34+
"pkeEncrypt": "pkeEncrypt1024",
35+
36+
"DecapsulationKey768": "DecapsulationKey1024",
37+
"NewDecapsulationKey768": "NewDecapsulationKey1024",
38+
"newKeyFromSeed": "newKeyFromSeed1024",
39+
40+
"kemDecaps": "kemDecaps1024",
41+
"pkeDecrypt": "pkeDecrypt1024",
42+
43+
"GenerateKey768": "GenerateKey1024",
44+
"generateKey": "generateKey1024",
45+
46+
"kemKeyGen": "kemKeyGen1024",
47+
48+
"encodingSize4": "encodingSize5",
49+
"encodingSize10": "encodingSize11",
50+
"ringCompressAndEncode4": "ringCompressAndEncode5",
51+
"ringCompressAndEncode10": "ringCompressAndEncode11",
52+
"ringDecodeAndDecompress4": "ringDecodeAndDecompress5",
53+
"ringDecodeAndDecompress10": "ringDecodeAndDecompress11",
54+
}
55+
56+
func main() {
57+
inputFile := flag.String("input", "", "")
58+
outputFile := flag.String("output", "", "")
59+
flag.Parse()
60+
61+
fset := token.NewFileSet()
62+
f, err := parser.ParseFile(fset, *inputFile, nil, parser.SkipObjectResolution|parser.ParseComments)
63+
if err != nil {
64+
log.Fatal(err)
65+
}
66+
cmap := ast.NewCommentMap(fset, f, f.Comments)
67+
68+
// Drop header comments.
69+
cmap[ast.Node(f)] = nil
70+
71+
// Remove top-level consts used across the main and generated files.
72+
var newDecls []ast.Decl
73+
for _, decl := range f.Decls {
74+
switch d := decl.(type) {
75+
case *ast.GenDecl:
76+
if d.Tok == token.CONST {
77+
continue // Skip const declarations
78+
}
79+
if d.Tok == token.IMPORT {
80+
cmap[decl] = nil // Drop pre-import comments.
81+
}
82+
}
83+
newDecls = append(newDecls, decl)
84+
}
85+
f.Decls = newDecls
86+
87+
// Replace identifiers.
88+
ast.Inspect(f, func(n ast.Node) bool {
89+
switch x := n.(type) {
90+
case *ast.Ident:
91+
if replacement, ok := replacements[x.Name]; ok {
92+
x.Name = replacement
93+
}
94+
}
95+
return true
96+
})
97+
98+
// Replace identifiers in comments.
99+
for _, c := range f.Comments {
100+
for _, l := range c.List {
101+
for k, v := range replacements {
102+
if k == "k" {
103+
continue
104+
}
105+
l.Text = strings.ReplaceAll(l.Text, k, v)
106+
}
107+
}
108+
}
109+
110+
out, err := os.Create(*outputFile)
111+
if err != nil {
112+
log.Fatal(err)
113+
}
114+
defer out.Close()
115+
116+
out.WriteString("// Code generated by generate1024.go. DO NOT EDIT.\n\n")
117+
118+
f.Comments = cmap.Filter(f).Comments()
119+
err = format.Node(out, fset, f)
120+
if err != nil {
121+
log.Fatal(err)
122+
}
123+
}

0 commit comments

Comments
 (0)