linux/drivers/iommu/virtio-iommu.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Virtio driver for the paravirtualized IOMMU
   4 *
   5 * Copyright (C) 2019 Arm Limited
   6 */
   7
   8#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
   9
  10#include <linux/amba/bus.h>
  11#include <linux/delay.h>
  12#include <linux/dma-iommu.h>
  13#include <linux/dma-map-ops.h>
  14#include <linux/freezer.h>
  15#include <linux/interval_tree.h>
  16#include <linux/iommu.h>
  17#include <linux/module.h>
  18#include <linux/of_platform.h>
  19#include <linux/pci.h>
  20#include <linux/platform_device.h>
  21#include <linux/virtio.h>
  22#include <linux/virtio_config.h>
  23#include <linux/virtio_ids.h>
  24#include <linux/wait.h>
  25
  26#include <uapi/linux/virtio_iommu.h>
  27
  28#define MSI_IOVA_BASE                   0x8000000
  29#define MSI_IOVA_LENGTH                 0x100000
  30
  31#define VIOMMU_REQUEST_VQ               0
  32#define VIOMMU_EVENT_VQ                 1
  33#define VIOMMU_NR_VQS                   2
  34
  35struct viommu_dev {
  36        struct iommu_device             iommu;
  37        struct device                   *dev;
  38        struct virtio_device            *vdev;
  39
  40        struct ida                      domain_ids;
  41
  42        struct virtqueue                *vqs[VIOMMU_NR_VQS];
  43        spinlock_t                      request_lock;
  44        struct list_head                requests;
  45        void                            *evts;
  46
  47        /* Device configuration */
  48        struct iommu_domain_geometry    geometry;
  49        u64                             pgsize_bitmap;
  50        u32                             first_domain;
  51        u32                             last_domain;
  52        /* Supported MAP flags */
  53        u32                             map_flags;
  54        u32                             probe_size;
  55};
  56
  57struct viommu_mapping {
  58        phys_addr_t                     paddr;
  59        struct interval_tree_node       iova;
  60        u32                             flags;
  61};
  62
  63struct viommu_domain {
  64        struct iommu_domain             domain;
  65        struct viommu_dev               *viommu;
  66        struct mutex                    mutex; /* protects viommu pointer */
  67        unsigned int                    id;
  68        u32                             map_flags;
  69
  70        spinlock_t                      mappings_lock;
  71        struct rb_root_cached           mappings;
  72
  73        unsigned long                   nr_endpoints;
  74};
  75
  76struct viommu_endpoint {
  77        struct device                   *dev;
  78        struct viommu_dev               *viommu;
  79        struct viommu_domain            *vdomain;
  80        struct list_head                resv_regions;
  81};
  82
  83struct viommu_request {
  84        struct list_head                list;
  85        void                            *writeback;
  86        unsigned int                    write_offset;
  87        unsigned int                    len;
  88        char                            buf[];
  89};
  90
  91#define VIOMMU_FAULT_RESV_MASK          0xffffff00
  92
  93struct viommu_event {
  94        union {
  95                u32                     head;
  96                struct virtio_iommu_fault fault;
  97        };
  98};
  99
 100#define to_viommu_domain(domain)        \
 101        container_of(domain, struct viommu_domain, domain)
 102
 103static int viommu_get_req_errno(void *buf, size_t len)
 104{
 105        struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
 106
 107        switch (tail->status) {
 108        case VIRTIO_IOMMU_S_OK:
 109                return 0;
 110        case VIRTIO_IOMMU_S_UNSUPP:
 111                return -ENOSYS;
 112        case VIRTIO_IOMMU_S_INVAL:
 113                return -EINVAL;
 114        case VIRTIO_IOMMU_S_RANGE:
 115                return -ERANGE;
 116        case VIRTIO_IOMMU_S_NOENT:
 117                return -ENOENT;
 118        case VIRTIO_IOMMU_S_FAULT:
 119                return -EFAULT;
 120        case VIRTIO_IOMMU_S_NOMEM:
 121                return -ENOMEM;
 122        case VIRTIO_IOMMU_S_IOERR:
 123        case VIRTIO_IOMMU_S_DEVERR:
 124        default:
 125                return -EIO;
 126        }
 127}
 128
 129static void viommu_set_req_status(void *buf, size_t len, int status)
 130{
 131        struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
 132
 133        tail->status = status;
 134}
 135
 136static off_t viommu_get_write_desc_offset(struct viommu_dev *viommu,
 137                                          struct virtio_iommu_req_head *req,
 138                                          size_t len)
 139{
 140        size_t tail_size = sizeof(struct virtio_iommu_req_tail);
 141
 142        if (req->type == VIRTIO_IOMMU_T_PROBE)
 143                return len - viommu->probe_size - tail_size;
 144
 145        return len - tail_size;
 146}
 147
 148/*
 149 * __viommu_sync_req - Complete all in-flight requests
 150 *
 151 * Wait for all added requests to complete. When this function returns, all
 152 * requests that were in-flight at the time of the call have completed.
 153 */
 154static int __viommu_sync_req(struct viommu_dev *viommu)
 155{
 156        unsigned int len;
 157        size_t write_len;
 158        struct viommu_request *req;
 159        struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
 160
 161        assert_spin_locked(&viommu->request_lock);
 162
 163        virtqueue_kick(vq);
 164
 165        while (!list_empty(&viommu->requests)) {
 166                len = 0;
 167                req = virtqueue_get_buf(vq, &len);
 168                if (!req)
 169                        continue;
 170
 171                if (!len)
 172                        viommu_set_req_status(req->buf, req->len,
 173                                              VIRTIO_IOMMU_S_IOERR);
 174
 175                write_len = req->len - req->write_offset;
 176                if (req->writeback && len == write_len)
 177                        memcpy(req->writeback, req->buf + req->write_offset,
 178                               write_len);
 179
 180                list_del(&req->list);
 181                kfree(req);
 182        }
 183
 184        return 0;
 185}
 186
 187static int viommu_sync_req(struct viommu_dev *viommu)
 188{
 189        int ret;
 190        unsigned long flags;
 191
 192        spin_lock_irqsave(&viommu->request_lock, flags);
 193        ret = __viommu_sync_req(viommu);
 194        if (ret)
 195                dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
 196        spin_unlock_irqrestore(&viommu->request_lock, flags);
 197
 198        return ret;
 199}
 200
 201/*
 202 * __viommu_add_request - Add one request to the queue
 203 * @buf: pointer to the request buffer
 204 * @len: length of the request buffer
 205 * @writeback: copy data back to the buffer when the request completes.
 206 *
 207 * Add a request to the queue. Only synchronize the queue if it's already full.
 208 * Otherwise don't kick the queue nor wait for requests to complete.
 209 *
 210 * When @writeback is true, data written by the device, including the request
 211 * status, is copied into @buf after the request completes. This is unsafe if
 212 * the caller allocates @buf on stack and drops the lock between add_req() and
 213 * sync_req().
 214 *
 215 * Return 0 if the request was successfully added to the queue.
 216 */
 217static int __viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len,
 218                            bool writeback)
 219{
 220        int ret;
 221        off_t write_offset;
 222        struct viommu_request *req;
 223        struct scatterlist top_sg, bottom_sg;
 224        struct scatterlist *sg[2] = { &top_sg, &bottom_sg };
 225        struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
 226
 227        assert_spin_locked(&viommu->request_lock);
 228
 229        write_offset = viommu_get_write_desc_offset(viommu, buf, len);
 230        if (write_offset <= 0)
 231                return -EINVAL;
 232
 233        req = kzalloc(sizeof(*req) + len, GFP_ATOMIC);
 234        if (!req)
 235                return -ENOMEM;
 236
 237        req->len = len;
 238        if (writeback) {
 239                req->writeback = buf + write_offset;
 240                req->write_offset = write_offset;
 241        }
 242        memcpy(&req->buf, buf, write_offset);
 243
 244        sg_init_one(&top_sg, req->buf, write_offset);
 245        sg_init_one(&bottom_sg, req->buf + write_offset, len - write_offset);
 246
 247        ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
 248        if (ret == -ENOSPC) {
 249                /* If the queue is full, sync and retry */
 250                if (!__viommu_sync_req(viommu))
 251                        ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
 252        }
 253        if (ret)
 254                goto err_free;
 255
 256        list_add_tail(&req->list, &viommu->requests);
 257        return 0;
 258
 259err_free:
 260        kfree(req);
 261        return ret;
 262}
 263
 264static int viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len)
 265{
 266        int ret;
 267        unsigned long flags;
 268
 269        spin_lock_irqsave(&viommu->request_lock, flags);
 270        ret = __viommu_add_req(viommu, buf, len, false);
 271        if (ret)
 272                dev_dbg(viommu->dev, "could not add request: %d\n", ret);
 273        spin_unlock_irqrestore(&viommu->request_lock, flags);
 274
 275        return ret;
 276}
 277
 278/*
 279 * Send a request and wait for it to complete. Return the request status (as an
 280 * errno)
 281 */
 282static int viommu_send_req_sync(struct viommu_dev *viommu, void *buf,
 283                                size_t len)
 284{
 285        int ret;
 286        unsigned long flags;
 287
 288        spin_lock_irqsave(&viommu->request_lock, flags);
 289
 290        ret = __viommu_add_req(viommu, buf, len, true);
 291        if (ret) {
 292                dev_dbg(viommu->dev, "could not add request (%d)\n", ret);
 293                goto out_unlock;
 294        }
 295
 296        ret = __viommu_sync_req(viommu);
 297        if (ret) {
 298                dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
 299                /* Fall-through (get the actual request status) */
 300        }
 301
 302        ret = viommu_get_req_errno(buf, len);
 303out_unlock:
 304        spin_unlock_irqrestore(&viommu->request_lock, flags);
 305        return ret;
 306}
 307
 308/*
 309 * viommu_add_mapping - add a mapping to the internal tree
 310 *
 311 * On success, return the new mapping. Otherwise return NULL.
 312 */
 313static int viommu_add_mapping(struct viommu_domain *vdomain, unsigned long iova,
 314                              phys_addr_t paddr, size_t size, u32 flags)
 315{
 316        unsigned long irqflags;
 317        struct viommu_mapping *mapping;
 318
 319        mapping = kzalloc(sizeof(*mapping), GFP_ATOMIC);
 320        if (!mapping)
 321                return -ENOMEM;
 322
 323        mapping->paddr          = paddr;
 324        mapping->iova.start     = iova;
 325        mapping->iova.last      = iova + size - 1;
 326        mapping->flags          = flags;
 327
 328        spin_lock_irqsave(&vdomain->mappings_lock, irqflags);
 329        interval_tree_insert(&mapping->iova, &vdomain->mappings);
 330        spin_unlock_irqrestore(&vdomain->mappings_lock, irqflags);
 331
 332        return 0;
 333}
 334
 335/*
 336 * viommu_del_mappings - remove mappings from the internal tree
 337 *
 338 * @vdomain: the domain
 339 * @iova: start of the range
 340 * @size: size of the range. A size of 0 corresponds to the entire address
 341 *      space.
 342 *
 343 * On success, returns the number of unmapped bytes (>= size)
 344 */
 345static size_t viommu_del_mappings(struct viommu_domain *vdomain,
 346                                  unsigned long iova, size_t size)
 347{
 348        size_t unmapped = 0;
 349        unsigned long flags;
 350        unsigned long last = iova + size - 1;
 351        struct viommu_mapping *mapping = NULL;
 352        struct interval_tree_node *node, *next;
 353
 354        spin_lock_irqsave(&vdomain->mappings_lock, flags);
 355        next = interval_tree_iter_first(&vdomain->mappings, iova, last);
 356        while (next) {
 357                node = next;
 358                mapping = container_of(node, struct viommu_mapping, iova);
 359                next = interval_tree_iter_next(node, iova, last);
 360
 361                /* Trying to split a mapping? */
 362                if (mapping->iova.start < iova)
 363                        break;
 364
 365                /*
 366                 * Virtio-iommu doesn't allow UNMAP to split a mapping created
 367                 * with a single MAP request, so remove the full mapping.
 368                 */
 369                unmapped += mapping->iova.last - mapping->iova.start + 1;
 370
 371                interval_tree_remove(node, &vdomain->mappings);
 372                kfree(mapping);
 373        }
 374        spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
 375
 376        return unmapped;
 377}
 378
 379/*
 380 * viommu_replay_mappings - re-send MAP requests
 381 *
 382 * When reattaching a domain that was previously detached from all endpoints,
 383 * mappings were deleted from the device. Re-create the mappings available in
 384 * the internal tree.
 385 */
 386static int viommu_replay_mappings(struct viommu_domain *vdomain)
 387{
 388        int ret = 0;
 389        unsigned long flags;
 390        struct viommu_mapping *mapping;
 391        struct interval_tree_node *node;
 392        struct virtio_iommu_req_map map;
 393
 394        spin_lock_irqsave(&vdomain->mappings_lock, flags);
 395        node = interval_tree_iter_first(&vdomain->mappings, 0, -1UL);
 396        while (node) {
 397                mapping = container_of(node, struct viommu_mapping, iova);
 398                map = (struct virtio_iommu_req_map) {
 399                        .head.type      = VIRTIO_IOMMU_T_MAP,
 400                        .domain         = cpu_to_le32(vdomain->id),
 401                        .virt_start     = cpu_to_le64(mapping->iova.start),
 402                        .virt_end       = cpu_to_le64(mapping->iova.last),
 403                        .phys_start     = cpu_to_le64(mapping->paddr),
 404                        .flags          = cpu_to_le32(mapping->flags),
 405                };
 406
 407                ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
 408                if (ret)
 409                        break;
 410
 411                node = interval_tree_iter_next(node, 0, -1UL);
 412        }
 413        spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
 414
 415        return ret;
 416}
 417
 418static int viommu_add_resv_mem(struct viommu_endpoint *vdev,
 419                               struct virtio_iommu_probe_resv_mem *mem,
 420                               size_t len)
 421{
 422        size_t size;
 423        u64 start64, end64;
 424        phys_addr_t start, end;
 425        struct iommu_resv_region *region = NULL;
 426        unsigned long prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
 427
 428        start = start64 = le64_to_cpu(mem->start);
 429        end = end64 = le64_to_cpu(mem->end);
 430        size = end64 - start64 + 1;
 431
 432        /* Catch any overflow, including the unlikely end64 - start64 + 1 = 0 */
 433        if (start != start64 || end != end64 || size < end64 - start64)
 434                return -EOVERFLOW;
 435
 436        if (len < sizeof(*mem))
 437                return -EINVAL;
 438
 439        switch (mem->subtype) {
 440        default:
 441                dev_warn(vdev->dev, "unknown resv mem subtype 0x%x\n",
 442                         mem->subtype);
 443                fallthrough;
 444        case VIRTIO_IOMMU_RESV_MEM_T_RESERVED:
 445                region = iommu_alloc_resv_region(start, size, 0,
 446                                                 IOMMU_RESV_RESERVED);
 447                break;
 448        case VIRTIO_IOMMU_RESV_MEM_T_MSI:
 449                region = iommu_alloc_resv_region(start, size, prot,
 450                                                 IOMMU_RESV_MSI);
 451                break;
 452        }
 453        if (!region)
 454                return -ENOMEM;
 455
 456        list_add(&region->list, &vdev->resv_regions);
 457        return 0;
 458}
 459
 460static int viommu_probe_endpoint(struct viommu_dev *viommu, struct device *dev)
 461{
 462        int ret;
 463        u16 type, len;
 464        size_t cur = 0;
 465        size_t probe_len;
 466        struct virtio_iommu_req_probe *probe;
 467        struct virtio_iommu_probe_property *prop;
 468        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
 469        struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
 470
 471        if (!fwspec->num_ids)
 472                return -EINVAL;
 473
 474        probe_len = sizeof(*probe) + viommu->probe_size +
 475                    sizeof(struct virtio_iommu_req_tail);
 476        probe = kzalloc(probe_len, GFP_KERNEL);
 477        if (!probe)
 478                return -ENOMEM;
 479
 480        probe->head.type = VIRTIO_IOMMU_T_PROBE;
 481        /*
 482         * For now, assume that properties of an endpoint that outputs multiple
 483         * IDs are consistent. Only probe the first one.
 484         */
 485        probe->endpoint = cpu_to_le32(fwspec->ids[0]);
 486
 487        ret = viommu_send_req_sync(viommu, probe, probe_len);
 488        if (ret)
 489                goto out_free;
 490
 491        prop = (void *)probe->properties;
 492        type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
 493
 494        while (type != VIRTIO_IOMMU_PROBE_T_NONE &&
 495               cur < viommu->probe_size) {
 496                len = le16_to_cpu(prop->length) + sizeof(*prop);
 497
 498                switch (type) {
 499                case VIRTIO_IOMMU_PROBE_T_RESV_MEM:
 500                        ret = viommu_add_resv_mem(vdev, (void *)prop, len);
 501                        break;
 502                default:
 503                        dev_err(dev, "unknown viommu prop 0x%x\n", type);
 504                }
 505
 506                if (ret)
 507                        dev_err(dev, "failed to parse viommu prop 0x%x\n", type);
 508
 509                cur += len;
 510                if (cur >= viommu->probe_size)
 511                        break;
 512
 513                prop = (void *)probe->properties + cur;
 514                type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
 515        }
 516
 517out_free:
 518        kfree(probe);
 519        return ret;
 520}
 521
 522static int viommu_fault_handler(struct viommu_dev *viommu,
 523                                struct virtio_iommu_fault *fault)
 524{
 525        char *reason_str;
 526
 527        u8 reason       = fault->reason;
 528        u32 flags       = le32_to_cpu(fault->flags);
 529        u32 endpoint    = le32_to_cpu(fault->endpoint);
 530        u64 address     = le64_to_cpu(fault->address);
 531
 532        switch (reason) {
 533        case VIRTIO_IOMMU_FAULT_R_DOMAIN:
 534                reason_str = "domain";
 535                break;
 536        case VIRTIO_IOMMU_FAULT_R_MAPPING:
 537                reason_str = "page";
 538                break;
 539        case VIRTIO_IOMMU_FAULT_R_UNKNOWN:
 540        default:
 541                reason_str = "unknown";
 542                break;
 543        }
 544
 545        /* TODO: find EP by ID and report_iommu_fault */
 546        if (flags & VIRTIO_IOMMU_FAULT_F_ADDRESS)
 547                dev_err_ratelimited(viommu->dev, "%s fault from EP %u at %#llx [%s%s%s]\n",
 548                                    reason_str, endpoint, address,
 549                                    flags & VIRTIO_IOMMU_FAULT_F_READ ? "R" : "",
 550                                    flags & VIRTIO_IOMMU_FAULT_F_WRITE ? "W" : "",
 551                                    flags & VIRTIO_IOMMU_FAULT_F_EXEC ? "X" : "");
 552        else
 553                dev_err_ratelimited(viommu->dev, "%s fault from EP %u\n",
 554                                    reason_str, endpoint);
 555        return 0;
 556}
 557
 558static void viommu_event_handler(struct virtqueue *vq)
 559{
 560        int ret;
 561        unsigned int len;
 562        struct scatterlist sg[1];
 563        struct viommu_event *evt;
 564        struct viommu_dev *viommu = vq->vdev->priv;
 565
 566        while ((evt = virtqueue_get_buf(vq, &len)) != NULL) {
 567                if (len > sizeof(*evt)) {
 568                        dev_err(viommu->dev,
 569                                "invalid event buffer (len %u != %zu)\n",
 570                                len, sizeof(*evt));
 571                } else if (!(evt->head & VIOMMU_FAULT_RESV_MASK)) {
 572                        viommu_fault_handler(viommu, &evt->fault);
 573                }
 574
 575                sg_init_one(sg, evt, sizeof(*evt));
 576                ret = virtqueue_add_inbuf(vq, sg, 1, evt, GFP_ATOMIC);
 577                if (ret)
 578                        dev_err(viommu->dev, "could not add event buffer\n");
 579        }
 580
 581        virtqueue_kick(vq);
 582}
 583
 584/* IOMMU API */
 585
 586static struct iommu_domain *viommu_domain_alloc(unsigned type)
 587{
 588        struct viommu_domain *vdomain;
 589
 590        if (type != IOMMU_DOMAIN_UNMANAGED && type != IOMMU_DOMAIN_DMA)
 591                return NULL;
 592
 593        vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
 594        if (!vdomain)
 595                return NULL;
 596
 597        mutex_init(&vdomain->mutex);
 598        spin_lock_init(&vdomain->mappings_lock);
 599        vdomain->mappings = RB_ROOT_CACHED;
 600
 601        return &vdomain->domain;
 602}
 603
 604static int viommu_domain_finalise(struct viommu_endpoint *vdev,
 605                                  struct iommu_domain *domain)
 606{
 607        int ret;
 608        unsigned long viommu_page_size;
 609        struct viommu_dev *viommu = vdev->viommu;
 610        struct viommu_domain *vdomain = to_viommu_domain(domain);
 611
 612        viommu_page_size = 1UL << __ffs(viommu->pgsize_bitmap);
 613        if (viommu_page_size > PAGE_SIZE) {
 614                dev_err(vdev->dev,
 615                        "granule 0x%lx larger than system page size 0x%lx\n",
 616                        viommu_page_size, PAGE_SIZE);
 617                return -EINVAL;
 618        }
 619
 620        ret = ida_alloc_range(&viommu->domain_ids, viommu->first_domain,
 621                              viommu->last_domain, GFP_KERNEL);
 622        if (ret < 0)
 623                return ret;
 624
 625        vdomain->id             = (unsigned int)ret;
 626
 627        domain->pgsize_bitmap   = viommu->pgsize_bitmap;
 628        domain->geometry        = viommu->geometry;
 629
 630        vdomain->map_flags      = viommu->map_flags;
 631        vdomain->viommu         = viommu;
 632
 633        return 0;
 634}
 635
 636static void viommu_domain_free(struct iommu_domain *domain)
 637{
 638        struct viommu_domain *vdomain = to_viommu_domain(domain);
 639
 640        /* Free all remaining mappings (size 2^64) */
 641        viommu_del_mappings(vdomain, 0, 0);
 642
 643        if (vdomain->viommu)
 644                ida_free(&vdomain->viommu->domain_ids, vdomain->id);
 645
 646        kfree(vdomain);
 647}
 648
 649static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
 650{
 651        int i;
 652        int ret = 0;
 653        struct virtio_iommu_req_attach req;
 654        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
 655        struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
 656        struct viommu_domain *vdomain = to_viommu_domain(domain);
 657
 658        mutex_lock(&vdomain->mutex);
 659        if (!vdomain->viommu) {
 660                /*
 661                 * Properly initialize the domain now that we know which viommu
 662                 * owns it.
 663                 */
 664                ret = viommu_domain_finalise(vdev, domain);
 665        } else if (vdomain->viommu != vdev->viommu) {
 666                dev_err(dev, "cannot attach to foreign vIOMMU\n");
 667                ret = -EXDEV;
 668        }
 669        mutex_unlock(&vdomain->mutex);
 670
 671        if (ret)
 672                return ret;
 673
 674        /*
 675         * In the virtio-iommu device, when attaching the endpoint to a new
 676         * domain, it is detached from the old one and, if as as a result the
 677         * old domain isn't attached to any endpoint, all mappings are removed
 678         * from the old domain and it is freed.
 679         *
 680         * In the driver the old domain still exists, and its mappings will be
 681         * recreated if it gets reattached to an endpoint. Otherwise it will be
 682         * freed explicitly.
 683         *
 684         * vdev->vdomain is protected by group->mutex
 685         */
 686        if (vdev->vdomain)
 687                vdev->vdomain->nr_endpoints--;
 688
 689        req = (struct virtio_iommu_req_attach) {
 690                .head.type      = VIRTIO_IOMMU_T_ATTACH,
 691                .domain         = cpu_to_le32(vdomain->id),
 692        };
 693
 694        for (i = 0; i < fwspec->num_ids; i++) {
 695                req.endpoint = cpu_to_le32(fwspec->ids[i]);
 696
 697                ret = viommu_send_req_sync(vdomain->viommu, &req, sizeof(req));
 698                if (ret)
 699                        return ret;
 700        }
 701
 702        if (!vdomain->nr_endpoints) {
 703                /*
 704                 * This endpoint is the first to be attached to the domain.
 705                 * Replay existing mappings (e.g. SW MSI).
 706                 */
 707                ret = viommu_replay_mappings(vdomain);
 708                if (ret)
 709                        return ret;
 710        }
 711
 712        vdomain->nr_endpoints++;
 713        vdev->vdomain = vdomain;
 714
 715        return 0;
 716}
 717
 718static int viommu_map(struct iommu_domain *domain, unsigned long iova,
 719                      phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
 720{
 721        int ret;
 722        u32 flags;
 723        struct virtio_iommu_req_map map;
 724        struct viommu_domain *vdomain = to_viommu_domain(domain);
 725
 726        flags = (prot & IOMMU_READ ? VIRTIO_IOMMU_MAP_F_READ : 0) |
 727                (prot & IOMMU_WRITE ? VIRTIO_IOMMU_MAP_F_WRITE : 0) |
 728                (prot & IOMMU_MMIO ? VIRTIO_IOMMU_MAP_F_MMIO : 0);
 729
 730        if (flags & ~vdomain->map_flags)
 731                return -EINVAL;
 732
 733        ret = viommu_add_mapping(vdomain, iova, paddr, size, flags);
 734        if (ret)
 735                return ret;
 736
 737        map = (struct virtio_iommu_req_map) {
 738                .head.type      = VIRTIO_IOMMU_T_MAP,
 739                .domain         = cpu_to_le32(vdomain->id),
 740                .virt_start     = cpu_to_le64(iova),
 741                .phys_start     = cpu_to_le64(paddr),
 742                .virt_end       = cpu_to_le64(iova + size - 1),
 743                .flags          = cpu_to_le32(flags),
 744        };
 745
 746        if (!vdomain->nr_endpoints)
 747                return 0;
 748
 749        ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
 750        if (ret)
 751                viommu_del_mappings(vdomain, iova, size);
 752
 753        return ret;
 754}
 755
 756static size_t viommu_unmap(struct iommu_domain *domain, unsigned long iova,
 757                           size_t size, struct iommu_iotlb_gather *gather)
 758{
 759        int ret = 0;
 760        size_t unmapped;
 761        struct virtio_iommu_req_unmap unmap;
 762        struct viommu_domain *vdomain = to_viommu_domain(domain);
 763
 764        unmapped = viommu_del_mappings(vdomain, iova, size);
 765        if (unmapped < size)
 766                return 0;
 767
 768        /* Device already removed all mappings after detach. */
 769        if (!vdomain->nr_endpoints)
 770                return unmapped;
 771
 772        unmap = (struct virtio_iommu_req_unmap) {
 773                .head.type      = VIRTIO_IOMMU_T_UNMAP,
 774                .domain         = cpu_to_le32(vdomain->id),
 775                .virt_start     = cpu_to_le64(iova),
 776                .virt_end       = cpu_to_le64(iova + unmapped - 1),
 777        };
 778
 779        ret = viommu_add_req(vdomain->viommu, &unmap, sizeof(unmap));
 780        return ret ? 0 : unmapped;
 781}
 782
 783static phys_addr_t viommu_iova_to_phys(struct iommu_domain *domain,
 784                                       dma_addr_t iova)
 785{
 786        u64 paddr = 0;
 787        unsigned long flags;
 788        struct viommu_mapping *mapping;
 789        struct interval_tree_node *node;
 790        struct viommu_domain *vdomain = to_viommu_domain(domain);
 791
 792        spin_lock_irqsave(&vdomain->mappings_lock, flags);
 793        node = interval_tree_iter_first(&vdomain->mappings, iova, iova);
 794        if (node) {
 795                mapping = container_of(node, struct viommu_mapping, iova);
 796                paddr = mapping->paddr + (iova - mapping->iova.start);
 797        }
 798        spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
 799
 800        return paddr;
 801}
 802
 803static void viommu_iotlb_sync(struct iommu_domain *domain,
 804                              struct iommu_iotlb_gather *gather)
 805{
 806        struct viommu_domain *vdomain = to_viommu_domain(domain);
 807
 808        viommu_sync_req(vdomain->viommu);
 809}
 810
 811static void viommu_get_resv_regions(struct device *dev, struct list_head *head)
 812{
 813        struct iommu_resv_region *entry, *new_entry, *msi = NULL;
 814        struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
 815        int prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
 816
 817        list_for_each_entry(entry, &vdev->resv_regions, list) {
 818                if (entry->type == IOMMU_RESV_MSI)
 819                        msi = entry;
 820
 821                new_entry = kmemdup(entry, sizeof(*entry), GFP_KERNEL);
 822                if (!new_entry)
 823                        return;
 824                list_add_tail(&new_entry->list, head);
 825        }
 826
 827        /*
 828         * If the device didn't register any bypass MSI window, add a
 829         * software-mapped region.
 830         */
 831        if (!msi) {
 832                msi = iommu_alloc_resv_region(MSI_IOVA_BASE, MSI_IOVA_LENGTH,
 833                                              prot, IOMMU_RESV_SW_MSI);
 834                if (!msi)
 835                        return;
 836
 837                list_add_tail(&msi->list, head);
 838        }
 839
 840        iommu_dma_get_resv_regions(dev, head);
 841}
 842
 843static struct iommu_ops viommu_ops;
 844static struct virtio_driver virtio_iommu_drv;
 845
 846static int viommu_match_node(struct device *dev, const void *data)
 847{
 848        return dev->parent->fwnode == data;
 849}
 850
 851static struct viommu_dev *viommu_get_by_fwnode(struct fwnode_handle *fwnode)
 852{
 853        struct device *dev = driver_find_device(&virtio_iommu_drv.driver, NULL,
 854                                                fwnode, viommu_match_node);
 855        put_device(dev);
 856
 857        return dev ? dev_to_virtio(dev)->priv : NULL;
 858}
 859
 860static struct iommu_device *viommu_probe_device(struct device *dev)
 861{
 862        int ret;
 863        struct viommu_endpoint *vdev;
 864        struct viommu_dev *viommu = NULL;
 865        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
 866
 867        if (!fwspec || fwspec->ops != &viommu_ops)
 868                return ERR_PTR(-ENODEV);
 869
 870        viommu = viommu_get_by_fwnode(fwspec->iommu_fwnode);
 871        if (!viommu)
 872                return ERR_PTR(-ENODEV);
 873
 874        vdev = kzalloc(sizeof(*vdev), GFP_KERNEL);
 875        if (!vdev)
 876                return ERR_PTR(-ENOMEM);
 877
 878        vdev->dev = dev;
 879        vdev->viommu = viommu;
 880        INIT_LIST_HEAD(&vdev->resv_regions);
 881        dev_iommu_priv_set(dev, vdev);
 882
 883        if (viommu->probe_size) {
 884                /* Get additional information for this endpoint */
 885                ret = viommu_probe_endpoint(viommu, dev);
 886                if (ret)
 887                        goto err_free_dev;
 888        }
 889
 890        return &viommu->iommu;
 891
 892err_free_dev:
 893        generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
 894        kfree(vdev);
 895
 896        return ERR_PTR(ret);
 897}
 898
 899static void viommu_probe_finalize(struct device *dev)
 900{
 901#ifndef CONFIG_ARCH_HAS_SETUP_DMA_OPS
 902        /* First clear the DMA ops in case we're switching from a DMA domain */
 903        set_dma_ops(dev, NULL);
 904        iommu_setup_dma_ops(dev, 0, U64_MAX);
 905#endif
 906}
 907
 908static void viommu_release_device(struct device *dev)
 909{
 910        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
 911        struct viommu_endpoint *vdev;
 912
 913        if (!fwspec || fwspec->ops != &viommu_ops)
 914                return;
 915
 916        vdev = dev_iommu_priv_get(dev);
 917
 918        generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
 919        kfree(vdev);
 920}
 921
 922static struct iommu_group *viommu_device_group(struct device *dev)
 923{
 924        if (dev_is_pci(dev))
 925                return pci_device_group(dev);
 926        else
 927                return generic_device_group(dev);
 928}
 929
 930static int viommu_of_xlate(struct device *dev, struct of_phandle_args *args)
 931{
 932        return iommu_fwspec_add_ids(dev, args->args, 1);
 933}
 934
 935static struct iommu_ops viommu_ops = {
 936        .domain_alloc           = viommu_domain_alloc,
 937        .domain_free            = viommu_domain_free,
 938        .attach_dev             = viommu_attach_dev,
 939        .map                    = viommu_map,
 940        .unmap                  = viommu_unmap,
 941        .iova_to_phys           = viommu_iova_to_phys,
 942        .iotlb_sync             = viommu_iotlb_sync,
 943        .probe_device           = viommu_probe_device,
 944        .probe_finalize         = viommu_probe_finalize,
 945        .release_device         = viommu_release_device,
 946        .device_group           = viommu_device_group,
 947        .get_resv_regions       = viommu_get_resv_regions,
 948        .put_resv_regions       = generic_iommu_put_resv_regions,
 949        .of_xlate               = viommu_of_xlate,
 950        .owner                  = THIS_MODULE,
 951};
 952
 953static int viommu_init_vqs(struct viommu_dev *viommu)
 954{
 955        struct virtio_device *vdev = dev_to_virtio(viommu->dev);
 956        const char *names[] = { "request", "event" };
 957        vq_callback_t *callbacks[] = {
 958                NULL, /* No async requests */
 959                viommu_event_handler,
 960        };
 961
 962        return virtio_find_vqs(vdev, VIOMMU_NR_VQS, viommu->vqs, callbacks,
 963                               names, NULL);
 964}
 965
 966static int viommu_fill_evtq(struct viommu_dev *viommu)
 967{
 968        int i, ret;
 969        struct scatterlist sg[1];
 970        struct viommu_event *evts;
 971        struct virtqueue *vq = viommu->vqs[VIOMMU_EVENT_VQ];
 972        size_t nr_evts = vq->num_free;
 973
 974        viommu->evts = evts = devm_kmalloc_array(viommu->dev, nr_evts,
 975                                                 sizeof(*evts), GFP_KERNEL);
 976        if (!evts)
 977                return -ENOMEM;
 978
 979        for (i = 0; i < nr_evts; i++) {
 980                sg_init_one(sg, &evts[i], sizeof(*evts));
 981                ret = virtqueue_add_inbuf(vq, sg, 1, &evts[i], GFP_KERNEL);
 982                if (ret)
 983                        return ret;
 984        }
 985
 986        return 0;
 987}
 988
 989static int viommu_probe(struct virtio_device *vdev)
 990{
 991        struct device *parent_dev = vdev->dev.parent;
 992        struct viommu_dev *viommu = NULL;
 993        struct device *dev = &vdev->dev;
 994        u64 input_start = 0;
 995        u64 input_end = -1UL;
 996        int ret;
 997
 998        if (!virtio_has_feature(vdev, VIRTIO_F_VERSION_1) ||
 999            !virtio_has_feature(vdev, VIRTIO_IOMMU_F_MAP_UNMAP))
1000                return -ENODEV;
1001
1002        viommu = devm_kzalloc(dev, sizeof(*viommu), GFP_KERNEL);
1003        if (!viommu)
1004                return -ENOMEM;
1005
1006        spin_lock_init(&viommu->request_lock);
1007        ida_init(&viommu->domain_ids);
1008        viommu->dev = dev;
1009        viommu->vdev = vdev;
1010        INIT_LIST_HEAD(&viommu->requests);
1011
1012        ret = viommu_init_vqs(viommu);
1013        if (ret)
1014                return ret;
1015
1016        virtio_cread_le(vdev, struct virtio_iommu_config, page_size_mask,
1017                        &viommu->pgsize_bitmap);
1018
1019        if (!viommu->pgsize_bitmap) {
1020                ret = -EINVAL;
1021                goto err_free_vqs;
1022        }
1023
1024        viommu->map_flags = VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE;
1025        viommu->last_domain = ~0U;
1026
1027        /* Optional features */
1028        virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
1029                                struct virtio_iommu_config, input_range.start,
1030                                &input_start);
1031
1032        virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
1033                                struct virtio_iommu_config, input_range.end,
1034                                &input_end);
1035
1036        virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
1037                                struct virtio_iommu_config, domain_range.start,
1038                                &viommu->first_domain);
1039
1040        virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
1041                                struct virtio_iommu_config, domain_range.end,
1042                                &viommu->last_domain);
1043
1044        virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_PROBE,
1045                                struct virtio_iommu_config, probe_size,
1046                                &viommu->probe_size);
1047
1048        viommu->geometry = (struct iommu_domain_geometry) {
1049                .aperture_start = input_start,
1050                .aperture_end   = input_end,
1051                .force_aperture = true,
1052        };
1053
1054        if (virtio_has_feature(vdev, VIRTIO_IOMMU_F_MMIO))
1055                viommu->map_flags |= VIRTIO_IOMMU_MAP_F_MMIO;
1056
1057        viommu_ops.pgsize_bitmap = viommu->pgsize_bitmap;
1058
1059        virtio_device_ready(vdev);
1060
1061        /* Populate the event queue with buffers */
1062        ret = viommu_fill_evtq(viommu);
1063        if (ret)
1064                goto err_free_vqs;
1065
1066        ret = iommu_device_sysfs_add(&viommu->iommu, dev, NULL, "%s",
1067                                     virtio_bus_name(vdev));
1068        if (ret)
1069                goto err_free_vqs;
1070
1071        iommu_device_register(&viommu->iommu, &viommu_ops, parent_dev);
1072
1073#ifdef CONFIG_PCI
1074        if (pci_bus_type.iommu_ops != &viommu_ops) {
1075                ret = bus_set_iommu(&pci_bus_type, &viommu_ops);
1076                if (ret)
1077                        goto err_unregister;
1078        }
1079#endif
1080#ifdef CONFIG_ARM_AMBA
1081        if (amba_bustype.iommu_ops != &viommu_ops) {
1082                ret = bus_set_iommu(&amba_bustype, &viommu_ops);
1083                if (ret)
1084                        goto err_unregister;
1085        }
1086#endif
1087        if (platform_bus_type.iommu_ops != &viommu_ops) {
1088                ret = bus_set_iommu(&platform_bus_type, &viommu_ops);
1089                if (ret)
1090                        goto err_unregister;
1091        }
1092
1093        vdev->priv = viommu;
1094
1095        dev_info(dev, "input address: %u bits\n",
1096                 order_base_2(viommu->geometry.aperture_end));
1097        dev_info(dev, "page mask: %#llx\n", viommu->pgsize_bitmap);
1098
1099        return 0;
1100
1101err_unregister:
1102        iommu_device_sysfs_remove(&viommu->iommu);
1103        iommu_device_unregister(&viommu->iommu);
1104err_free_vqs:
1105        vdev->config->del_vqs(vdev);
1106
1107        return ret;
1108}
1109
1110static void viommu_remove(struct virtio_device *vdev)
1111{
1112        struct viommu_dev *viommu = vdev->priv;
1113
1114        iommu_device_sysfs_remove(&viommu->iommu);
1115        iommu_device_unregister(&viommu->iommu);
1116
1117        /* Stop all virtqueues */
1118        vdev->config->reset(vdev);
1119        vdev->config->del_vqs(vdev);
1120
1121        dev_info(&vdev->dev, "device removed\n");
1122}
1123
1124static void viommu_config_changed(struct virtio_device *vdev)
1125{
1126        dev_warn(&vdev->dev, "config changed\n");
1127}
1128
1129static unsigned int features[] = {
1130        VIRTIO_IOMMU_F_MAP_UNMAP,
1131        VIRTIO_IOMMU_F_INPUT_RANGE,
1132        VIRTIO_IOMMU_F_DOMAIN_RANGE,
1133        VIRTIO_IOMMU_F_PROBE,
1134        VIRTIO_IOMMU_F_MMIO,
1135};
1136
1137static struct virtio_device_id id_table[] = {
1138        { VIRTIO_ID_IOMMU, VIRTIO_DEV_ANY_ID },
1139        { 0 },
1140};
1141MODULE_DEVICE_TABLE(virtio, id_table);
1142
1143static struct virtio_driver virtio_iommu_drv = {
1144        .driver.name            = KBUILD_MODNAME,
1145        .driver.owner           = THIS_MODULE,
1146        .id_table               = id_table,
1147        .feature_table          = features,
1148        .feature_table_size     = ARRAY_SIZE(features),
1149        .probe                  = viommu_probe,
1150        .remove                 = viommu_remove,
1151        .config_changed         = viommu_config_changed,
1152};
1153
1154module_virtio_driver(virtio_iommu_drv);
1155
1156MODULE_DESCRIPTION("Virtio IOMMU driver");
1157MODULE_AUTHOR("Jean-Philippe Brucker <jean-philippe.brucker@arm.com>");
1158MODULE_LICENSE("GPL v2");
1159