Merge branch 'x86-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-block.git] / arch / x86 / kvm / mmu.c
index ddb3291d49c90b955c03aa7c09b3c47319a4c110..70e95d097ef104ac489a41a17dce2296dad5123c 100644 (file)
@@ -41,6 +41,7 @@
 #include <asm/cmpxchg.h>
 #include <asm/io.h>
 #include <asm/vmx.h>
+#include <asm/kvm_page_track.h>
 
 /*
  * When setting this variable to true it enables Two-Dimensional-Paging
@@ -631,12 +632,12 @@ static void walk_shadow_page_lockless_begin(struct kvm_vcpu *vcpu)
         * kvm_flush_remote_tlbs() IPI to all active vcpus.
         */
        local_irq_disable();
-       vcpu->mode = READING_SHADOW_PAGE_TABLES;
+
        /*
         * Make sure a following spte read is not reordered ahead of the write
         * to vcpu->mode.
         */
-       smp_mb();
+       smp_store_mb(vcpu->mode, READING_SHADOW_PAGE_TABLES);
 }
 
 static void walk_shadow_page_lockless_end(struct kvm_vcpu *vcpu)
@@ -646,8 +647,7 @@ static void walk_shadow_page_lockless_end(struct kvm_vcpu *vcpu)
         * reads to sptes.  If it does, kvm_commit_zap_page() can see us
         * OUTSIDE_GUEST_MODE and proceed to free the shadow page table.
         */
-       smp_mb();
-       vcpu->mode = OUTSIDE_GUEST_MODE;
+       smp_store_release(&vcpu->mode, OUTSIDE_GUEST_MODE);
        local_irq_enable();
 }
 
@@ -776,62 +776,85 @@ static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
        return &slot->arch.lpage_info[level - 2][idx];
 }
 
+static void update_gfn_disallow_lpage_count(struct kvm_memory_slot *slot,
+                                           gfn_t gfn, int count)
+{
+       struct kvm_lpage_info *linfo;
+       int i;
+
+       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
+               linfo = lpage_info_slot(gfn, slot, i);
+               linfo->disallow_lpage += count;
+               WARN_ON(linfo->disallow_lpage < 0);
+       }
+}
+
+void kvm_mmu_gfn_disallow_lpage(struct kvm_memory_slot *slot, gfn_t gfn)
+{
+       update_gfn_disallow_lpage_count(slot, gfn, 1);
+}
+
+void kvm_mmu_gfn_allow_lpage(struct kvm_memory_slot *slot, gfn_t gfn)
+{
+       update_gfn_disallow_lpage_count(slot, gfn, -1);
+}
+
 static void account_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
-       struct kvm_lpage_info *linfo;
        gfn_t gfn;
-       int i;
 
+       kvm->arch.indirect_shadow_pages++;
        gfn = sp->gfn;
        slots = kvm_memslots_for_spte_role(kvm, sp->role);
        slot = __gfn_to_memslot(slots, gfn);
-       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
-               linfo = lpage_info_slot(gfn, slot, i);
-               linfo->write_count += 1;
-       }
-       kvm->arch.indirect_shadow_pages++;
+
+       /* the non-leaf shadow pages are keeping readonly. */
+       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+               return kvm_slot_page_track_add_page(kvm, slot, gfn,
+                                                   KVM_PAGE_TRACK_WRITE);
+
+       kvm_mmu_gfn_disallow_lpage(slot, gfn);
 }
 
 static void unaccount_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
-       struct kvm_lpage_info *linfo;
        gfn_t gfn;
-       int i;
 
+       kvm->arch.indirect_shadow_pages--;
        gfn = sp->gfn;
        slots = kvm_memslots_for_spte_role(kvm, sp->role);
        slot = __gfn_to_memslot(slots, gfn);
-       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
-               linfo = lpage_info_slot(gfn, slot, i);
-               linfo->write_count -= 1;
-               WARN_ON(linfo->write_count < 0);
-       }
-       kvm->arch.indirect_shadow_pages--;
+       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+               return kvm_slot_page_track_remove_page(kvm, slot, gfn,
+                                                      KVM_PAGE_TRACK_WRITE);
+
+       kvm_mmu_gfn_allow_lpage(slot, gfn);
 }
 
-static int __has_wrprotected_page(gfn_t gfn, int level,
-                                 struct kvm_memory_slot *slot)
+static bool __mmu_gfn_lpage_is_disallowed(gfn_t gfn, int level,
+                                         struct kvm_memory_slot *slot)
 {
        struct kvm_lpage_info *linfo;
 
        if (slot) {
                linfo = lpage_info_slot(gfn, slot, level);
-               return linfo->write_count;
+               return !!linfo->disallow_lpage;
        }
 
-       return 1;
+       return true;
 }
 
