Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- (defn my-count [obj]
- (cond
- (number? obj) 0
- (vector? obj) (count obj)))
- (defn correct-tensor? [t]
- (cond
- (number? t) true
- :else (and
- (vector? t)
- (apply = (map my-count t))
- (every? correct-tensor? t))))
- (defn correct-obj? [t]
- (cond
- (number? t) true
- :else (and
- (vector? t)
- (every? correct-obj? t))))
- (defn shape
- ([obj]
- {:pre (correct-tensor? obj)}
- (shape obj []))
- ([obj cur]
- (cond
- (number? obj) cur
- (vector? obj) (shape (first obj) (conj cur (count obj))))))
- (defn dim [obj]
- (count (shape obj)))
- (defn correct-vector? [v]
- (= (dim v) 1))
- (defn correct-matrix? [m]
- (= (dim m) 2))
- (defn cond-operation [op condition]
- (fn [a & args]
- {:pre [(every? condition (conj args a))]}
- (apply op a args)))
- (defn can-coordinate-wise [& args]
- (cond
- (every? number? args) true
- :else (and
- (apply = (map my-count args))
- (every? true? (apply map can-coordinate-wise args)))))
- (defn postfix? [a b]
- (cond
- (<= (count a) (count b))
- (every? zero? (map - a (subvec b (- (count b) (count a)))))
- :else false))
- (defn cw-operation [f]
- (fn res [a & args]
- {:pre [(apply can-coordinate-wise a args)]}
- (cond
- (vector? a) (apply mapv res a args)
- (number? a) (apply f a args))))
- (defn transpose [m]
- {:pre [(correct-matrix? m)]}
- (apply mapv vector m))
- (def scalar)
- (defn matrix-by-vector [m v]
- {:pre [(= (dim m) 2)
- (= (dim v) 1)
- (= (last (shape m)) (first (shape v)))]}
- (mapv (fn [x] (scalar x v)) m))
- (defn matrix-by-matrix [m1 m2]
- {:pre [(= (dim m1) 2)
- (= (dim m2) 2)
- (= (last (shape m1)) (first (shape m2)))]}
- (transpose
- (mapv (fn [v] (matrix-by-vector m1 v))
- (transpose m2))))
- (defn make-multiarg [op]
- (fn [x & args]
- (reduce op x args)))
- (defn broadcast [t1 res-shape]
- {:pre (postfix? (shape t1) res-shape)}
- (cond
- (= (shape t1) res-shape) t1
- :else (vec (repeat
- (first res-shape)
- (broadcast t1 (rest res-shape))))))
- (defn broadcast-all [t & args]
- {:pre (every? correct-tensor? (conj args t))}
- (let [max-shape (apply max-key count (map shape (conj args t)))]
- (map (fn [t] (broadcast t max-shape)) (conj args t))))
- (defn b-operation [f]
- (let [cw (cw-operation f)]
- (fn [t & args]
- {:pre [(every? correct-tensor? (conj args t))]}
- (apply cw (apply broadcast-all t args)))))
- (def s+ (cw-operation +))
- (def s* (cw-operation *))
- (def s- (cw-operation -))
- (def v+ (cond-operation s+ correct-vector?))
- (def v* (cond-operation s* correct-vector?))
- (def v- (cond-operation s- correct-vector?))
- (defn scalar [v & args]
- {:pre [(every? correct-vector? (conj args v))]}
- (reduce + (apply v* v args)))
- (def vect (make-multiarg
- (fn [v1 v2]
- {:pre [(correct-vector? v1)
- (correct-vector? v2)
- (= 3 (first (shape v1)))
- (= 3 (first (shape v2)))]}
- [(- (* (nth v1 1) (nth v2 2)) (* (nth v1 2) (nth v2 1)))
- (- (* (nth v1 2) (nth v2 0)) (* (nth v1 0) (nth v2 2)))
- (- (* (nth v1 0) (nth v2 1)) (* (nth v1 1) (nth v2 0)))])))
- (def v*s (fn [v & args]
- {:pre [(correct-vector? v)
- (every? number? args)]}
- (mapv (partial * (apply * args)) v)))
- (def m+ (cond-operation s+ correct-matrix?))
- (def m* (cond-operation s* correct-matrix?))
- (def m- (cond-operation s- correct-matrix?))
- (def m*s (fn [m & args]
- {:pre [(correct-matrix? m)]}
- (mapv (fn [x] (v*s x (apply * args))) m)))
- (def m*v (make-multiarg matrix-by-vector))
- (def m*m (make-multiarg matrix-by-matrix))
- (def b+ (b-operation +))
- (def b* (b-operation *))
- (def b- (b-operation -))
- ;(println (b+ [[1.1 2.1] [1.2 3.4]]))
- ;(println (b+ 1))
- ;(println (v+ [1.1 2.1] [1.2 3.4]))
- ;(println (v*s [1.1 2.1]))
- ;(println (m*v [[1 2] [3 4] [5 6]] [10 20]))
- ;(println (vect [1 2 3] [4 5 6]))
- ;(println (v+ [1 2 3] [4 5 6]))
- ;(println (s* [1 2] [3 4] [3 4]))
- ;(println (transpose [[1 2] [3 4]]))
- ;(println (transpose [[1 2 3 4] [5 6 7 8] [0 1 0 1] [1 0 1 0]]))
- ;(println (correct-tensor? [[1 2] [3 4]]))
- ;(println (m*m [[1 2] [3 4]] [[1 2] [3 4]]))
- ;(println (shape [[1 2 3] [3 4 3]]))
- ;(println (broadcast-all 1 [[[10 20 30] [40 50 60]] [[10 20 30] [40 50 60]]]))
- ;(println (b+ 1 [[10 20 30] [40 50 60]] [100 200 300]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement