{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RelationalAI Simple Start Snowflake Notebook\n",
    "\n",
    "## Overview\n",
    "This notebook provides a minimalistic starting point to help you get up and running with RelationalAI in Snowflake Notebooks.\n",
    "\n",
    "To see examples that showcase RelationalAI's analytics capabilities, check out the other notebooks on the [docs site](https://relational.ai/docs/develop/example-notebooks/).\n",
    "\n",
    "## What you will learn\n",
    "\n",
    "- How to get RelationalAI's Python library running in your Snowflake account\n",
    "- How to run a simple graph algorithm over your data.\n",
    "\n",
    "***\n",
    "\n",
    "## Let's get started!\n",
    "\n",
    "Two preliminary steps are required before you can run this notebook:\n",
    "\n",
    "1. Choose *Notebook settings* from the triple-dots (⋮) dropdown in the top-right corner of the Snowflake Notebooks window and switch to the *External access* tab. Turn on the `S3_RAI_INTERNAL_BUCKET_EGRESS_INTEGRATION` toggle. This allows your notebook to access data from the native app.\n",
    "2. Select `networkx` and `matplotlib` from the Packages dropdown in the top-right corner of the Snowflake Notebooks window. These packages are not `relationalai` dependencies, but they are used for graph visualization purposes in this notebook.\n",
    "3. Upload the RelationalAI Python library as a ZIP file into your notebook filesystem. You can download the ZIP file from the [RelationalAI website](https://relational.ai/relationalai.zip).\n",
    "4. Run the code cell below to make the packages in the ZIP file visible to the Python interpreter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"./relationalai.zip\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we're ready to import the RelationalAI Python library and start using it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import relationalai as rai\n",
    "from relationalai.std import aggregates\n",
    "from relationalai.std.graphs import Graph\n",
    "from relationalai.std import alias\n",
    "\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Setup\n",
    "\n",
    "Run the code in this section to create a small table in the schema `RAI_DEMO.SIMPLE_START` in your Snowflake account."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "provider = rai.Provider()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "provider.sql(\"\"\"\n",
    "begin\n",
    "    create database if not exists RAI_DEMO;\n",
    "    create schema if not exists RAI_DEMO.SIMPLE_START;\n",
    "\n",
    "    create or replace table RAI_DEMO.SIMPLE_START.CONNECTIONS (\n",
    "        station_1 int,\n",
    "        station_2 int\n",
    "    );\n",
    "\n",
    "    insert into RAI_DEMO.SIMPLE_START.CONNECTIONS (station_1, station_2) values\n",
    "    (1, 2),\n",
    "    (1, 3),\n",
    "    (3, 4),\n",
    "    (1, 4),\n",
    "    (4, 5),\n",
    "    (5, 7),\n",
    "    (6, 7),\n",
    "    (6, 8),\n",
    "    (7, 8);\n",
    "end;\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the data is in your Snowflake account, let's set up a data stream to make this table accessible to your RelationalAI Python program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "provider.create_streams(\n",
    "    [\"RAI_DEMO.SIMPLE_START.CONNECTIONS\"],\n",
    "    \"SimpleStart\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Model in RelationalAI\n",
    "Let's define our model object. **Models** represent collections of objects. **Objects**, like Python objects, have **types** and **properties**, which we will define in a bit."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note: Invoking the `Model` constructor will create a new engine for you if it hasn't already been created, so running this line of code might take a few minutes:*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = rai.Model(\"SimpleStart\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Referencing Snowflake Data\n",
    "Due to RelationalAI's tight integration with Snowflake, we can refer to data that we are streaming to our RelationalAI schema by simply referring to the source:\\\n",
    "**\\<my_database>.\\<my_schema>.<my_table\\>**.\n",
    "\n",
    "The data for this example consists of a single table called `CONNECTIONS` whose columns are called `station_1` and `station_2`. These station values represent IDs of power stations, and a row in the table represents a connection (via transmission lines and substations) between two power stations.\n",
    "\n",
    "Accordingly, we will introduce two *types* that represent the two kinds of objects in our model: `Station` and `Connection`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Station = model.Type(\"Station\")\n",
    "\n",
    "Connection = model.Type(\n",
    "    \"Connection\",\n",
    "    source=\"RAI_DEMO.SIMPLE_START.CONNECTIONS\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple Queries\n",
    "\n",
    "We can run a query to count the number of connections as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>result</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<relationalai.dsl.Context at 0x14ace84d0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Count number of rows in the connections table:\n",
    "with model.query() as select:\n",
    "    connection = Connection()\n",
    "    num_records = aggregates.count(connection)\n",
    "    res = select(num_records)\n",
    "\n",
    "res.results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We haven't yet said what a `Station` is. We can do that using a **rule**. In this rule we will also define the `is_connected` property for stations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with model.rule():\n",
    "    connection = Connection()\n",
    "    station_1 = Station.add(id=connection.station_1)\n",
    "    station_2 = Station.add(id=connection.station_2)\n",
    "    station_1.is_connected.extend([station_2])\n",
    "    station_2.is_connected.extend([station_1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can ask for a list of all the stations IDs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<relationalai.dsl.Context at 0x14ac92390>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with model.query() as select:\n",
    "    station = Station()\n",
    "    res = select(station.id)\n",
    "\n",
    "res.results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The power stations and connections between them form a graph. We can model a graph in RelationalAI by wrapping the model in a `Graph` object and associating data with its `Node` and `Edge` properties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = Graph(model, undirected = True)\n",
    "graph.Node.extend(Station, label=Station.id)\n",
    "graph.Edge.extend(Station.is_connected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_graph(graph_data):\n",
    "    G = nx.DiGraph()\n",
    "\n",
    "    for node, props in graph_data['nodes'].items():\n",
    "        G.add_node(node, **props)\n",
    "\n",
    "    for (a, b), props in graph_data['edges'].items():\n",
    "        G.add_edge(a, b, **props)\n",
    "\n",
    "    pos = nx.planar_layout(G)\n",
    "    plt.figure(figsize=(4, 2))\n",
    "    labels = nx.get_node_attributes(G, 'label')\n",
    "    nx.draw(G, pos, labels=labels, with_labels=True, font_color=\"white\", edge_color=\"gray\")\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "graph_data = graph.fetch()\n",
    "visualize_graph(graph_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see from the figure that Stations 4, 5, and 7 are especially critical to this network because they provide the only connection between two larger components. It's important to be able to quantify and compute this kind of information because it would not be so visually apparent in a larger, real-world network.\n",
    "\n",
    "We can do that using a graph analytics metric called **betweenness centrality**. This metric and others are available under the `graph.compute` namespace:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>id</th>\n",
       "      <th>betweenness_centrality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<relationalai.dsl.Context at 0x14ad5a290>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with model.query() as select:\n",
    "    station = Station()\n",
    "    centrality = graph.compute.betweenness_centrality(station)\n",
    "    res = select(station.id, alias(centrality, \"betweenness_centrality\"))\n",
    "\n",
    "res.results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As predicted, Stations 4, 5, and 7 have the highest betweenness centrality scores!\n",
    "\n",
    "### Conclusion\n",
    "\n",
    "The functionality demonstrated in this notebook barely scratches the surface of what's possible with RelationalAI. The [Example Notebooks page on the docs site](https://relational.ai/docs/develop/example-notebooks/) contains a variety of example notebooks, each of which explores a scenario and a set of analytics capabilities in greater depth."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brought to you by [RelationalAI](https://relational.ai/) & Snowflake Native Applications!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