-static int has_wrprotected_page(struct kvm_vcpu *vcpu, gfn_t gfn, int level)
+static bool mmu_gfn_lpage_is_disallowed(struct kvm_vcpu *vcpu, gfn_t gfn,
+                                       int level)
 {
        struct kvm_memory_slot *slot;
 
        slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
-       return __has_wrprotected_page(gfn, level, slot);
+       return __mmu_gfn_lpage_is_disallowed(gfn, level, slot);
 }
 
 static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
@@ -897,7 +920,7 @@ static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn,
        max_level = min(kvm_x86_ops->get_lpage_level(), host_level);
 
        for (level = PT_DIRECTORY_LEVEL; level <= max_level; ++level)
-               if (__has_wrprotected_page(large_gfn, level, slot))
+               if (__mmu_gfn_lpage_is_disallowed(large_gfn, level, slot))
                        break;
 
        return level - 1;
@@ -1323,23 +1346,29 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                kvm_mmu_write_protect_pt_masked(kvm, slot, gfn_offset, mask);
 }
 
-static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
+bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
+                                   struct kvm_memory_slot *slot, u64 gfn)
 {
-       struct kvm_memory_slot *slot;
        struct kvm_rmap_head *rmap_head;
        int i;
        bool write_protected = false;
 
-       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
-
        for (i = PT_PAGE_TABLE_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                rmap_head = __gfn_to_rmap(gfn, i, slot);
-               write_protected |= __rmap_write_protect(vcpu->kvm, rmap_head, true);
+               write_protected |= __rmap_write_protect(kvm, rmap_head, true);
        }
 
        return write_protected;
 }
 
+static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
+{
+       struct kvm_memory_slot *slot;
+
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
+       return kvm_mmu_slot_gfn_write_protect(vcpu->kvm, slot, gfn);
+}
+
 static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
 {
        u64 *sptep;
@@ -1754,7 +1783,7 @@ static void mark_unsync(u64 *spte)
 static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
                               struct kvm_mmu_page *sp)
 {
-       return 1;
+       return 0;
 }
 
 static void nonpaging_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
@@ -1840,13 +1869,16 @@ static int __mmu_unsync_walk(struct kvm_mmu_page *sp,
        return nr_unsync_leaf;
 }
 
+#define INVALID_INDEX (-1)
+
 static int mmu_unsync_walk(struct kvm_mmu_page *sp,
                           struct kvm_mmu_pages *pvec)
 {
+       pvec->nr = 0;
        if (!sp->unsync_children)
                return 0;
 
-       mmu_pages_add(pvec, sp, 0);
+       mmu_pages_add(pvec, sp, INVALID_INDEX);
        return __mmu_unsync_walk(sp, pvec);
 }
 
@@ -1883,37 +1915,35 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
                if ((_sp)->role.direct || (_sp)->role.invalid) {} else
 
 /* @sp->gfn should be write-protected at the call site */
-static int __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                          struct list_head *invalid_list, bool clear_unsync)
+static bool __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+                           struct list_head *invalid_list)
 {
        if (sp->role.cr4_pae != !!is_pae(vcpu)) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
-               return 1;
+               return false;
        }
 
-       if (clear_unsync)
-               kvm_unlink_unsync_page(vcpu->kvm, sp);
-
-       if (vcpu->arch.mmu.sync_page(vcpu, sp)) {
+       if (vcpu->arch.mmu.sync_page(vcpu, sp) == 0) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
-               return 1;
+               return false;
        }
 
-       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
-       return 0;
+       return true;
 }
 
-static int kvm_sync_page_transient(struct kvm_vcpu *vcpu,
-                                  struct kvm_mmu_page *sp)
+static void kvm_mmu_flush_or_zap(struct kvm_vcpu *vcpu,
+                                struct list_head *invalid_list,
+                                bool remote_flush, bool local_flush)
 {
-       LIST_HEAD(invalid_list);
-       int ret;
-
-       ret = __kvm_sync_page(vcpu, sp, &invalid_list, false);
-       if (ret)
-               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       if (!list_empty(invalid_list)) {
+               kvm_mmu_commit_zap_page(vcpu->kvm, invalid_list);
+               return;
+       }
 
-       return ret;
+       if (remote_flush)
+               kvm_flush_remote_tlbs(vcpu->kvm);
+       else if (local_flush)
+               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
 }
 
 #ifdef CONFIG_KVM_MMU_AUDIT
@@ -1923,46 +1953,38 @@ static void kvm_mmu_audit(struct kvm_vcpu *vcpu, int point) { }
 static void mmu_audit_disable(void) { }
 #endif
 
-static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+static bool kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                         struct list_head *invalid_list)
 {
-       return __kvm_sync_page(vcpu, sp, invalid_list, true);
+       kvm_unlink_unsync_page(vcpu->kvm, sp);
+       return __kvm_sync_page(vcpu, sp, invalid_list);
 }
 
 /* @gfn should be write-protected at the call site */
