Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This notebook demonstrates an example of a model, where the `evaluate` method dispatches to something not unit aware (like C code), where the input and output needs to be in fixed units."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "import astropy.units as u\n",
- "from astropy.modeling import models as m\n",
- "from astropy.modeling import Model\n",
- "from astropy.utils.decorators import wraps"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Define a decorator to create a quantity object from the return value of a function."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 86,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class CreateReturnQuantity(object):\n",
- " \"\"\"\n",
- " Create a quantity of the return value and the return annotation as a unit.\n",
- " \"\"\"\n",
- " @classmethod\n",
- " def as_decorator(cls, func=None, **kwargs):\n",
- " self = cls(**kwargs)\n",
- " if func is not None and not kwargs:\n",
- " return self(func)\n",
- " else:\n",
- " return self\n",
- " \n",
- " def __init__(self, return_unit=None):\n",
- " self.return_unit = return_unit\n",
- " \n",
- " def __call__(self, func):\n",
- " if hasattr(func, '__annotations__'):\n",
- " return_unit = func.__annotations__['return']\n",
- " elif self.return_unit:\n",
- " return_unit = self.return_unit\n",
- " else:\n",
- " raise ValueError(\"The return unit must be specified either as a\"\n",
- " \" function annotation or as an argument to the decorator.\")\n",
- " \n",
- " @wraps(func)\n",
- " def wrapper(*args, **kwargs):\n",
- " return u.Quantity(func(*args, **kwargs), unit=return_unit)\n",
- " \n",
- " return wrapper\n",
- " \n",
- "create_return_quantity = CreateReturnQuantity.as_decorator"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Our dummy function that `evaluate` will use."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 77,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def some_unitless_function(ham, eggs):\n",
- " \"\"\"\n",
- " This function does something in code that has no idea what a unit is.\n",
- " \n",
- " Parameters\n",
- " ----------\n",
- " ham : `float`\n",
- " ham in units of kg\n",
- " \n",
- " eggs : `float`\n",
- " eggs in units of m\n",
- " \n",
- " Returns\n",
- " -------\n",
- " Ni : `float`\n",
- " The result of ham combined with eggs. In units of J.\n",
- " \"\"\"\n",
- " return np.array(ham) + np.array(eggs)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A model:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 91,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "class Spam(Model):\n",
- " inputs = ('ham', 'eggs')\n",
- " outputs = ('Ni',)\n",
- " \n",
- " @property\n",
- " def input_units(self):\n",
- " return u.kg, u.m\n",
- " \n",
- " @classmethod\n",
- " @create_return_quantity\n",
- " def evaluate(cls, x, y) -> u.J:\n",
- " # Compatibility with these units is ensured by `input_units`\n",
- " # however, this no longer supports non-quantity input.\n",
- " x = x.to(u.kg)\n",
- " y = y.to(u.m)\n",
- " # Call the external function\n",
- " return some_unitless_function(x, y)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 88,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "s = Spam()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 89,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "data": {
- "text/latex": [
- "$0.11 \\; \\mathrm{J}$"
- ],
- "text/plain": [
- "<Quantity 0.11 J>"
- ]
- },
- "execution_count": 89,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "s(10*u.g, 10*u.cm)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Environment (astropy-dev)",
- "language": "python",
- "name": "astropy-dev"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement