1
2
3
4
5
6
7
8
9
10
11
12
13
14
15#include <crypto/scatterwalk.h>
16#include <crypto/skcipher.h>
17#include <crypto/if_alg.h>
18#include <linux/init.h>
19#include <linux/list.h>
20#include <linux/kernel.h>
21#include <linux/mm.h>
22#include <linux/module.h>
23#include <linux/net.h>
24#include <net/sock.h>
25
26struct skcipher_sg_list {
27 struct list_head list;
28
29 int cur;
30
31 struct scatterlist sg[0];
32};
33
34struct ablkcipher_tfm_keycheck {
35 struct crypto_ablkcipher *ablkcipher;
36 bool has_key;
37};
38
39struct skcipher_ctx {
40 struct list_head tsgl;
41 struct af_alg_sgl rsgl;
42
43 void *iv;
44
45 struct af_alg_completion completion;
46
47 unsigned used;
48
49 unsigned int len;
50 bool more;
51 bool merge;
52 bool enc;
53
54 struct ablkcipher_request req;
55};
56
57#define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
58 sizeof(struct scatterlist) - 1)
59
60static inline int skcipher_sndbuf(struct sock *sk)
61{
62 struct alg_sock *ask = alg_sk(sk);
63 struct skcipher_ctx *ctx = ask->private;
64
65 return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
66 ctx->used, 0);
67}
68
69static inline bool skcipher_writable(struct sock *sk)
70{
71 return PAGE_SIZE <= skcipher_sndbuf(sk);
72}
73
74static int skcipher_alloc_sgl(struct sock *sk)
75{
76 struct alg_sock *ask = alg_sk(sk);
77 struct skcipher_ctx *ctx = ask->private;
78 struct skcipher_sg_list *sgl;
79 struct scatterlist *sg = NULL;
80
81 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
82 if (!list_empty(&ctx->tsgl))
83 sg = sgl->sg;
84
85 if (!sg || sgl->cur >= MAX_SGL_ENTS) {
86 sgl = sock_kmalloc(sk, sizeof(*sgl) +
87 sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
88 GFP_KERNEL);
89 if (!sgl)
90 return -ENOMEM;
91
92 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
93 sgl->cur = 0;
94
95 if (sg)
96 scatterwalk_sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
97
98 list_add_tail(&sgl->list, &ctx->tsgl);
99 }
100
101 return 0;
102}
103
104static void skcipher_pull_sgl(struct sock *sk, int used)
105{
106 struct alg_sock *ask = alg_sk(sk);
107 struct skcipher_ctx *ctx = ask->private;
108 struct skcipher_sg_list *sgl;
109 struct scatterlist *sg;
110 int i;
111
112 while (!list_empty(&ctx->tsgl)) {
113 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
114 list);
115 sg = sgl->sg;
116
117 for (i = 0; i < sgl->cur; i++) {
118 int plen = min_t(int, used, sg[i].length);
119
120 if (!sg_page(sg + i))
121 continue;
122
123 sg[i].length -= plen;
124 sg[i].offset += plen;
125
126 used -= plen;
127 ctx->used -= plen;
128
129 if (sg[i].length)
130 return;
131
132 put_page(sg_page(sg + i));
133 sg_assign_page(sg + i, NULL);
134 }
135
136 list_del(&sgl->list);
137 sock_kfree_s(sk, sgl,
138 sizeof(*sgl) + sizeof(sgl->sg[0]) *
139 (MAX_SGL_ENTS + 1));
140 }
141
142 if (!ctx->used)
143 ctx->merge = 0;
144}
145
146static void skcipher_free_sgl(struct sock *sk)
147{
148 struct alg_sock *ask = alg_sk(sk);
149 struct skcipher_ctx *ctx = ask->private;
150
151 skcipher_pull_sgl(sk, ctx->used);
152}
153
154static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
155{
156 long timeout;
157 DEFINE_WAIT(wait);
158 int err = -ERESTARTSYS;
159
160 if (flags & MSG_DONTWAIT)
161 return -EAGAIN;
162
163 set_bit(SOCK_ASYNC_NOSPACE, &sk->sk_socket->flags);
164
165 for (;;) {
166 if (signal_pending(current))
167 break;
168 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
169 timeout = MAX_SCHEDULE_TIMEOUT;
170 if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
171 err = 0;
172 break;
173 }
174 }
175 finish_wait(sk_sleep(sk), &wait);
176
177 return err;
178}
179
180static void skcipher_wmem_wakeup(struct sock *sk)
181{
182 struct socket_wq *wq;
183
184 if (!skcipher_writable(sk))
185 return;
186
187 rcu_read_lock();
188 wq = rcu_dereference(sk->sk_wq);
189 if (wq_has_sleeper(wq))
190 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
191 POLLRDNORM |
192 POLLRDBAND);
193 sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
194 rcu_read_unlock();
195}
196
197static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
198{
199 struct alg_sock *ask = alg_sk(sk);
200 struct skcipher_ctx *ctx = ask->private;
201 long timeout;
202 DEFINE_WAIT(wait);
203 int err = -ERESTARTSYS;
204
205 if (flags & MSG_DONTWAIT) {
206 return -EAGAIN;
207 }
208
209 set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
210
211 for (;;) {
212 if (signal_pending(current))
213 break;
214 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
215 timeout = MAX_SCHEDULE_TIMEOUT;
216 if (sk_wait_event(sk, &timeout, ctx->used)) {
217 err = 0;
218 break;
219 }
220 }
221 finish_wait(sk_sleep(sk), &wait);
222
223 clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
224
225 return err;
226}
227
228static void skcipher_data_wakeup(struct sock *sk)
229{
230 struct alg_sock *ask = alg_sk(sk);
231 struct skcipher_ctx *ctx = ask->private;
232 struct socket_wq *wq;
233
234 if (!ctx->used)
235 return;
236
237 rcu_read_lock();
238 wq = rcu_dereference(sk->sk_wq);
239 if (wq_has_sleeper(wq))
240 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
241 POLLRDNORM |
242 POLLRDBAND);
243 sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
244 rcu_read_unlock();
245}
246
247static int skcipher_sendmsg(struct kiocb *unused, struct socket *sock,
248 struct msghdr *msg, size_t size)
249{
250 struct sock *sk = sock->sk;
251 struct alg_sock *ask = alg_sk(sk);
252 struct skcipher_ctx *ctx = ask->private;
253 struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
254 unsigned ivsize = crypto_ablkcipher_ivsize(tfm);
255 struct skcipher_sg_list *sgl;
256 struct af_alg_control con = {};
257 long copied = 0;
258 bool enc = 0;
259 bool init = 0;
260 int err;
261 int i;
262
263 if (msg->msg_controllen) {
264 err = af_alg_cmsg_send(msg, &con);
265 if (err)
266 return err;
267
268 init = 1;
269 switch (con.op) {
270 case ALG_OP_ENCRYPT:
271 enc = 1;
272 break;
273 case ALG_OP_DECRYPT:
274 enc = 0;
275 break;
276 default:
277 return -EINVAL;
278 }
279
280 if (con.iv && con.iv->ivlen != ivsize)
281 return -EINVAL;
282 }
283
284 err = -EINVAL;
285
286 lock_sock(sk);
287 if (!ctx->more && ctx->used)
288 goto unlock;
289
290 if (init) {
291 ctx->enc = enc;
292 if (con.iv)
293 memcpy(ctx->iv, con.iv->iv, ivsize);
294 }
295
296 while (size) {
297 struct scatterlist *sg;
298 unsigned long len = size;
299 int plen;
300
301 if (ctx->merge) {
302 sgl = list_entry(ctx->tsgl.prev,
303 struct skcipher_sg_list, list);
304 sg = sgl->sg + sgl->cur - 1;
305 len = min_t(unsigned long, len,
306 PAGE_SIZE - sg->offset - sg->length);
307
308 err = memcpy_from_msg(page_address(sg_page(sg)) +
309 sg->offset + sg->length,
310 msg, len);
311 if (err)
312 goto unlock;
313
314 sg->length += len;
315 ctx->merge = (sg->offset + sg->length) &
316 (PAGE_SIZE - 1);
317
318 ctx->used += len;
319 copied += len;
320 size -= len;
321 continue;
322 }
323
324 if (!skcipher_writable(sk)) {
325 err = skcipher_wait_for_wmem(sk, msg->msg_flags);
326 if (err)
327 goto unlock;
328 }
329
330 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
331
332 err = skcipher_alloc_sgl(sk);
333 if (err)
334 goto unlock;
335
336 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
337 sg = sgl->sg;
338 do {
339 i = sgl->cur;
340 plen = min_t(int, len, PAGE_SIZE);
341
342 sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
343 err = -ENOMEM;
344 if (!sg_page(sg + i))
345 goto unlock;
346
347 err = memcpy_from_msg(page_address(sg_page(sg + i)),
348 msg, plen);
349 if (err) {
350 __free_page(sg_page(sg + i));
351 sg_assign_page(sg + i, NULL);
352 goto unlock;
353 }
354
355 sg[i].length = plen;
356 len -= plen;
357 ctx->used += plen;
358 copied += plen;
359 size -= plen;
360 sgl->cur++;
361 } while (len && sgl->cur < MAX_SGL_ENTS);
362
363 ctx->merge = plen & (PAGE_SIZE - 1);
364 }
365
366 err = 0;
367
368 ctx->more = msg->msg_flags & MSG_MORE;
369 if (!ctx->more && !list_empty(&ctx->tsgl))
370 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
371
372unlock:
373 skcipher_data_wakeup(sk);
374 release_sock(sk);
375
376 return copied ?: err;
377}
378
379static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
380 int offset, size_t size, int flags)
381{
382 struct sock *sk = sock->sk;
383 struct alg_sock *ask = alg_sk(sk);
384 struct skcipher_ctx *ctx = ask->private;
385 struct skcipher_sg_list *sgl;
386 int err = -EINVAL;
387
388 if (flags & MSG_SENDPAGE_NOTLAST)
389 flags |= MSG_MORE;
390
391 lock_sock(sk);
392 if (!ctx->more && ctx->used)
393 goto unlock;
394
395 if (!size)
396 goto done;
397
398 if (!skcipher_writable(sk)) {
399 err = skcipher_wait_for_wmem(sk, flags);
400 if (err)
401 goto unlock;
402 }
403
404 err = skcipher_alloc_sgl(sk);
405 if (err)
406 goto unlock;
407
408 ctx->merge = 0;
409 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
410
411 get_page(page);
412 sg_set_page(sgl->sg + sgl->cur, page, size, offset);
413 sgl->cur++;
414 ctx->used += size;
415
416done:
417 ctx->more = flags & MSG_MORE;
418 if (!ctx->more && !list_empty(&ctx->tsgl))
419 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
420
421unlock:
422 skcipher_data_wakeup(sk);
423 release_sock(sk);
424
425 return err ?: size;
426}
427
428static int skcipher_recvmsg(struct kiocb *unused, struct socket *sock,
429 struct msghdr *msg, size_t ignored, int flags)
430{
431 struct sock *sk = sock->sk;
432 struct alg_sock *ask = alg_sk(sk);
433 struct skcipher_ctx *ctx = ask->private;
434 unsigned bs = crypto_ablkcipher_blocksize(crypto_ablkcipher_reqtfm(
435 &ctx->req));
436 struct skcipher_sg_list *sgl;
437 struct scatterlist *sg;
438 unsigned long iovlen;
439 struct iovec *iov;
440 int err = -EAGAIN;
441 int used;
442 long copied = 0;
443
444 lock_sock(sk);
445 for (iov = msg->msg_iov, iovlen = msg->msg_iovlen; iovlen > 0;
446 iovlen--, iov++) {
447 unsigned long seglen = iov->iov_len;
448 char __user *from = iov->iov_base;
449
450 while (seglen) {
451 used = ctx->used;
452 if (!used) {
453 err = skcipher_wait_for_data(sk, flags);
454 if (err)
455 goto unlock;
456 }
457
458 used = min_t(unsigned long, used, seglen);
459
460 used = af_alg_make_sg(&ctx->rsgl, from, used, 1);
461 err = used;
462 if (err < 0)
463 goto unlock;
464
465 if (ctx->more || used < ctx->used)
466 used -= used % bs;
467
468 err = -EINVAL;
469 if (!used)
470 goto free;
471
472 sgl = list_first_entry(&ctx->tsgl,
473 struct skcipher_sg_list, list);
474 sg = sgl->sg;
475
476 while (!sg->length)
477 sg++;
478
479 ablkcipher_request_set_crypt(&ctx->req, sg,
480 ctx->rsgl.sg, used,
481 ctx->iv);
482
483 err = af_alg_wait_for_completion(
484 ctx->enc ?
485 crypto_ablkcipher_encrypt(&ctx->req) :
486 crypto_ablkcipher_decrypt(&ctx->req),
487 &ctx->completion);
488
489free:
490 af_alg_free_sg(&ctx->rsgl);
491
492 if (err)
493 goto unlock;
494
495 copied += used;
496 from += used;
497 seglen -= used;
498 skcipher_pull_sgl(sk, used);
499 }
500 }
501
502 err = 0;
503
504unlock:
505 skcipher_wmem_wakeup(sk);
506 release_sock(sk);
507
508 return copied ?: err;
509}
510
511
512static unsigned int skcipher_poll(struct file *file, struct socket *sock,
513 poll_table *wait)
514{
515 struct sock *sk = sock->sk;
516 struct alg_sock *ask = alg_sk(sk);
517 struct skcipher_ctx *ctx = ask->private;
518 unsigned int mask;
519
520 sock_poll_wait(file, sk_sleep(sk), wait);
521 mask = 0;
522
523 if (ctx->used)
524 mask |= POLLIN | POLLRDNORM;
525
526 if (skcipher_writable(sk))
527 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
528
529 return mask;
530}
531
532static struct proto_ops algif_skcipher_ops = {
533 .family = PF_ALG,
534
535 .connect = sock_no_connect,
536 .socketpair = sock_no_socketpair,
537 .getname = sock_no_getname,
538 .ioctl = sock_no_ioctl,
539 .listen = sock_no_listen,
540 .shutdown = sock_no_shutdown,
541 .getsockopt = sock_no_getsockopt,
542 .mmap = sock_no_mmap,
543 .bind = sock_no_bind,
544 .accept = sock_no_accept,
545 .setsockopt = sock_no_setsockopt,
546
547 .release = af_alg_release,
548 .sendmsg = skcipher_sendmsg,
549 .sendpage = skcipher_sendpage,
550 .recvmsg = skcipher_recvmsg,
551 .poll = skcipher_poll,
552};
553
554static int skcipher_check_key(struct socket *sock)
555{
556 int err;
557 struct sock *psk;
558 struct alg_sock *pask;
559 struct ablkcipher_tfm_keycheck *tfm;
560 struct sock *sk = sock->sk;
561 struct alg_sock *ask = alg_sk(sk);
562
563 if (ask->refcnt)
564 return 0;
565
566 psk = ask->parent;
567 pask = alg_sk(ask->parent);
568 tfm = pask->private;
569
570 err = -ENOKEY;
571 lock_sock(psk);
572 if (!tfm->has_key)
573 goto unlock;
574
575 if (!pask->refcnt++)
576 sock_hold(psk);
577
578 ask->refcnt = 1;
579 sock_put(psk);
580
581 err = 0;
582
583unlock:
584 release_sock(psk);
585
586 return err;
587}
588
589static int skcipher_sendmsg_nokey(struct kiocb *unused, struct socket *sock,
590 struct msghdr *msg, size_t size)
591{
592 int err;
593
594 err = skcipher_check_key(sock);
595 if (err)
596 return err;
597
598 return skcipher_sendmsg(NULL, sock, msg, size);
599}
600
601static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
602 int offset, size_t size, int flags)
603{
604 int err;
605
606 err = skcipher_check_key(sock);
607 if (err)
608 return err;
609
610 return skcipher_sendpage(sock, page, offset, size, flags);
611}
612
613static int skcipher_recvmsg_nokey(struct kiocb *unused, struct socket *sock,
614 struct msghdr *msg, size_t ignored, int flags)
615{
616 int err;
617
618 err = skcipher_check_key(sock);
619 if (err)
620 return err;
621
622 return skcipher_recvmsg(NULL, sock, msg, ignored, flags);
623}
624
625static struct proto_ops algif_skcipher_ops_nokey = {
626 .family = PF_ALG,
627
628 .connect = sock_no_connect,
629 .socketpair = sock_no_socketpair,
630 .getname = sock_no_getname,
631 .ioctl = sock_no_ioctl,
632 .listen = sock_no_listen,
633 .shutdown = sock_no_shutdown,
634 .getsockopt = sock_no_getsockopt,
635 .mmap = sock_no_mmap,
636 .bind = sock_no_bind,
637 .accept = sock_no_accept,
638 .setsockopt = sock_no_setsockopt,
639
640 .release = af_alg_release,
641 .sendmsg = skcipher_sendmsg_nokey,
642 .sendpage = skcipher_sendpage_nokey,
643 .recvmsg = skcipher_recvmsg_nokey,
644 .poll = skcipher_poll,
645};
646
647static void *skcipher_bind(const char *name, u32 type, u32 mask)
648{
649 struct ablkcipher_tfm_keycheck *tfm;
650 struct crypto_ablkcipher *ablkcipher;
651
652 tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
653 if (!tfm)
654 return ERR_PTR(-ENOMEM);
655
656 ablkcipher = crypto_alloc_ablkcipher(name, type, mask);
657 if (IS_ERR(ablkcipher)) {
658 kfree(tfm);
659 return ERR_CAST(ablkcipher);
660 }
661
662 tfm->ablkcipher = ablkcipher;
663
664 return tfm;
665}
666
667static void skcipher_release(void *private)
668{
669 struct ablkcipher_tfm_keycheck *tfm = private;
670
671 crypto_free_ablkcipher(tfm->ablkcipher);
672 kfree(tfm);
673}
674
675static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
676{
677 struct ablkcipher_tfm_keycheck *tfm = private;
678 int err;
679
680 err = crypto_ablkcipher_setkey(tfm->ablkcipher, key, keylen);
681 tfm->has_key = !err;
682
683 return err;
684}
685
686static void skcipher_sock_destruct(struct sock *sk)
687{
688 struct alg_sock *ask = alg_sk(sk);
689 struct skcipher_ctx *ctx = ask->private;
690 struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
691
692 skcipher_free_sgl(sk);
693 sock_kzfree_s(sk, ctx->iv, crypto_ablkcipher_ivsize(tfm));
694 sock_kfree_s(sk, ctx, ctx->len);
695 af_alg_release_parent(sk);
696}
697
698static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
699{
700 struct skcipher_ctx *ctx;
701 struct alg_sock *ask = alg_sk(sk);
702 struct ablkcipher_tfm_keycheck *tfm = private;
703 struct crypto_ablkcipher *ablkcipher = tfm->ablkcipher;
704 unsigned int len = sizeof(*ctx) +
705 crypto_ablkcipher_reqsize(ablkcipher);
706
707 ctx = sock_kmalloc(sk, len, GFP_KERNEL);
708 if (!ctx)
709 return -ENOMEM;
710
711 ctx->iv = sock_kmalloc(sk, crypto_ablkcipher_ivsize(ablkcipher),
712 GFP_KERNEL);
713 if (!ctx->iv) {
714 sock_kfree_s(sk, ctx, len);
715 return -ENOMEM;
716 }
717
718 memset(ctx->iv, 0, crypto_ablkcipher_ivsize(ablkcipher));
719
720 INIT_LIST_HEAD(&ctx->tsgl);
721 ctx->len = len;
722 ctx->used = 0;
723 ctx->more = 0;
724 ctx->merge = 0;
725 ctx->enc = 0;
726 af_alg_init_completion(&ctx->completion);
727
728 ask->private = ctx;
729
730 ablkcipher_request_set_tfm(&ctx->req, ablkcipher);
731 ablkcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
732 af_alg_complete, &ctx->completion);
733
734 sk->sk_destruct = skcipher_sock_destruct;
735
736 return 0;
737}
738
739static int skcipher_accept_parent(void *private, struct sock *sk)
740{
741 struct ablkcipher_tfm_keycheck *tfm = private;
742 struct crypto_tfm *ctfm = crypto_ablkcipher_tfm(tfm->ablkcipher);
743 struct crypto_alg *calg = ctfm->__crt_alg;
744
745 if (!tfm->has_key && calg->cra_u.ablkcipher.max_keysize)
746 return -ENOKEY;
747
748 return skcipher_accept_parent_nokey(private, sk);
749}
750
751static const struct af_alg_type algif_type_skcipher = {
752 .bind = skcipher_bind,
753 .release = skcipher_release,
754 .setkey = skcipher_setkey,
755 .accept = skcipher_accept_parent,
756 .accept_nokey = skcipher_accept_parent_nokey,
757 .ops = &algif_skcipher_ops,
758 .ops_nokey = &algif_skcipher_ops_nokey,
759 .name = "skcipher",
760 .owner = THIS_MODULE
761};
762
763static int __init algif_skcipher_init(void)
764{
765 return af_alg_register_type(&algif_type_skcipher);
766}
767
768static void __exit algif_skcipher_exit(void)
769{
770 int err = af_alg_unregister_type(&algif_type_skcipher);
771 BUG_ON(err);
772}
773
774module_init(algif_skcipher_init);
775module_exit(algif_skcipher_exit);
776MODULE_LICENSE("GPL");
777