-static void kvm_sync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+static bool kvm_sync_pages(struct kvm_vcpu *vcpu, gfn_t gfn,
+                          struct list_head *invalid_list)
 {
        struct kvm_mmu_page *s;
-       LIST_HEAD(invalid_list);
-       bool flush = false;
+       bool ret = false;
 
        for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
                if (!s->unsync)
                        continue;
 
                WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
-               kvm_unlink_unsync_page(vcpu->kvm, s);
-               if ((s->role.cr4_pae != !!is_pae(vcpu)) ||
-                       (vcpu->arch.mmu.sync_page(vcpu, s))) {
-                       kvm_mmu_prepare_zap_page(vcpu->kvm, s, &invalid_list);
-                       continue;
-               }
-               flush = true;
+               ret |= kvm_sync_page(vcpu, s, invalid_list);
        }
 
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
-       if (flush)
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+       return ret;
 }
 
 struct mmu_page_path {
-       struct kvm_mmu_page *parent[PT64_ROOT_LEVEL-1];
-       unsigned int idx[PT64_ROOT_LEVEL-1];
+       struct kvm_mmu_page *parent[PT64_ROOT_LEVEL];
+       unsigned int idx[PT64_ROOT_LEVEL];
 };
 
 #define for_each_sp(pvec, sp, parents, i)                      \
-               for (i = mmu_pages_next(&pvec, &parents, -1),   \
-                       sp = pvec.page[i].sp;                   \
+               for (i = mmu_pages_first(&pvec, &parents);      \
                        i < pvec.nr && ({ sp = pvec.page[i].sp; 1;});   \
                        i = mmu_pages_next(&pvec, &parents, i))
 
@@ -1974,19 +1996,43 @@ static int mmu_pages_next(struct kvm_mmu_pages *pvec,
 
        for (n = i+1; n < pvec->nr; n++) {
                struct kvm_mmu_page *sp = pvec->page[n].sp;
+               unsigned idx = pvec->page[n].idx;
+               int level = sp->role.level;
 
-               if (sp->role.level == PT_PAGE_TABLE_LEVEL) {
-                       parents->idx[0] = pvec->page[n].idx;
-                       return n;
-               }
+               parents->idx[level-1] = idx;
+               if (level == PT_PAGE_TABLE_LEVEL)
+                       break;
 
-               parents->parent[sp->role.level-2] = sp;
-               parents->idx[sp->role.level-1] = pvec->page[n].idx;
+               parents->parent[level-2] = sp;
        }
 
        return n;
 }
 
+static int mmu_pages_first(struct kvm_mmu_pages *pvec,
+                          struct mmu_page_path *parents)
+{
+       struct kvm_mmu_page *sp;
+       int level;
+
+       if (pvec->nr == 0)
+               return 0;
+
+       WARN_ON(pvec->page[0].idx != INVALID_INDEX);
+
+       sp = pvec->page[0].sp;
+       level = sp->role.level;
+       WARN_ON(level == PT_PAGE_TABLE_LEVEL);
+
+       parents->parent[level-2] = sp;
+
+       /* Also set up a sentinel.  Further entries in pvec are all
+        * children of sp, so this element is never overwritten.
+        */
+       parents->parent[level-1] = NULL;
+       return mmu_pages_next(pvec, parents, 0);
+}
+
 static void mmu_pages_clear_parents(struct mmu_page_path *parents)
 {
        struct kvm_mmu_page *sp;
@@ -1994,22 +2040,14 @@ static void mmu_pages_clear_parents(struct mmu_page_path *parents)
 
        do {
                unsigned int idx = parents->idx[level];
-
                sp = parents->parent[level];
                if (!sp)
                        return;
 
+               WARN_ON(idx == INVALID_INDEX);
                clear_unsync_child_bit(sp, idx);
                level++;
-       } while (level < PT64_ROOT_LEVEL-1 && !sp->unsync_children);
-}
-
-static void kvm_mmu_pages_init(struct kvm_mmu_page *parent,
-                              struct mmu_page_path *parents,
-                              struct kvm_mmu_pages *pvec)
-{
-       parents->parent[parent->role.level-1] = NULL;
-       pvec->nr = 0;
+       } while (!sp->unsync_children);
 }
 
 static void mmu_sync_children(struct kvm_vcpu *vcpu,
@@ -2020,30 +2058,36 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
        struct mmu_page_path parents;
        struct kvm_mmu_pages pages;
        LIST_HEAD(invalid_list);
+       bool flush = false;
 
-       kvm_mmu_pages_init(parent, &parents, &pages);
        while (mmu_unsync_walk(parent, &pages)) {
                bool protected = false;
 
                for_each_sp(pages, sp, parents, i)
                        protected |= rmap_write_protect(vcpu, sp->gfn);
 
-               if (protected)
+               if (protected) {
                        kvm_flush_remote_tlbs(vcpu->kvm);
+                       flush = false;
+               }
 
                for_each_sp(pages, sp, parents, i) {
-                       kvm_sync_page(vcpu, sp, &invalid_list);
+                       flush |= kvm_sync_page(vcpu, sp, &invalid_list);
                        mmu_pages_clear_parents(&parents);
                }
-               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
-               cond_resched_lock(&vcpu->kvm->mmu_lock);
-               kvm_mmu_pages_init(parent, &parents, &pages);
+               if (need_resched() || spin_needbreak(&vcpu->kvm->mmu_lock)) {
+                       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
+                       cond_resched_lock(&vcpu->kvm->mmu_lock);
+                       flush = false;
+               }
        }
+
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
 }
 
 static void __clear_sp_write_flooding_count(struct kvm_mmu_page *sp)
 {
-       sp->write_flooding_count = 0;
+       atomic_set(&sp->write_flooding_count,  0);
 }
 
 static void clear_sp_write_flooding_count(u64 *spte)
