Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
suyashi29
GitHub Repository: suyashi29/python-su
Path: blob/master/Generative AI for Intelligent Data Handling/Day 4 Understanding Recurrent Neural Networks (RNNs) and its example in Sequence Generation.ipynb
3074 views
Kernel: Python 3 (ipykernel)

Recurrent Neural Networks (RNNs) are a type of neural network architecture that is particularly well-suited for tasks involving sequential data. Unlike feedforward neural networks, which process data in fixed-size chunks, RNNs can handle input sequences of arbitrary length.

key features of RNNs:

  • Recurrent Connections: RNNs have recurrent connections that allow information to persist across different time steps in a sequence. This means that information from previous inputs is considered when processing the current input.

  • Shared Parameters: The same set of weights and biases are applied at each time step. This allows the network to use the same computation for different elements of the sequence.

  • Time Dependency: RNNs are well-suited for tasks where the order or temporal dependency of data matters, such as time series prediction, language modeling, and speech recognition.

image.png

Applications of RNNs:

  • Language Modeling and Text Generation: RNNs can be used to model the probability distribution of sequences of words. This enables tasks like auto-completion, machine translation, and text generation.

  • Time Series Prediction: RNNs are effective for tasks like stock price prediction, weather forecasting, and any scenario where the current state depends on previous states.

  • Speech Recognition: RNNs can be used to convert spoken language into written text. This is crucial for applications like voice assistants (e.g., Siri, Alexa).

  • Handwriting Recognition: RNNs can recognize handwritten text, enabling applications like digit recognition and signature verification.

  • Image Captioning: RNNs can be combined with Convolutional Neural Networks (CNNs) to generate captions for images.

  • Video Analysis: RNNs can process sequences of images or video frames, making them useful for tasks like action recognition, video captioning, and video prediction.

  • Anomaly Detection: RNNs can be used to detect anomalies in sequences of data, making them valuable for tasks like fraud detection in finance or detecting defects in manufacturing.

  • Sentiment Analysis: RNNs can analyze sequences of text to determine the sentiment expressed.

Mathematical Implementation:

Terms:

  • xt: Input at time step at t

  • ht: Hidden state at time step at t

  • Whx: Weight matrix for input-to-hidden connections

  • Whh: Weight matrix for hidden-to-hidden connections

  • bh:Bias term for hidden layer

  • Wyh: Weight matrix for hidden-to-output connection

  • by: Bias term for output layer

image.png

Training:

During training, you would use backpropagation through time (BPTT) to compute gradients and update the weights and biases to minimize the loss function. Prediction:

Once the network is trained, you can make predictions by passing a sequence of inputs through the network. This is a basic mathematical interpretation of a simple RNN. In practice, more sophisticated variants like LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) are often used to address issues like vanishing gradients and better capture long-term dependencies.

Below is a basic implementation of a simple RNN using only the NumPy library. This code demonstrates how you can manually perform forward passes through time.

import numpy as np # Define the sigmoid activation function def sigmoid(x): return 1 / (1 + np.exp(-x)) # Define the hyperbolic tangent (tanh) activation function def tanh(x): return np.tanh(x) # Define the derivative of the tanh activation function def tanh_derivative(x): return 1 - np.tanh(x)**2 # Define the RNN class class SimpleRNN: def __init__(self, input_size, hidden_size, output_size): # Initialize weights and biases self.W_hx = np.random.randn(hidden_size, input_size) self.W_hh = np.random.randn(hidden_size, hidden_size) self.W_yh = np.random.randn(output_size, hidden_size) self.b_h = np.zeros((hidden_size, 1)) self.b_y = np.zeros((output_size, 1)) def forward(self, x): # Initialize hidden state h = np.zeros((self.W_hx.shape[0], 1)) # Lists to store intermediate values self.h_states = [] self.x_inputs = [] for t in range(len(x)): # Update hidden state h = tanh(np.dot(self.W_hx, x[t]) + np.dot(self.W_hh, h) + self.b_h) self.h_states.append(h) self.x_inputs.append(x[t]) # Calculate output y = np.dot(self.W_yh, h) + self.b_y return y, h def backward(self, x, y_true, learning_rate): # Initialize gradients dW_hx, dW_hh, dW_yh = np.zeros_like(self.W_hx), np.zeros_like(self.W_hh), np.zeros_like(self.W_yh) db_h, db_y = np.zeros_like(self.b_h), np.zeros_like(self.b_y) dh_next = np.zeros_like(self.h_states[0]) for t in reversed(range(len(x))): # Compute gradients dy = y_true - x[t] dW_yh += np.dot(dy, self.h_states[t].T) db_y += dy dh = np.dot(self.W_yh.T, dy) + dh_next dh_raw = tanh_derivative(self.h_states[t]) * dh db_h += dh_raw dW_hx += np.dot(dh_raw, self.x_inputs[t].T) dW_hh += np.dot(dh_raw, self.h_states[t-1].T) dh_next = np.dot(self.W_hh.T, dh_raw) # Clip gradients to avoid exploding gradients (optional) for gradient in [dW_hx, dW_hh, dW_yh, db_h, db_y]: np.clip(gradient, -5, 5, out=gradient) # Update weights and biases self.W_hx -= learning_rate * dW_hx self.W_hh -= learning_rate * dW_hh self.W_yh -= learning_rate * dW_yh self.b_h -= learning_rate * db_h self.b_y -= learning_rate * db_y

Explanation:

  • The code defines a basic RNN class (SimpleRNN) with methods for forward pass (forward) and backward pass (backward).

  • The activation functions (sigmoid and tanh) and their derivatives are defined.

  • The forward method performs a forward pass through the RNN, storing intermediate values for backpropagation.

  • The backward method computes gradients and updates the weights and biases using backpropagation through time (BPTT).

Let us use Keras library to create and train a basic RNN for a toy example of sequence prediction. This example uses a very simple sequence (1, 2, 3, 4, 5) and tries to predict the next number in the sequence.

import numpy as np from tensorflow.keras.models import Sequential from tensorflow.keras.layers import SimpleRNN, Dense # Generate some sample data X = np.array([[i+j for j in range(5)] for i in range(100)]) y = np.array([i+5 for i in range(100)]) y
array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104])
X
array([[ 0, 1, 2, 3, 4], [ 1, 2, 3, 4, 5], [ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10], [ 7, 8, 9, 10, 11], [ 8, 9, 10, 11, 12], [ 9, 10, 11, 12, 13], [ 10, 11, 12, 13, 14], [ 11, 12, 13, 14, 15], [ 12, 13, 14, 15, 16], [ 13, 14, 15, 16, 17], [ 14, 15, 16, 17, 18], [ 15, 16, 17, 18, 19], [ 16, 17, 18, 19, 20], [ 17, 18, 19, 20, 21], [ 18, 19, 20, 21, 22], [ 19, 20, 21, 22, 23], [ 20, 21, 22, 23, 24], [ 21, 22, 23, 24, 25], [ 22, 23, 24, 25, 26], [ 23, 24, 25, 26, 27], [ 24, 25, 26, 27, 28], [ 25, 26, 27, 28, 29], [ 26, 27, 28, 29, 30], [ 27, 28, 29, 30, 31], [ 28, 29, 30, 31, 32], [ 29, 30, 31, 32, 33], [ 30, 31, 32, 33, 34], [ 31, 32, 33, 34, 35], [ 32, 33, 34, 35, 36], [ 33, 34, 35, 36, 37], [ 34, 35, 36, 37, 38], [ 35, 36, 37, 38, 39], [ 36, 37, 38, 39, 40], [ 37, 38, 39, 40, 41], [ 38, 39, 40, 41, 42], [ 39, 40, 41, 42, 43], [ 40, 41, 42, 43, 44], [ 41, 42, 43, 44, 45], [ 42, 43, 44, 45, 46], [ 43, 44, 45, 46, 47], [ 44, 45, 46, 47, 48], [ 45, 46, 47, 48, 49], [ 46, 47, 48, 49, 50], [ 47, 48, 49, 50, 51], [ 48, 49, 50, 51, 52], [ 49, 50, 51, 52, 53], [ 50, 51, 52, 53, 54], [ 51, 52, 53, 54, 55], [ 52, 53, 54, 55, 56], [ 53, 54, 55, 56, 57], [ 54, 55, 56, 57, 58], [ 55, 56, 57, 58, 59], [ 56, 57, 58, 59, 60], [ 57, 58, 59, 60, 61], [ 58, 59, 60, 61, 62], [ 59, 60, 61, 62, 63], [ 60, 61, 62, 63, 64], [ 61, 62, 63, 64, 65], [ 62, 63, 64, 65, 66], [ 63, 64, 65, 66, 67], [ 64, 65, 66, 67, 68], [ 65, 66, 67, 68, 69], [ 66, 67, 68, 69, 70], [ 67, 68, 69, 70, 71], [ 68, 69, 70, 71, 72], [ 69, 70, 71, 72, 73], [ 70, 71, 72, 73, 74], [ 71, 72, 73, 74, 75], [ 72, 73, 74, 75, 76], [ 73, 74, 75, 76, 77], [ 74, 75, 76, 77, 78], [ 75, 76, 77, 78, 79], [ 76, 77, 78, 79, 80], [ 77, 78, 79, 80, 81], [ 78, 79, 80, 81, 82], [ 79, 80, 81, 82, 83], [ 80, 81, 82, 83, 84], [ 81, 82, 83, 84, 85], [ 82, 83, 84, 85, 86], [ 83, 84, 85, 86, 87], [ 84, 85, 86, 87, 88], [ 85, 86, 87, 88, 89], [ 86, 87, 88, 89, 90], [ 87, 88, 89, 90, 91], [ 88, 89, 90, 91, 92], [ 89, 90, 91, 92, 93], [ 90, 91, 92, 93, 94], [ 91, 92, 93, 94, 95], [ 92, 93, 94, 95, 96], [ 93, 94, 95, 96, 97], [ 94, 95, 96, 97, 98], [ 95, 96, 97, 98, 99], [ 96, 97, 98, 99, 100], [ 97, 98, 99, 100, 101], [ 98, 99, 100, 101, 102], [ 99, 100, 101, 102, 103]])
# Reshape the data for RNN input (samples, time steps, features) X = X.reshape((X.shape[0], X.shape[1], 1))
# Define the RNN model model = Sequential([ SimpleRNN(units=32, input_shape=(X.shape[1], X.shape[2]), activation='relu'), Dense(1) ])
# Compile the model model.compile(optimizer='adam', loss='mean_squared_error') # Train the model model.fit(X, y, epochs=10, batch_size=8)
Epoch 1/10 13/13 [==============================] - 2s 6ms/step - loss: 1814.8225 Epoch 2/10 13/13 [==============================] - 0s 8ms/step - loss: 211.9403 Epoch 3/10 13/13 [==============================] - 0s 7ms/step - loss: 59.4532 Epoch 4/10 13/13 [==============================] - 0s 7ms/step - loss: 10.0549 Epoch 5/10 13/13 [==============================] - 0s 6ms/step - loss: 4.6951 Epoch 6/10 13/13 [==============================] - 0s 8ms/step - loss: 3.3915 Epoch 7/10 13/13 [==============================] - 0s 6ms/step - loss: 2.2211 Epoch 8/10 13/13 [==============================] - 0s 10ms/step - loss: 2.0450 Epoch 9/10 13/13 [==============================] - 0s 11ms/step - loss: 1.8847 Epoch 10/10 13/13 [==============================] - 0s 12ms/step - loss: 1.9048
<keras.callbacks.History at 0x1e6f01ffb20>
# Test the model test_input = np.array([[i+j for j in range(5)] for i in range(100, 115)]) test_input = test_input.reshape((test_input.shape[0], test_input.shape[1], 1)) predicted_output = model.predict(test_input)
1/1 [==============================] - 0s 349ms/step
test_input
array([[[100], [101], [102], [103], [104]], [[101], [102], [103], [104], [105]], [[102], [103], [104], [105], [106]], [[103], [104], [105], [106], [107]], [[104], [105], [106], [107], [108]], [[105], [106], [107], [108], [109]], [[106], [107], [108], [109], [110]], [[107], [108], [109], [110], [111]], [[108], [109], [110], [111], [112]], [[109], [110], [111], [112], [113]], [[110], [111], [112], [113], [114]], [[111], [112], [113], [114], [115]], [[112], [113], [114], [115], [116]], [[113], [114], [115], [116], [117]], [[114], [115], [116], [117], [118]]])
# Print the predicted output print("Predicted Output:") print(predicted_output.flatten())
Predicted Output: [106.44789 107.489365 108.530846 109.57231 110.61378 111.65525 112.696724 113.7382 114.77966 115.82114 116.8626 117.90408 118.945564 119.987015 121.028496]

