Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
| Download
Try doing some basic maths questions in the Lean Theorem Prover. Functions, real numbers, equivalence relations and groups. Click on README.md and then on "Open in CoCalc with one click".
Project: Xena
Views: 18536License: APACHE
/- Copyright (c) 2018 Alexander Bentkamp. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Alexander Bentkamp -/ import algebra.pi_instances /-! # Basic properties of holors Holors are indexed collections of tensor coefficients. Confusingly, they are often called tensors in physics and in the neural network community. A holor is simply a multidimensional array of values. The size of a holor is specified by a `list ℕ`, whose length is called the dimension of the holor. The tensor product of `x₁ : holor α ds₁` and `x₂ : holor α ds₂` is the holor given by `(x₁ ⊗ x₂) (i₁ ++ i₂) = x₁ i₁ * x₂ i₂`. A holor is "of rank at most 1" if it is a tensor product of one-dimensional holors. The CP rank of a holor `x` is the smallest N such that `x` is the sum of N holors of rank at most 1. Based on the tensor library found in <https://www.isa-afp.org/entries/Deep_Learning.html> ## References * <https://en.wikipedia.org/wiki/Tensor_rank_decomposition> -/ universes u open list /-- `holor_index ds` is the type of valid index tuples to identify an entry of a holor of dimensions `ds` -/ def holor_index (ds : list ℕ) : Type := { is : list ℕ // forall₂ (<) is ds} namespace holor_index variables {ds₁ ds₂ ds₃ : list ℕ} def take : Π {ds₁ : list ℕ}, holor_index (ds₁ ++ ds₂) → holor_index ds₁ | ds is := ⟨ list.take (length ds) is.1, forall₂_take_append is.1 ds ds₂ is.2 ⟩ def drop : Π {ds₁ : list ℕ}, holor_index (ds₁ ++ ds₂) → holor_index ds₂ | ds is := ⟨ list.drop (length ds) is.1, forall₂_drop_append is.1 ds ds₂ is.2 ⟩ lemma cast_type (is : list ℕ) (eq : ds₁ = ds₂) (h : forall₂ (<) is ds₁) : (cast (congr_arg holor_index eq) ⟨is, h⟩).val = is := by subst eq; refl def assoc_right : holor_index (ds₁ ++ ds₂ ++ ds₃) → holor_index (ds₁ ++ (ds₂ ++ ds₃)) := cast (congr_arg holor_index (append_assoc ds₁ ds₂ ds₃)) def assoc_left : holor_index (ds₁ ++ (ds₂ ++ ds₃)) → holor_index (ds₁ ++ ds₂ ++ ds₃) := cast (congr_arg holor_index (append_assoc ds₁ ds₂ ds₃).symm) lemma take_take : ∀ t : holor_index (ds₁ ++ ds₂ ++ ds₃), t.assoc_right.take = t.take.take | ⟨ is , h ⟩ := subtype.eq (by simp [assoc_right,take, cast_type, list.take_take, nat.le_add_right, min_eq_left]) lemma drop_take : ∀ t : holor_index (ds₁ ++ ds₂ ++ ds₃), t.assoc_right.drop.take = t.take.drop | ⟨ is , h ⟩ := subtype.eq (by simp [assoc_right, take, drop, cast_type, list.drop_take]) lemma drop_drop : ∀ t : holor_index (ds₁ ++ ds₂ ++ ds₃), t.assoc_right.drop.drop = t.drop | ⟨ is , h ⟩ := subtype.eq (by simp [assoc_right,drop, cast_type, list.drop_drop]) end holor_index /-- Holor (indexed collections of tensor coefficients) -/ def holor (α : Type u) (ds:list ℕ) := holor_index ds → α namespace holor variables {α : Type} {d : ℕ} {ds : list ℕ} {ds₁ : list ℕ} {ds₂ : list ℕ} {ds₃ : list ℕ} instance [inhabited α] : inhabited (holor α ds) := ⟨λ t, default α⟩ instance [has_zero α] : has_zero (holor α ds) := ⟨λ t, 0⟩ instance [has_add α] : has_add (holor α ds) := ⟨λ x y t, x t + y t⟩ instance [has_neg α] : has_neg (holor α ds) := ⟨λ a t, - a t⟩ instance [add_semigroup α] : add_semigroup (holor α ds) := by pi_instance instance [add_comm_semigroup α] : add_comm_semigroup (holor α ds) := by pi_instance instance [add_monoid α] : add_monoid (holor α ds) := by pi_instance instance [add_comm_monoid α] : add_comm_monoid (holor α ds) := by pi_instance instance [add_group α] : add_group (holor α ds) := by pi_instance instance [add_comm_group α] : add_comm_group (holor α ds) := by pi_instance /- scalar product -/ instance [has_mul α] : has_scalar α (holor α ds) := ⟨λ a x, λ t, a * x t⟩ instance [ring α] : module α (holor α ds) := pi.module α instance [discrete_field α] : vector_space α (holor α ds) := ⟨α, holor α ds⟩ /-- The tensor product of two holors. -/ def mul [s : has_mul α] (x : holor α ds₁) (y : holor α ds₂) : holor α (ds₁ ++ ds₂) := λ t, x (t.take) * y (t.drop) local infix ` ⊗ ` : 70 := mul lemma cast_type (eq : ds₁ = ds₂) (a : holor α ds₁) : cast (congr_arg (holor α) eq) a = (λ t, a (cast (congr_arg holor_index eq.symm) t)) := by subst eq; refl def assoc_right : holor α (ds₁ ++ ds₂ ++ ds₃) → holor α (ds₁ ++ (ds₂ ++ ds₃)) := cast (congr_arg (holor α) (append_assoc ds₁ ds₂ ds₃)) def assoc_left : holor α (ds₁ ++ (ds₂ ++ ds₃)) → holor α (ds₁ ++ ds₂ ++ ds₃) := cast (congr_arg (holor α) (append_assoc ds₁ ds₂ ds₃).symm) lemma mul_assoc0 [semigroup α] (x : holor α ds₁) (y : holor α ds₂) (z : holor α ds₃) : x ⊗ y ⊗ z = (x ⊗ (y ⊗ z)).assoc_left := funext (assume t : holor_index (ds₁ ++ ds₂ ++ ds₃), begin rw assoc_left, unfold mul, rw mul_assoc, rw [←holor_index.take_take, ←holor_index.drop_take, ←holor_index.drop_drop], rw cast_type, refl, rw append_assoc end) lemma mul_assoc [semigroup α] (x : holor α ds₁) (y : holor α ds₂) (z : holor α ds₃) : mul (mul x y) z == (mul x (mul y z)) := by simp [cast_heq, mul_assoc0, assoc_left]. lemma mul_left_distrib [distrib α] (x : holor α ds₁) (y : holor α ds₂) (z : holor α ds₂) : x ⊗ (y + z) = x ⊗ y + x ⊗ z := funext (λt, left_distrib (x (holor_index.take t)) (y (holor_index.drop t)) (z (holor_index.drop t))) lemma mul_right_distrib [distrib α] (x : holor α ds₁) (y : holor α ds₁) (z : holor α ds₂) : (x + y) ⊗ z = x ⊗ z + y ⊗ z := funext (λt, right_distrib (x (holor_index.take t)) (y (holor_index.take t)) (z (holor_index.drop t))) @[simp] lemma zero_mul {α : Type} [ring α] (x : holor α ds₂) : (0 : holor α ds₁) ⊗ x = 0 := funext (λ t, zero_mul (x (holor_index.drop t))) @[simp] lemma mul_zero {α : Type} [ring α] (x : holor α ds₁) : x ⊗ (0 :holor α ds₂) = 0 := funext (λ t, mul_zero (x (holor_index.take t))) lemma mul_scalar_mul [monoid α] (x : holor α []) (y : holor α ds) : x ⊗ y = x ⟨[], forall₂.nil⟩ • y := by simp [mul, has_scalar.smul, holor_index.take, holor_index.drop] /- holor slices -/ /-- A slice is a subholor consisting of all entries with initial index i. -/ def slice (x : holor α (d :: ds)) (i : ℕ) (h : i < d) : holor α ds := (λ is : holor_index ds, x ⟨ i :: is.1, forall₂.cons h is.2⟩) /-- The 1-dimensional "unit" holor with 1 in the `j`th position. -/ def unit_vec [monoid α] [add_monoid α] (d : ℕ) (j : ℕ) : holor α [d] := λ ti, if ti.1 = [j] then 1 else 0 lemma holor_index_cons_decomp (p: holor_index (d :: ds) → Prop) : Π (t : holor_index (d :: ds)), (∀ i is, Π h : t.1 = i :: is, p ⟨ i :: is, begin rw [←h], exact t.2 end ⟩ ) → p t | ⟨[], hforall₂⟩ hp := absurd (forall₂_nil_left_iff.1 hforall₂) (cons_ne_nil d ds) | ⟨(i :: is), hforall₂⟩ hp := hp i is rfl /-- Two holors are equal if all their slices are equal. -/ lemma slice_eq (x : holor α (d :: ds)) (y : holor α (d :: ds)) (h : slice x = slice y) : x = y := funext $ λ t : holor_index (d :: ds), holor_index_cons_decomp (λ t, x t = y t) t $ λ i is hiis, have hiisdds: forall₂ (<) (i :: is) (d :: ds), begin rw [←hiis], exact t.2 end, have hid: i<d, from (forall₂_cons.1 hiisdds).1, have hisds: forall₂ (<) is ds, from (forall₂_cons.1 hiisdds).2, calc x ⟨i :: is, _⟩ = slice x i hid ⟨is, hisds⟩ : congr_arg (λ t, x t) (subtype.eq rfl) ... = slice y i hid ⟨is, hisds⟩ : by rw h ... = y ⟨i :: is, _⟩ : congr_arg (λ t, y t) (subtype.eq rfl) lemma slice_unit_vec_mul [ring α] {i : ℕ} {j : ℕ} (hid : i < d) (x : holor α ds) : slice (unit_vec d j ⊗ x) i hid = if i=j then x else 0 := funext $ λ t : holor_index ds, if h : i = j then by simp [slice, mul, holor_index.take, unit_vec, holor_index.drop, h] else by simp [slice, mul, holor_index.take, unit_vec, holor_index.drop, h]; refl lemma slice_add [has_add α] (i : ℕ) (hid : i < d) (x : holor α (d :: ds)) (y : holor α (d :: ds)) : slice x i hid + slice y i hid = slice (x + y) i hid := funext (λ t, by simp [slice,(+)]) lemma slice_zero [has_zero α] (i : ℕ) (hid : i < d) : slice (0 : holor α (d :: ds)) i hid = 0 := funext (λ t, by simp [slice]; refl) lemma slice_sum [add_comm_monoid α] {β : Type} (i : ℕ) (hid : i < d) (s : finset β) (f : β → holor α (d :: ds)) : finset.sum s (λ x, slice (f x) i hid) = slice (finset.sum s f) i hid := begin letI := classical.dec_eq β, refine finset.induction_on s _ _, { simp [slice_zero] }, { intros _ _ h_not_in ih, rw [finset.sum_insert h_not_in, ih, slice_add, finset.sum_insert h_not_in] } end /-- The original holor can be recovered from its slices by multiplying with unit vectors and summing up. -/ @[simp] lemma sum_unit_vec_mul_slice [ring α] (x : holor α (d :: ds)) : (finset.range d).attach.sum (λ i, unit_vec d i.1 ⊗ slice x i.1 (nat.succ_le_of_lt (finset.mem_range.1 i.2))) = x := begin apply slice_eq _ _ _, ext i hid, rw [←slice_sum], simp only [slice_unit_vec_mul hid], rw finset.sum_eq_single (subtype.mk i _), { simp, refl }, { assume (b : {x // x ∈ finset.range d}) (hb : b ∈ (finset.range d).attach) (hbi : b ≠ ⟨i, _⟩), have hbi' : i ≠ b.val, { apply not.imp hbi, { assume h0 : i = b.val, apply subtype.eq, simp only [h0] }, { exact finset.mem_range.2 hid } }, simp [hbi']}, { assume hid' : subtype.mk i _ ∉ finset.attach (finset.range d), exfalso, exact absurd (finset.mem_attach _ _) hid' } end /- CP rank -/ /-- `cprank_max1 x` means `x` has CP rank at most 1, that is, it is the tensor product of 1-dimensional holors. -/ inductive cprank_max1 [has_mul α]: Π {ds}, holor α ds → Prop | nil (x : holor α []) : cprank_max1 x | cons {d} {ds} (x : holor α [d]) (y : holor α ds) : cprank_max1 y → cprank_max1 (x ⊗ y) /-- `cprank_max N x` means `x` has CP rank at most `N`, that is, it can be written as the sum of N holors of rank at most 1. -/ inductive cprank_max [has_mul α] [add_monoid α] : ℕ → Π {ds}, holor α ds → Prop | zero {ds} : cprank_max 0 (0 : holor α ds) | succ n {ds} (x : holor α ds) (y : holor α ds) : cprank_max1 x → cprank_max n y → cprank_max (n+1) (x + y) lemma cprank_max_nil [monoid α] [add_monoid α] (x : holor α nil) : cprank_max 1 x := have h : _, from cprank_max.succ 0 x 0 (cprank_max1.nil x) (cprank_max.zero), by rwa [add_zero x, zero_add] at h lemma cprank_max_1 [monoid α] [add_monoid α] {x : holor α ds} (h : cprank_max1 x) : cprank_max 1 x := have h' : _, from cprank_max.succ 0 x 0 h cprank_max.zero, by rwa [zero_add, add_zero] at h' lemma cprank_max_add [monoid α] [add_monoid α]: ∀ {m : ℕ} {n : ℕ} {x : holor α ds} {y : holor α ds}, cprank_max m x → cprank_max n y → cprank_max (m + n) (x + y) | 0 n x y (cprank_max.zero) hy := by simp [hy] | (m+1) n _ y (cprank_max.succ k x₁ x₂ hx₁ hx₂) hy := begin simp only [add_comm, add_assoc], apply cprank_max.succ, { assumption }, { exact cprank_max_add hx₂ hy } end lemma cprank_max_mul [ring α] : ∀ (n : ℕ) (x : holor α [d]) (y : holor α ds), cprank_max n y → cprank_max n (x ⊗ y) | 0 x _ (cprank_max.zero) := by simp [mul_zero x, cprank_max.zero] | (n+1) x _ (cprank_max.succ k y₁ y₂ hy₁ hy₂) := begin rw mul_left_distrib, rw nat.add_comm, apply cprank_max_add, { exact cprank_max_1 (cprank_max1.cons _ _ hy₁) }, { exact cprank_max_mul k x y₂ hy₂ } end lemma cprank_max_sum [ring α] {β} {n : ℕ} (s : finset β) (f : β → holor α ds) : (∀ x ∈ s, cprank_max n (f x)) → cprank_max (s.card * n) (finset.sum s f) := by letI := classical.dec_eq β; exact finset.induction_on s (by simp [cprank_max.zero]) (begin assume x s (h_x_notin_s : x ∉ s) ih h_cprank, simp only [finset.sum_insert h_x_notin_s,finset.card_insert_of_not_mem h_x_notin_s], rw nat.right_distrib, simp only [nat.one_mul, nat.add_comm], have ih' : cprank_max (finset.card s * n) (finset.sum s f), { apply ih, assume (x : β) (h_x_in_s: x ∈ s), simp only [h_cprank, finset.mem_insert_of_mem, h_x_in_s] }, exact (cprank_max_add (h_cprank x (finset.mem_insert_self x s)) ih') end) lemma cprank_max_upper_bound [ring α] : Π {ds}, ∀ x : holor α ds, cprank_max ds.prod x | [] x := cprank_max_nil x | (d :: ds) x := have h_summands : Π (i : {x // x ∈ finset.range d}), cprank_max ds.prod (unit_vec d i.1 ⊗ slice x i.1 (mem_range.1 i.2)), from λ i, cprank_max_mul _ _ _ (cprank_max_upper_bound (slice x i.1 (mem_range.1 i.2))), have h_dds_prod : (list.cons d ds).prod = finset.card (finset.range d) * prod ds, by simp [finset.card_range], have cprank_max (finset.card (finset.attach (finset.range d)) * prod ds) (finset.sum (finset.attach (finset.range d)) (λ (i : {x // x ∈ finset.range d}), unit_vec d (i.val)⊗slice x (i.val) (mem_range.1 i.2))), from cprank_max_sum (finset.range d).attach _ (λ i _, h_summands i), have h_cprank_max_sum : cprank_max (finset.card (finset.range d) * prod ds) (finset.sum (finset.attach (finset.range d)) (λ (i : {x // x ∈ finset.range d}), unit_vec d (i.val)⊗slice x (i.val) (mem_range.1 i.2))), by rwa [finset.card_attach] at this, begin rw [←sum_unit_vec_mul_slice x], rw [h_dds_prod], exact h_cprank_max_sum, end /-- The CP rank of a holor `x`: the smallest N such that `x` can be written as the sum of N holors of rank at most 1. -/ noncomputable def cprank [ring α] (x : holor α ds) : nat := @nat.find (λ n, cprank_max n x) (classical.dec_pred _) ⟨ds.prod, cprank_max_upper_bound x⟩ lemma cprank_upper_bound [ring α] : Π {ds}, ∀ x : holor α ds, cprank x ≤ ds.prod := λ ds (x : holor α ds), by letI := classical.dec_pred (λ (n : ℕ), cprank_max n x); exact nat.find_min' ⟨ds.prod, show (λ n, cprank_max n x) ds.prod, from cprank_max_upper_bound x⟩ (cprank_max_upper_bound x) end holor