@@ -2069,6 +2113,8 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        unsigned quadrant;
        struct kvm_mmu_page *sp;
        bool need_sync = false;
+       bool flush = false;
+       LIST_HEAD(invalid_list);
 
        role = vcpu->arch.mmu.base_role;
        role.level = level;
@@ -2092,8 +2138,16 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                if (sp->role.word != role.word)
                        continue;
 
-               if (sp->unsync && kvm_sync_page_transient(vcpu, sp))
-                       break;
+               if (sp->unsync) {
+                       /* The page is good, but __kvm_sync_page might still end
+                        * up zapping it.  If so, break in order to rebuild it.
+                        */
+                       if (!__kvm_sync_page(vcpu, sp, &invalid_list))
+                               break;
+
+                       WARN_ON(!list_empty(&invalid_list));
+                       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+               }
 
                if (sp->unsync_children)
                        kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
@@ -2112,16 +2166,24 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        hlist_add_head(&sp->hash_link,
                &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)]);
        if (!direct) {
-               if (rmap_write_protect(vcpu, gfn))
+               /*
+                * we should do write protection before syncing pages
+                * otherwise the content of the synced shadow page may
+                * be inconsistent with guest page table.
+                */
+               account_shadowed(vcpu->kvm, sp);
+               if (level == PT_PAGE_TABLE_LEVEL &&
+                     rmap_write_protect(vcpu, gfn))
                        kvm_flush_remote_tlbs(vcpu->kvm);
-               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
-                       kvm_sync_pages(vcpu, gfn);
 
-               account_shadowed(vcpu->kvm, sp);
+               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
+                       flush |= kvm_sync_pages(vcpu, gfn, &invalid_list);
        }
        sp->mmu_valid_gen = vcpu->kvm->arch.mmu_valid_gen;
        clear_page(sp->spt);
        trace_kvm_mmu_get_page(sp, true);
+
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
        return sp;
 }
 
@@ -2269,7 +2331,6 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
        if (parent->role.level == PT_PAGE_TABLE_LEVEL)
                return 0;
 
-       kvm_mmu_pages_init(parent, &parents, &pages);
        while (mmu_unsync_walk(parent, &pages)) {
                struct kvm_mmu_page *sp;
 
@@ -2278,7 +2339,6 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
                        mmu_pages_clear_parents(&parents);
                        zapped++;
                }
-               kvm_mmu_pages_init(parent, &parents, &pages);
        }
 
        return zapped;
@@ -2329,14 +2389,13 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
                return;
 
        /*
-        * wmb: make sure everyone sees our modifications to the page tables
-        * rmb: make sure we see changes to vcpu->mode
-        */
-       smp_mb();
-
-       /*
-        * Wait for all vcpus to exit guest mode and/or lockless shadow
-        * page table walks.
+        * We need to make sure everyone sees our modifications to
+        * the page tables and see changes to vcpu->mode here. The barrier
+        * in the kvm_flush_remote_tlbs() achieves this. This pairs
+        * with vcpu_enter_guest and walk_shadow_page_lockless_begin/end.
+        *
+        * In addition, kvm_flush_remote_tlbs waits for all vcpus to exit
+        * guest mode and/or lockless shadow page table walks.
         */
        kvm_flush_remote_tlbs(kvm);
 
@@ -2354,8 +2413,8 @@ static bool prepare_zap_oldest_mmu_page(struct kvm *kvm,
        if (list_empty(&kvm->arch.active_mmu_pages))
                return false;
 
-       sp = list_entry(kvm->arch.active_mmu_pages.prev,
-                       struct kvm_mmu_page, link);
+       sp = list_last_entry(&kvm->arch.active_mmu_pages,
+                            struct kvm_mmu_page, link);
        kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
 
        return true;
@@ -2408,7 +2467,7 @@ int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page);
 
-static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+static void kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
        trace_kvm_mmu_unsync_page(sp);
        ++vcpu->kvm->stat.mmu_unsync;
@@ -2417,37 +2476,26 @@ static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        kvm_mmu_mark_parents_unsync(sp);
 }
 
-static void kvm_unsync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+static bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
+                                  bool can_unsync)
 {
-       struct kvm_mmu_page *s;
-
-       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
-               if (s->unsync)
-                       continue;
-               WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
-               __kvm_unsync_page(vcpu, s);
-       }
-}
+       struct kvm_mmu_page *sp;
 