Let's create a simple RNN using Keras with some sample data. In this example, we'll use a sequence of numbers to predict the next number in the sequence.

import numpy as np from tensorflow.keras.models import Sequential from tensorflow.keras.layers import SimpleRNN, Dense
# Generate some sample data np.random.seed(0) sequence_length = 10 X = np.random.rand(100, sequence_length) y = np.sum(X, axis=1) # Reshape the data for RNN input (samples, time steps, features) X = X.reshape((X.shape[0], X.shape[1], 1))
X
array([[[5.48813504e-01], [7.15189366e-01], [6.02763376e-01], [5.44883183e-01], [4.23654799e-01], [6.45894113e-01], [4.37587211e-01], [8.91773001e-01], [9.63662761e-01], [3.83441519e-01]], [[7.91725038e-01], [5.28894920e-01], [5.68044561e-01], [9.25596638e-01], [7.10360582e-02], [8.71292997e-02], [2.02183974e-02], [8.32619846e-01], [7.78156751e-01], [8.70012148e-01]], [[9.78618342e-01], [7.99158564e-01], [4.61479362e-01], [7.80529176e-01], [1.18274426e-01], [6.39921021e-01], [1.43353287e-01], [9.44668917e-01], [5.21848322e-01], [4.14661940e-01]], [[2.64555612e-01], [7.74233689e-01], [4.56150332e-01], [5.68433949e-01], [1.87898004e-02], [6.17635497e-01], [6.12095723e-01], [6.16933997e-01], [9.43748079e-01], [6.81820299e-01]], [[3.59507901e-01], [4.37031954e-01], [6.97631196e-01], [6.02254716e-02], [6.66766715e-01], [6.70637870e-01], [2.10382561e-01], [1.28926298e-01], [3.15428351e-01], [3.63710771e-01]], [[5.70196770e-01], [4.38601513e-01], [9.88373838e-01], [1.02044811e-01], [2.08876756e-01], [1.61309518e-01], [6.53108325e-01], [2.53291603e-01], [4.66310773e-01], [2.44425592e-01]], [[1.58969584e-01], [1.10375141e-01], [6.56329589e-01], [1.38182951e-01], [1.96582362e-01], [3.68725171e-01], [8.20993230e-01], [9.71012758e-02], [8.37944907e-01], [9.60984079e-02]], [[9.76459465e-01], [4.68651202e-01], [9.76761088e-01], [6.04845520e-01], [7.39263579e-01], [3.91877923e-02], [2.82806963e-01], [1.20196561e-01], [2.96140198e-01], [1.18727719e-01]], [[3.17983179e-01], [4.14262995e-01], [6.41474963e-02], [6.92472119e-01], [5.66601454e-01], [2.65389491e-01], [5.23248053e-01], [9.39405108e-02], [5.75946496e-01], [9.29296198e-01]], [[3.18568952e-01], [6.67410380e-01], [1.31797862e-01], [7.16327204e-01], [2.89406093e-01], [1.83191362e-01], [5.86512935e-01], [2.01075462e-02], [8.28940029e-01], [4.69547619e-03]], [[6.77816537e-01], [2.70007973e-01], [7.35194022e-01], [9.62188545e-01], [2.48753144e-01], [5.76157334e-01], [5.92041931e-01], [5.72251906e-01], [2.23081633e-01], [9.52749012e-01]], [[4.47125379e-01], [8.46408672e-01], [6.99479275e-01], [2.97436951e-01], [8.13797820e-01], [3.96505741e-01], [8.81103197e-01], [5.81272873e-01], [8.81735362e-01], [6.92531590e-01]], [[7.25254280e-01], [5.01324382e-01], [9.56083635e-01], [6.43990199e-01], [4.23855049e-01], [6.06393214e-01], [1.91931983e-02], [3.01574817e-01], [6.60173537e-01], [2.90077607e-01]], [[6.18015429e-01], [4.28768701e-01], [1.35474064e-01], [2.98282326e-01], [5.69964911e-01], [5.90872761e-01], [5.74325249e-01], [6.53200820e-01], [6.52103270e-01], [4.31418435e-01]], [[8.96546596e-01], [3.67561870e-01], [4.35864925e-01], [8.91923355e-01], [8.06193989e-01], [7.03888584e-01], [1.00226887e-01], [9.19482614e-01], [7.14241300e-01], [9.98847007e-01]], [[1.49448305e-01], [8.68126057e-01], [1.62492935e-01], [6.15559564e-01], [1.23819983e-01], [8.48008229e-01], [8.07318959e-01], [5.69100739e-01], [4.07183297e-01], [6.91669955e-02]], [[6.97428773e-01], [4.53542683e-01], [7.22055599e-01], [8.66382326e-01], [9.75521505e-01], [8.55803342e-01], [1.17140842e-02], [3.59978064e-01], [7.29990562e-01], [1.71629677e-01]], [[5.21036606e-01], [5.43379883e-02], [1.99996525e-01], [1.85217945e-02], [7.93697703e-01], [2.23924688e-01], [3.45351681e-01], [9.28081293e-01], [7.04414402e-01], [3.18389295e-02]], [[1.64694156e-01], [6.21478401e-01], [5.77228589e-01], [2.37892821e-01], [9.34213998e-01], [6.13965956e-01], [5.35632803e-01], [5.89909976e-01], [7.30122030e-01], [3.11944995e-01]], [[3.98221062e-01], [2.09843749e-01], [1.86193006e-01], [9.44372390e-01], [7.39550795e-01], [4.90458809e-01], [2.27414628e-01], [2.54356482e-01], [5.80291603e-02], [4.34416626e-01]], [[3.11795882e-01], [6.96343489e-01], [3.77751839e-01], [1.79603678e-01], [2.46787284e-02], [6.72496315e-02], [6.79392773e-01], [4.53696845e-01], [5.36579211e-01], [8.96671293e-01]], [[9.90338947e-01], [2.16896984e-01], [6.63078203e-01], [2.63322377e-01], [2.06509995e-02], [7.58378654e-01], [3.20017151e-01], [3.83463894e-01], [5.88317114e-01], [8.31048455e-01]], [[6.28981844e-01], [8.72650655e-01], [2.73542035e-01], [7.98046834e-01], [1.85635944e-01], [9.52791657e-01], [6.87488276e-01], [2.15507677e-01], [9.47370590e-01], [7.30855807e-01]], [[2.53941643e-01], [2.13311977e-01], [5.18200714e-01], [2.56627181e-02], [2.07470075e-01], [4.24685469e-01], [3.74169980e-01], [4.63575424e-01], [2.77628706e-01], [5.86784346e-01]], [[8.63855606e-01], [1.17531856e-01], [5.17379107e-01], [1.32068106e-01], [7.16859681e-01], [3.96059703e-01], [5.65421312e-01], [1.83279836e-01], [1.44847759e-01], [4.88056281e-01]], [[3.55612738e-01], [9.40431945e-01], [7.65325254e-01], [7.48663620e-01], [9.03719740e-01], [8.34224354e-02], [5.52192470e-01], [5.84476069e-01], [9.61936379e-01], [2.92147527e-01]], [[2.40828780e-01], [1.00293942e-01], [1.64296296e-02], [9.29529317e-01], [6.69916547e-01], [7.85152912e-01], [2.81730106e-01], [5.86410166e-01], [6.39552661e-02], [4.85627596e-01]], [[9.77495140e-01], [8.76505245e-01], [3.38158952e-01], [9.61570155e-01], [2.31701626e-01], [9.49318822e-01], [9.41377705e-01], [7.99202587e-01], [6.30447937e-01], [8.74287967e-01]], [[2.93020285e-01], [8.48943555e-01], [6.17876692e-01], [1.32368578e-02], [3.47233518e-01], [1.48140861e-01], [9.81829390e-01], [4.78370307e-01], [4.97391365e-01], [6.39472516e-01]], [[3.68584606e-01], [1.36900272e-01], [8.22117733e-01], [1.89847912e-01], [5.11318983e-01], [2.24317029e-01], [9.78444845e-02], [8.62191517e-01], [9.72919489e-01], [9.60834658e-01]], [[9.06555499e-01], [7.74047333e-01], [3.33145152e-01], [8.11013900e-02], [4.07241171e-01], [2.32234142e-01], [1.32487635e-01], [5.34271818e-02], [7.25594364e-01], [1.14274586e-02]], [[7.70580749e-01], [1.46946645e-01], [7.95220826e-02], [8.96030342e-02], [6.72047807e-01], [2.45367210e-01], [4.20539467e-01], [5.57368791e-01], [8.60551174e-01], [7.27044263e-01]], [[2.70327905e-01], [1.31482799e-01], [5.53743204e-02], [3.01598634e-01], [2.62118149e-01], [4.56140567e-01], [6.83281336e-01], [6.95625446e-01], [2.83518847e-01], [3.79926956e-01]], [[1.81150962e-01], [7.88545512e-01], [5.68480764e-02], [6.96997242e-01], [7.78695396e-01], [7.77407562e-01], [2.59422564e-01], [3.73813138e-01], [5.87599635e-01], [2.72821902e-01]], [[3.70852799e-01], [1.97054280e-01], [4.59855884e-01], [4.46123013e-02], [7.99795885e-01], [7.69564470e-02], [5.18835149e-01], [3.06810100e-01], [5.77542949e-01], [9.59433341e-01]], [[6.45570244e-01], [3.53624358e-02], [4.30402440e-01], [5.10016852e-01], [5.36177495e-01], [6.81392511e-01], [2.77596098e-01], [1.28860565e-01], [3.92675677e-01], [9.56405723e-01]], [[1.87130892e-01], [9.03983955e-01], [5.43805950e-01], [4.56911422e-01], [8.82041410e-01], [4.58603962e-01], [7.24167637e-01], [3.99025322e-01], [9.04044393e-01], [6.90025020e-01]], [[6.99622054e-01], [3.27720402e-01], [7.56778643e-01], [6.36061055e-01], [2.40020273e-01], [1.60538822e-01], [7.96391475e-01], [9.59166603e-01], [4.58138827e-01], [5.90984165e-01]], [[8.57722644e-01], [4.57223453e-01], [9.51874477e-01], [5.75751162e-01], [8.20767121e-01], [9.08843718e-01], [8.15523819e-01], [1.59414463e-01], [6.28898439e-01], [3.98434259e-01]], [[6.27129520e-02], [4.24032252e-01], [2.58684067e-01], [8.49038308e-01], [3.33046265e-02], [9.58982722e-01], [3.55368848e-01], [3.56706890e-01], [1.63285027e-02], [1.85232325e-01]], [[4.01259501e-01], [9.29291417e-01], [9.96149302e-02], [9.45301533e-01], [8.69488531e-01], [4.54162397e-01], [3.26700882e-01], [2.32744129e-01], [6.14464706e-01], [3.30745915e-02]], [[1.56060644e-02], [4.28795722e-01], [6.80740740e-02], [2.51940988e-01], [2.21160915e-01], [2.53191194e-01], [1.31055231e-01], [1.20362229e-02], [1.15484297e-01], [6.18480260e-01]], [[9.74256213e-01], [9.90345002e-01], [4.09054095e-01], [1.62954426e-01], [6.38761757e-01], [4.90305347e-01], [9.89409777e-01], [6.53042072e-02], [7.83234438e-01], [2.88398497e-01]], [[2.41418620e-01], [6.62504572e-01], [2.46063185e-01], [6.65859118e-01], [5.17308517e-01], [4.24088988e-01], [5.54687809e-01], [2.87051520e-01], [7.06574706e-01], [4.14856869e-01]], [[3.60545560e-01], [8.28656915e-01], [9.24966912e-01], [4.60073109e-02], [2.32626993e-01], [3.48519369e-01], [8.14966479e-01], [9.85491428e-01], [9.68971705e-01], [9.04948346e-01]], [[2.96556265e-01], [9.92011243e-01], [2.49420041e-01], [1.05906155e-01], [9.50952611e-01], [2.33420255e-01], [6.89768265e-01], [5.83563590e-02], [7.30709099e-01], [8.81720212e-01]], [[2.72436895e-01], [3.79056896e-01], [3.74296183e-01], [7.48788258e-01], [2.37807243e-01], [1.71853099e-01], [4.49291649e-01], [3.04468407e-01], [8.39189122e-01], [2.37741826e-01]], [[5.02389457e-01], [9.42583600e-01], [6.33997698e-01], [8.67289405e-01], [9.40209689e-01], [7.50764862e-01], [6.99575060e-01], [9.67965567e-01], [9.94400790e-01], [4.51821683e-01]], [[7.08697782e-02], [2.92794031e-01], [1.52354706e-01], [4.17486375e-01], [1.31289328e-01], [6.04117804e-01], [3.82808059e-01], [8.95385884e-01], [9.67794672e-01], [5.46884902e-01]], [[2.74823570e-01], [5.92230419e-01], [8.96761158e-01], [4.06733346e-01], [5.52078277e-01], [2.71652768e-01], [4.55444149e-01], [4.01713535e-01], [2.48413465e-01], [5.05866384e-01]], [[3.10380826e-01], [3.73034864e-01], [5.24970442e-01], [7.50595023e-01], [3.33507466e-01], [9.24158767e-01], [8.62318547e-01], [4.86902960e-02], [2.53642524e-01], [4.46135513e-01]], [[1.04627889e-01], [3.48475989e-01], [7.40097526e-01], [6.80514481e-01], [6.22384429e-01], [7.10528403e-01], [2.04923687e-01], [3.41698115e-01], [6.76242482e-01], [8.79234763e-01]], [[5.43678054e-01], [2.82699651e-01], [3.02352580e-02], [7.10336829e-01], [7.88410351e-03], [3.72679070e-01], [5.30537215e-01], [9.22111462e-01], [8.94945450e-02], [4.05942322e-01]], [[2.43131997e-02], [3.42610984e-01], [6.22231059e-01], [2.79067948e-01], [2.09749950e-01], [1.15703233e-01], [5.77140244e-01], [6.95270006e-01], [6.71957141e-01], [9.48861021e-01]], [[2.70321389e-03], [6.47196654e-01], [6.00392237e-01], [5.88739610e-01], [9.62770320e-01], [1.68716734e-02], [6.96482431e-01], [8.13678650e-01], [5.09807197e-01], [3.33964870e-01]], [[7.90840163e-01], [9.72429256e-02], [4.42035638e-01], [5.19952375e-01], [6.93956411e-01], [9.08857320e-02], [2.27759502e-01], [4.10301563e-01], [6.23294673e-01], [8.86960781e-01]], [[6.18826168e-01], [1.33461471e-01], [9.80580133e-01], [8.71785735e-01], [5.02720761e-01], [9.22347982e-01], [5.41380794e-01], [9.23306068e-01], [8.29897369e-01], [9.68286410e-01]], [[9.19782811e-01], [3.60338174e-02], [1.74772004e-01], [3.89134677e-01], [9.52142697e-01], [3.00028919e-01], [1.60467644e-01], [8.86304666e-01], [4.46394415e-01], [9.07875594e-01]], [[1.60230466e-01], [6.61117512e-01], [4.40263753e-01], [7.64867690e-02], [6.96463145e-01], [2.47398756e-01], [3.96155226e-02], [5.99442982e-02], [6.10785371e-02], [9.07732957e-01]], [[7.39883918e-01], [8.98062357e-01], [6.72582311e-01], [5.28939929e-01], [3.04446364e-01], [9.97962251e-01], [3.62189059e-01], [4.70648949e-01], [3.78245175e-01], [9.79526929e-01]], [[1.74658385e-01], [3.27988001e-01], [6.80348666e-01], [6.32076183e-02], [6.07249374e-01], [4.77646503e-01], [2.83999977e-01], [2.38413281e-01], [5.14512743e-01], [3.67927581e-01]], [[4.56519891e-01], [3.37477382e-01], [9.70493694e-01], [1.33439432e-01], [9.68039532e-02], [3.43391729e-01], [5.91026901e-01], [6.59176472e-01], [3.97256747e-01], [9.99277994e-01]], [[3.51892996e-01], [7.21406668e-01], [6.37582695e-01], [8.13053863e-01], [9.76225663e-01], [8.89793656e-01], [7.64561974e-01], [6.98248478e-01], [3.35498170e-01], [1.47685578e-01]], [[6.26360031e-02], [2.41901704e-01], [4.32281481e-01], [5.21996274e-01], [7.73083554e-01], [9.58740923e-01], [1.17320480e-01], [1.07004140e-01], [5.89694723e-01], [7.45398074e-01]], [[8.48150380e-01], [9.35832080e-01], [9.83426242e-01], [3.99801692e-01], [3.80335184e-01], [1.47808677e-01], [6.84934439e-01], [6.56761958e-01], [8.62062596e-01], [9.72579948e-02]], [[4.97776908e-01], [5.81081930e-01], [2.41557040e-01], [1.69025406e-01], [8.59580836e-01], [5.85349222e-02], [4.70620904e-01], [1.15834001e-01], [4.57058761e-01], [9.79962326e-01]], [[4.23706353e-01], [8.57124918e-01], [1.17315564e-01], [2.71252077e-01], [4.03792741e-01], [3.99812140e-01], [6.71383479e-01], [3.44718127e-01], [7.13766868e-01], [6.39186899e-01]], [[3.99161145e-01], [4.31760128e-01], [6.14527700e-01], [7.00421901e-02], [8.22406738e-01], [6.53421161e-01], [7.26342464e-01], [5.36923001e-01], [1.10477111e-01], [4.05035613e-01]], [[4.05373583e-01], [3.21042990e-01], [2.99503249e-02], [7.37254243e-01], [1.09784458e-01], [6.06308133e-01], [7.03217496e-01], [6.34786323e-01], [9.59142252e-01], [1.03298155e-01]], [[8.67167159e-01], [2.91902348e-02], [5.34916855e-01], [4.04243618e-01], [5.24183860e-01], [3.65099877e-01], [1.90566915e-01], [1.91228974e-02], [5.18149814e-01], [8.42776863e-01]], [[3.73215956e-01], [2.22863818e-01], [8.05320035e-02], [8.53109231e-02], [2.21396446e-01], [1.00014061e-01], [2.65039698e-01], [6.61494621e-02], [6.56048672e-02], [8.56276180e-01]], [[1.62120261e-01], [5.59682406e-01], [7.73455544e-01], [4.56409565e-01], [1.53368878e-01], [1.99596142e-01], [4.32984206e-01], [5.28234089e-01], [3.49440292e-01], [7.81479600e-01]], [[7.51021649e-01], [9.27211807e-01], [2.89525490e-02], [8.95691291e-01], [3.92568788e-01], [8.78372495e-01], [6.90784776e-01], [9.87348757e-01], [7.59282452e-01], [3.64544626e-01]], [[5.01063173e-01], [3.76389155e-01], [3.64911836e-01], [2.60904499e-01], [4.95970295e-01], [6.81739945e-01], [2.77340271e-01], [5.24379811e-01], [1.17380294e-01], [1.59845287e-01]], [[4.68063547e-02], [9.70731443e-01], [3.86035151e-03], [1.78579968e-01], [6.12866753e-01], [8.13695989e-02], [8.81896503e-01], [7.19620158e-01], [9.66389971e-01], [5.07635547e-01]], [[3.00403683e-01], [5.49500573e-01], [9.30818717e-01], [5.20761437e-01], [2.67207032e-01], [8.77398789e-01], [3.71918749e-01], [1.38335000e-03], [2.47685022e-01], [3.18233509e-01]], [[8.58777468e-01], [4.58503167e-01], [4.44587288e-01], [3.36102266e-01], [8.80678123e-01], [9.45026777e-01], [9.91890329e-01], [3.76741267e-01], [9.66147446e-01], [7.91879570e-01]], [[6.75689148e-01], [2.44889479e-01], [2.16457261e-01], [1.66047825e-01], [9.22756610e-01], [2.94076662e-01], [4.53094245e-01], [4.93957834e-01], [7.78171595e-01], [8.44234962e-01]], [[1.39072701e-01], [4.26904360e-01], [8.42854888e-01], [8.18033306e-01], [1.02413758e-01], [1.56383349e-01], [3.04198692e-01], [7.53590691e-02], [4.24663003e-01], [1.07617705e-01]], [[5.68217594e-01], [2.46556940e-01], [5.96433065e-01], [1.17525643e-01], [9.75883868e-01], [9.32561204e-01], [3.91796939e-01], [2.42178594e-01], [2.50398213e-01], [4.83393535e-01]], [[3.99928019e-02], [6.39705106e-01], [4.08302908e-01], [3.77406573e-01], [8.09364971e-01], [7.09035460e-01], [9.54333815e-01], [3.51936240e-01], [8.97542765e-01], [7.69967186e-01]], [[3.57424652e-01], [6.21665436e-01], [2.88569958e-01], [8.74399917e-01], [1.12427317e-01], [2.12434361e-01], [1.83033292e-01], [4.03026002e-01], [7.45232960e-01], [5.26907449e-01]], [[4.87676324e-01], [5.45964897e-04], [4.25401725e-01], [6.35537748e-02], [2.08253252e-01], [9.32393939e-01], [2.15398204e-01], [8.58337639e-01], [8.02893372e-01], [1.59146237e-01]], [[6.05711957e-01], [1.15661872e-01], [7.27888158e-01], [6.37462277e-01], [8.11938562e-01], [4.79384549e-01], [9.14863088e-01], [4.93489468e-02], [2.92888565e-01], [7.15052597e-01]], [[4.18109212e-01], [1.72951354e-01], [1.07210745e-01], [8.17339111e-01], [4.73142978e-01], [8.82283672e-01], [7.33289134e-01], [4.09726206e-01], [3.73511014e-01], [5.15638347e-01]], [[8.89059953e-01], [7.37278580e-01], [5.15296427e-03], [6.94157851e-01], [9.19507407e-01], [7.10455760e-01], [1.77005782e-01], [4.83518127e-01], [1.40316018e-01], [3.58995278e-01]], [[9.37117042e-01], [9.23305308e-01], [2.82836852e-01], [3.39631044e-01], [6.00212868e-01], [9.63197295e-01], [1.47801334e-01], [2.56916644e-01], [8.73556827e-01], [4.91892232e-01]], [[8.98961092e-01], [1.85517898e-01], [5.32668587e-01], [3.26269633e-01], [3.16542560e-01], [4.46876964e-01], [4.33077449e-01], [3.57346880e-01], [9.14970770e-01], [7.31744185e-01]], [[7.27546991e-01], [2.89913450e-01], [5.77709424e-01], [7.79179433e-01], [7.95590369e-01], [3.44530461e-01], [7.70872757e-01], [7.35893897e-01], [1.41506486e-01], [8.65945469e-01]], [[4.41321470e-01], [4.86410449e-01], [4.48369179e-01], [5.67846001e-01], [6.21169247e-01], [4.98179566e-01], [8.66788543e-01], [6.27734756e-01], [4.01427949e-01], [4.16691757e-01]], [[8.10838615e-01], [3.48191943e-01], [2.11454796e-01], [5.93831880e-02], [8.76026848e-01], [9.18546451e-01], [1.20120182e-01], [3.34473741e-01], [1.75372070e-01], [1.15898469e-01]], [[8.99866743e-01], [5.68772591e-02], [9.80485663e-01], [9.64508607e-02], [8.63470649e-01], [5.66506107e-01], [3.67917488e-01], [3.42342377e-01], [7.57364143e-01], [3.14573295e-01]], [[6.57318917e-01], [5.17326084e-01], [4.84965645e-01], [9.01162171e-01], [5.54645059e-01], [8.26861603e-01], [7.25573534e-01], [3.85572461e-02], [7.73110053e-01], [2.16870250e-01]], [[9.03149647e-01], [4.29241906e-02], [3.33072034e-01], [9.97329472e-02], [4.75589117e-01], [8.20022436e-01], [2.98187360e-01], [1.50934897e-01], [3.30267036e-01], [8.13880142e-01]], [[1.40383958e-01], [2.27362449e-01], [6.88519645e-02], [7.05710044e-01], [3.95233244e-01], [3.10839977e-01], [7.18626390e-01], [3.35977542e-01], [7.27771273e-01], [8.15199395e-01]], [[2.17662843e-01], [9.73818697e-01], [1.62357948e-01], [2.90840907e-01], [1.79795291e-01], [3.45505656e-01], [4.80060888e-01], [5.22175869e-01], [8.53606042e-01], [8.89447909e-01]], [[2.20103861e-01], [6.22894032e-01], [1.11496057e-01], [4.58969860e-01], [3.22333538e-01], [3.16500745e-01], [4.82584242e-01], [7.29827636e-01], [6.91826588e-02], [8.79173338e-01]], [[7.34813775e-01], [1.76499389e-01], [9.39160909e-01], [5.06312224e-01], [9.99808578e-01], [1.97259474e-01], [5.34908198e-01], [2.90248043e-01], [3.04173557e-01], [5.91065381e-01]], [[9.21719067e-01], [8.05263856e-01], [7.23941399e-01], [5.59173782e-01], [9.22298504e-01], [4.92361407e-01], [8.73832178e-01], [8.33981644e-01], [2.13835347e-01], [7.71225463e-01]], [[1.21711569e-02], [3.22829538e-01], [2.29567445e-01], [5.06862958e-01], [7.36853162e-01], [9.76763674e-02], [5.14922202e-01], [9.38412022e-01], [2.28646551e-01], [6.77141144e-01]]])
y
array([6.15766283, 5.47343366, 5.80251336, 5.55439698, 3.91024909, 4.0865395 , 3.48130262, 4.62304009, 4.44328799, 3.74695784, 5.81024204, 6.53739686, 5.12791992, 4.95242597, 6.83477713, 4.62022506, 5.84404662, 3.82120161, 5.31708373, 3.94285671, 4.22376337, 5.03551278, 6.29287132, 3.34543105, 4.12535925, 6.18792818, 4.15987426, 7.58006614, 4.86551535, 5.14687668, 3.65726133, 4.56957122, 3.51939496, 4.77330199, 4.31174913, 4.59446004, 6.14973996, 5.62542232, 6.57445356, 3.50039149, 4.90610262, 2.11582497, 5.79202376, 4.7204139 , 6.41570102, 5.18882051, 4.01492958, 7.75099781, 4.46178554, 4.60571707, 4.82743427, 5.30872776, 3.89559851, 4.48690479, 5.17260685, 4.78322976, 7.29259289, 5.17293725, 3.35033172, 6.33248724, 3.73595213, 4.98486419, 6.33594974, 4.55005736, 5.99637124, 4.43103304, 4.84205917, 4.77009725, 4.61015796, 4.29541809, 2.33640341, 4.39677098, 6.67577919, 3.75992457, 4.96975665, 4.38531086, 7.0503337 , 5.08937562, 3.39750083, 4.80494559, 5.95758783, 4.32512134, 4.15360043, 5.35020057, 4.90320177, 5.11544772, 5.81646745, 5.14397602, 6.02868874, 5.37593892, 3.9703063 , 5.24585459, 5.69639056, 4.26775981, 4.44595624, 4.91527205, 4.21306597, 5.27424953, 7.11763265, 4.26508255])
# Define the RNN model model = Sequential([ SimpleRNN(units=32, input_shape=(X.shape[1], X.shape[2]), activation='relu'), Dense(1) ]) # Compile the model model.compile(optimizer='adam', loss='mean_squared_error')
# Train the model model.fit(X, y, epochs=10, batch_size=8)
Epoch 1/10 13/13 [==============================] - 2s 10ms/step - loss: 26.1947 Epoch 2/10 13/13 [==============================] - 0s 10ms/step - loss: 22.5297 Epoch 3/10 13/13 [==============================] - 0s 9ms/step - loss: 15.9407 Epoch 4/10 13/13 [==============================] - 0s 12ms/step - loss: 3.5103 Epoch 5/10 13/13 [==============================] - 0s 10ms/step - loss: 0.5985 Epoch 6/10 13/13 [==============================] - 0s 11ms/step - loss: 0.2409 Epoch 7/10 13/13 [==============================] - 0s 10ms/step - loss: 0.1323 Epoch 8/10 13/13 [==============================] - 0s 10ms/step - loss: 0.1077 Epoch 9/10 13/13 [==============================] - 0s 10ms/step - loss: 0.1047 Epoch 10/10 13/13 [==============================] - 0s 10ms/step - loss: 0.0942 1/1 [==============================] - 0s 326ms/step Predicted Output: 4.6273394
# Test the model test_input = np.random.rand(10).reshape((1, sequence_length, 1)) predicted_output = model.predict(test_input) test_input
1/1 [==============================] - 0s 43ms/step
array([[[0.23583422], [0.6204999 ], [0.63962224], [0.9485403 ], [0.77827617], [0.84834527], [0.49041991], [0.18534859], [0.99581529], [0.12935576]]])
# Print the predicted output print("Predicted Output:", predicted_output[0, 0])
Predicted Output: 6.0044985