-static int mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
-                                 bool can_unsync)
-{
-       struct kvm_mmu_page *s;
-       bool need_unsync = false;
+       if (kvm_page_track_is_active(vcpu, gfn, KVM_PAGE_TRACK_WRITE))
+               return true;
 
-       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (!can_unsync)
-                       return 1;
+                       return true;
 
-               if (s->role.level != PT_PAGE_TABLE_LEVEL)
-                       return 1;
+               if (sp->unsync)
+                       continue;
 
-               if (!s->unsync)
-                       need_unsync = true;
+               WARN_ON(sp->role.level != PT_PAGE_TABLE_LEVEL);
+               kvm_unsync_page(vcpu, sp);
        }
-       if (need_unsync)
-               kvm_unsync_pages(vcpu, gfn);
-       return 0;
+
+       return false;
 }
 
 static bool kvm_is_mmio_pfn(kvm_pfn_t pfn)
@@ -2503,7 +2551,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                 * be fixed if guest refault.
                 */
                if (level > PT_PAGE_TABLE_LEVEL &&
-                   has_wrprotected_page(vcpu, gfn, level))
+                   mmu_gfn_lpage_is_disallowed(vcpu, gfn, level))
                        goto done;
 
                spte |= PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE;
@@ -2768,7 +2816,7 @@ static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
        if (!is_error_noslot_pfn(pfn) && !kvm_is_reserved_pfn(pfn) &&
            level == PT_PAGE_TABLE_LEVEL &&
            PageTransCompound(pfn_to_page(pfn)) &&
-           !has_wrprotected_page(vcpu, gfn, PT_DIRECTORY_LEVEL)) {
+           !mmu_gfn_lpage_is_disallowed(vcpu, gfn, PT_DIRECTORY_LEVEL)) {
                unsigned long mask;
                /*
                 * mmu_notifier_retry was successful and we hold the
@@ -2796,20 +2844,16 @@ static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
 static bool handle_abnormal_pfn(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn,
                                kvm_pfn_t pfn, unsigned access, int *ret_val)
 {
-       bool ret = true;
-
        /* The pfn is invalid, report the error! */
        if (unlikely(is_error_pfn(pfn))) {
                *ret_val = kvm_handle_bad_page(vcpu, gfn, pfn);
-               goto exit;
+               return true;
        }
 
        if (unlikely(is_noslot_pfn(pfn)))
                vcpu_cache_mmio_info(vcpu, gva, gfn, access);
 
-       ret = false;
-exit:
-       return ret;
+       return false;
 }
 
 static bool page_fault_can_be_fast(u32 error_code)
@@ -3273,7 +3317,7 @@ static bool is_shadow_zero_bits_set(struct kvm_mmu *mmu, u64 spte, int level)
        return __is_rsvd_bits_set(&mmu->shadow_zero_check, spte, level);
 }
 
-static bool quickly_check_mmio_pf(struct kvm_vcpu *vcpu, u64 addr, bool direct)
+static bool mmio_info_in_cache(struct kvm_vcpu *vcpu, u64 addr, bool direct)
 {
        if (direct)
                return vcpu_match_mmio_gpa(vcpu, addr);
@@ -3332,7 +3376,7 @@ int handle_mmio_page_fault(struct kvm_vcpu *vcpu, u64 addr, bool direct)
        u64 spte;
        bool reserved;
 
-       if (quickly_check_mmio_pf(vcpu, addr, direct))
+       if (mmio_info_in_cache(vcpu, addr, direct))
                return RET_MMIO_PF_EMULATE;
 
        reserved = walk_shadow_page_get_mmio_spte(vcpu, addr, &spte);
@@ -3362,20 +3406,53 @@ int handle_mmio_page_fault(struct kvm_vcpu *vcpu, u64 addr, bool direct)
 }
 EXPORT_SYMBOL_GPL(handle_mmio_page_fault);
 
+static bool page_fault_handle_page_track(struct kvm_vcpu *vcpu,
+                                        u32 error_code, gfn_t gfn)
+{
+       if (unlikely(error_code & PFERR_RSVD_MASK))
+               return false;
+
+       if (!(error_code & PFERR_PRESENT_MASK) ||
+             !(error_code & PFERR_WRITE_MASK))
+               return false;
+
+       /*
+        * guest is writing the page which is write tracked which can
+        * not be fixed by page fault handler.
+        */
+       if (kvm_page_track_is_active(vcpu, gfn, KVM_PAGE_TRACK_WRITE))
+               return true;
+
+       return false;
+}
+
+static void shadow_page_table_clear_flood(struct kvm_vcpu *vcpu, gva_t addr)
+{
+       struct kvm_shadow_walk_iterator iterator;
+       u64 spte;
+
+       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+               return;
+
+       walk_shadow_page_lockless_begin(vcpu);
+       for_each_shadow_entry_lockless(vcpu, addr, iterator, spte) {
+               clear_sp_write_flooding_count(iterator.sptep);
+               if (!is_shadow_present_pte(spte))
+                       break;
+       }
+       walk_shadow_page_lockless_end(vcpu);
+}
+
 static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
                                u32 error_code, bool prefault)
 {
-       gfn_t gfn;
+       gfn_t gfn = gva >> PAGE_SHIFT;
        int r;
 
        pgprintk("%s: gva %lx error %x\n", __func__, gva, error_code);
 
-       if (unlikely(error_code & PFERR_RSVD_MASK)) {
-               r = handle_mmio_page_fault(vcpu, gva, true);
-
-               if (likely(r != RET_MMIO_PF_INVALID))
-                       return r;
-       }
+       if (page_fault_handle_page_track(vcpu, error_code, gfn))
+               return 1;
 
        r = mmu_topup_memory_caches(vcpu);
        if (r)
@@ -3383,7 +3460,6 @@ static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
 
        MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
-       gfn = gva >> PAGE_SHIFT;
 
        return nonpaging_map(vcpu, gva & PAGE_MASK,
                             error_code, gfn, prefault);
@@ -3460,12 +3536,8 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
 
        MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
-       if (unlikely(error_code & PFERR_RSVD_MASK)) {
-               r = handle_mmio_page_fault(vcpu, gpa, true);
-
-               if (likely(r != RET_MMIO_PF_INVALID))
-                       return r;
-       }
+       if (page_fault_handle_page_track(vcpu, error_code, gfn))
+               return 1;
 
        r = mmu_topup_memory_caches(vcpu);
        if (r)
@@ -3558,13 +3630,24 @@ static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
        return false;
 }
 
-static inline bool is_last_gpte(struct kvm_mmu *mmu, unsigned level, unsigned gpte)
+static inline bool is_last_gpte(struct kvm_mmu *mmu,
+                               unsigned level, unsigned gpte)
 {
-       unsigned index;
+       /*
+        * PT_PAGE_TABLE_LEVEL always terminates.  The RHS has bit 7 set
+        * iff level <= PT_PAGE_TABLE_LEVEL, which for our purpose means
+        * level == PT_PAGE_TABLE_LEVEL; set PT_PAGE_SIZE_MASK in gpte then.
+        */
+       gpte |= level - PT_PAGE_TABLE_LEVEL - 1;
+
+       /*
+        * The RHS has bit 7 set iff level < mmu->last_nonleaf_level.
+        * If it is clear, there are no large pages at this level, so clear
+        * PT_PAGE_SIZE_MASK in gpte if that is the case.
+        */
+       gpte &= level - mmu->last_nonleaf_level;
 
-       index = level - 1;
-       index |= (gpte & PT_PAGE_SIZE_MASK) >> (PT_PAGE_SIZE_SHIFT - 2);
-       return mmu->last_pte_bitmap & (1 << index);
+       return gpte & PT_PAGE_SIZE_MASK;
 }
 
 #define PTTYPE_EPT 18 /* arbitrary */
@@ -3838,22 +3921,88 @@ static void update_permission_bitmask(struct kvm_vcpu *vcpu,
        }
 }
 
-static void update_last_pte_bitmap(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
-{
-       u8 map;
-       unsigned level, root_level = mmu->root_level;
-       const unsigned ps_set_index = 1 << 2;  /* bit 2 of index: ps */
-
-       if (root_level == PT32E_ROOT_LEVEL)
-               --root_level;
-       /* PT_PAGE_TABLE_LEVEL always terminates */
-       map = 1 | (1 << ps_set_index);
-       for (level = PT_DIRECTORY_LEVEL; level <= root_level; ++level) {
-               if (level <= PT_PDPE_LEVEL
-                   && (mmu->root_level >= PT32E_ROOT_LEVEL || is_pse(vcpu)))
-                       map |= 1 << (ps_set_index | (level - 1));
+/*
+* PKU is an additional mechanism by which the paging controls access to
+* user-mode addresses based on the value in the PKRU register.  Protection
+* key violations are reported through a bit in the page fault error code.
+* Unlike other bits of the error code, the PK bit is not known at the
+* call site of e.g. gva_to_gpa; it must be computed directly in
+* permission_fault based on two bits of PKRU, on some machine state (CR4,
+* CR0, EFER, CPL), and on other bits of the error code and the page tables.
+*
+* In particular the following conditions come from the error code, the
+* page tables and the machine state:
+* - PK is always zero unless CR4.PKE=1 and EFER.LMA=1
+* - PK is always zero if RSVD=1 (reserved bit set) or F=1 (instruction fetch)
+* - PK is always zero if U=0 in the page tables
+* - PKRU.WD is ignored if CR0.WP=0 and the access is a supervisor access.
+*
+* The PKRU bitmask caches the result of these four conditions.  The error
+* code (minus the P bit) and the page table's U bit form an index into the
+* PKRU bitmask.  Two bits of the PKRU bitmask are then extracted and ANDed
+* with the two bits of the PKRU register corresponding to the protection key.
+* For the first three conditions above the bits will be 00, thus masking
+* away both AD and WD.  For all reads or if the last condition holds, WD
+* only will be masked away.
+*/
+static void update_pkru_bitmask(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                               bool ept)
+{
+       unsigned bit;
+       bool wp;
+
+       if (ept) {
+               mmu->pkru_mask = 0;
+               return;
+       }
+
+       /* PKEY is enabled only if CR4.PKE and EFER.LMA are both set. */
+       if (!kvm_read_cr4_bits(vcpu, X86_CR4_PKE) || !is_long_mode(vcpu)) {
+               mmu->pkru_mask = 0;
+               return;
        }
-       mmu->last_pte_bitmap = map;
+
+       wp = is_write_protection(vcpu);
+
+       for (bit = 0; bit < ARRAY_SIZE(mmu->permissions); ++bit) {
+               unsigned pfec, pkey_bits;
+               bool check_pkey, check_write, ff, uf, wf, pte_user;
+
+               pfec = bit << 1;
+               ff = pfec & PFERR_FETCH_MASK;
+               uf = pfec & PFERR_USER_MASK;
+               wf = pfec & PFERR_WRITE_MASK;
+
+               /* PFEC.RSVD is replaced by ACC_USER_MASK. */
+               pte_user = pfec & PFERR_RSVD_MASK;
+
+               /*
+                * Only need to check the access which is not an
+                * instruction fetch and is to a user page.
+                */
+               check_pkey = (!ff && pte_user);
+               /*
+                * write access is controlled by PKRU if it is a
+                * user access or CR0.WP = 1.
+                */
+               check_write = check_pkey && wf && (uf || wp);
+
+               /* PKRU.AD stops both read and write access. */
+               pkey_bits = !!check_pkey;
+               /* PKRU.WD stops write access. */
+               pkey_bits |= (!!check_write) << 1;
+
+               mmu->pkru_mask |= (pkey_bits & 3) << pfec;
+       }
+}
+
+static void update_last_nonleaf_level(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
+{
+       unsigned root_level = mmu->root_level;
+
+       mmu->last_nonleaf_level = root_level;
+       if (root_level == PT32_ROOT_LEVEL && is_pse(vcpu))
+               mmu->last_nonleaf_level++;
 }
 
 static void paging64_init_context_common(struct kvm_vcpu *vcpu,
@@ -3865,7 +4014,8 @@ static void paging64_init_context_common(struct kvm_vcpu *vcpu,
 
        reset_rsvds_bits_mask(vcpu, context);
        update_permission_bitmask(vcpu, context, false);
-       update_last_pte_bitmap(vcpu, context);
+       update_pkru_bitmask(vcpu, context, false);
+       update_last_nonleaf_level(vcpu, context);
 
        MMU_WARN_ON(!is_pae(vcpu));
        context->page_fault = paging64_page_fault;
@@ -3892,7 +4042,8 @@ static void paging32_init_context(struct kvm_vcpu *vcpu,
 
        reset_rsvds_bits_mask(vcpu, context);
        update_permission_bitmask(vcpu, context, false);
-       update_last_pte_bitmap(vcpu, context);
+       update_pkru_bitmask(vcpu, context, false);
+       update_last_nonleaf_level(vcpu, context);
 
        context->page_fault = paging32_page_fault;
        context->gva_to_gpa = paging32_gva_to_gpa;
@@ -3950,7 +4101,8 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        }
 
        update_permission_bitmask(vcpu, context, false);
-       update_last_pte_bitmap(vcpu, context);
+       update_pkru_bitmask(vcpu, context, false);
+       update_last_nonleaf_level(vcpu, context);
        reset_tdp_shadow_zero_bits_mask(vcpu, context);
 }
 
@@ -4002,6 +4154,7 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly)
        context->direct_map = false;
 
        update_permission_bitmask(vcpu, context, true);
+       update_pkru_bitmask(vcpu, context, true);
        reset_rsvds_bits_mask_ept(vcpu, context, execonly);
        reset_ept_shadow_zero_bits_mask(vcpu, context, execonly);
 }
@@ -4056,7 +4209,8 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
        }
 
        update_permission_bitmask(vcpu, g_context, false);
-       update_last_pte_bitmap(vcpu, g_context);
+       update_pkru_bitmask(vcpu, g_context, false);
+       update_last_nonleaf_level(vcpu, g_context);
 }
 
 static void init_kvm_mmu(struct kvm_vcpu *vcpu)
@@ -4127,18 +4281,6 @@ static bool need_remote_flush(u64 old, u64 new)
        return (old & ~new & PT64_PERM_MASK) != 0;
 }
 
-static void mmu_pte_write_flush_tlb(struct kvm_vcpu *vcpu, bool zap_page,
-                                   bool remote_flush, bool local_flush)
-{
-       if (zap_page)
-               return;
-
-       if (remote_flush)
-               kvm_flush_remote_tlbs(vcpu->kvm);
-       else if (local_flush)
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
-}
-
 static u64 mmu_pte_write_fetch_gpte(struct kvm_vcpu *vcpu, gpa_t *gpa,
                                    const u8 *new, int *bytes)
 {
@@ -4188,7 +4330,8 @@ static bool detect_write_flooding(struct kvm_mmu_page *sp)
        if (sp->role.level == PT_PAGE_TABLE_LEVEL)
                return false;
 
-       return ++sp->write_flooding_count >= 3;
+       atomic_inc(&sp->write_flooding_count);
+       return atomic_read(&sp->write_flooding_count) >= 3;
 }
 
 /*
@@ -4250,15 +4393,15 @@ static u64 *get_written_sptes(struct kvm_mmu_page *sp, gpa_t gpa, int *nspte)
        return spte;
 }
 
-void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-                      const u8 *new, int bytes)
+static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
+                             const u8 *new, int bytes)
 {
        gfn_t gfn = gpa >> PAGE_SHIFT;
        struct kvm_mmu_page *sp;
        LIST_HEAD(invalid_list);
        u64 entry, gentry, *spte;
        int npte;
-       bool remote_flush, local_flush, zap_page;
+       bool remote_flush, local_flush;
        union kvm_mmu_page_role mask = { };
 
        mask.cr0_wp = 1;
@@ -4275,7 +4418,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        if (!ACCESS_ONCE(vcpu->kvm->arch.indirect_shadow_pages))
                return;
 
-       zap_page = remote_flush = local_flush = false;
+       remote_flush = local_flush = false;
 
        pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
 
@@ -4295,8 +4438,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (detect_write_misaligned(sp, gpa, bytes) ||
                      detect_write_flooding(sp)) {
-                       zap_page |= !!kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
-                                                    &invalid_list);
+                       kvm_mmu_prepare_zap_page(vcpu->kvm, sp, &invalid_list);
                        ++vcpu->kvm->stat.mmu_flooded;
                        continue;
                }
@@ -4318,8 +4460,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                        ++spte;
                }
        }
-       mmu_pte_write_flush_tlb(vcpu, zap_page, remote_flush, local_flush);
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, remote_flush, local_flush);
        kvm_mmu_audit(vcpu, AUDIT_POST_PTE_WRITE);
        spin_unlock(&vcpu->kvm->mmu_lock);
 }
@@ -4356,32 +4497,34 @@ static void make_mmu_pages_available(struct kvm_vcpu *vcpu)
        kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
 }
 
-static bool is_mmio_page_fault(struct kvm_vcpu *vcpu, gva_t addr)
-{
-       if (vcpu->arch.mmu.direct_map || mmu_is_nested(vcpu))
-               return vcpu_match_mmio_gpa(vcpu, addr);
-
-       return vcpu_match_mmio_gva(vcpu, addr);
-}
-
 int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code,
                       void *insn, int insn_len)
 {
        int r, emulation_type = EMULTYPE_RETRY;
        enum emulation_result er;
+       bool direct = vcpu->arch.mmu.direct_map || mmu_is_nested(vcpu);
+
+       if (unlikely(error_code & PFERR_RSVD_MASK)) {
+               r = handle_mmio_page_fault(vcpu, cr2, direct);
+               if (r == RET_MMIO_PF_EMULATE) {
+                       emulation_type = 0;
+                       goto emulate;
+               }
+               if (r == RET_MMIO_PF_RETRY)
+                       return 1;
+               if (r < 0)
+                       return r;
+       }
 
        r = vcpu->arch.mmu.page_fault(vcpu, cr2, error_code, false);
        if (r < 0)
-               goto out;
-
-       if (!r) {
-               r = 1;
-               goto out;
-       }
+               return r;
+       if (!r)
+               return 1;
 
-       if (is_mmio_page_fault(vcpu, cr2))
+       if (mmio_info_in_cache(vcpu, cr2, direct))
                emulation_type = 0;
-
+emulate:
        er = x86_emulate_instruction(vcpu, cr2, emulation_type, insn, insn_len);
 
        switch (er) {
@@ -4395,8 +4538,6 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code,
        default:
                BUG();
        }
-out:
-       return r;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_page_fault);
 
@@ -4465,6 +4606,21 @@ void kvm_mmu_setup(struct kvm_vcpu *vcpu)
        init_kvm_mmu(vcpu);
 }
 
+void kvm_mmu_init_vm(struct kvm *kvm)
+{
+       struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
+
+       node->track_write = kvm_mmu_pte_write;
+       kvm_page_track_register_notifier(kvm, node);
+}
+
+void kvm_mmu_uninit_vm(struct kvm *kvm)
+{
+       struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
+
+       kvm_page_track_unregister_notifier(kvm, node);
+}
+
 /* The return value indicates if tlb flush on all vcpus is needed. */
 typedef bool (*slot_level_handler) (struct kvm *kvm, struct kvm_rmap_head *rmap_head);