Long Short-Term Memory
- Summary
-
Discussion
- How do LSTM networks differ from recurrent neural networks?
- What is the typical node structure of LSTM?
- How does an LSTM network selectively ‘forget’ and ‘remember’ from past data?
- Which part of the ‘memory’ is short term and which is long?
- Could you explain LSTM with an example?
- What are the data preparation steps before feeding to an LSTM network?
- How to train an LSTM model?
- What the major applications of the LSTM network?
- What support do TensorFlow and Keras frameworks offer to LSTM modeling?
- Milestones
- References
- Further Reading
- Article Stats
- Cite As
There are several variations in neural networks, each suitable for a particular kind of data input. Standard and convolutional neural networks work well on static data, such as static image where entire data is analysed all at once.
However, when data is dynamic and sequentially organised such as video frames or stock market values, a variant called Recurrent Neural Network (RNN) is employed. LSTM (Long Short-Term Memory) is a subset of RNNs.
As the name suggests, LSTM networks have ‘memory’ of previous states of the data. This memory is selectively tuned to remember only chosen parts of past data, even for a long time. In applications where predictions depend on previous values of data, LSTM finds great relevance.
Keras and TensorFlow implementations of LSTM are extensively used in sequential prediction applications such as auto-response suggestions in emails, stock value predictions and speech/writing recognition.
Discussion
-
How do LSTM networks differ from recurrent neural networks? In conventional feed-forward NN, all data values are considered equally important, irrespective of whether a day old or months old. This works fine for reading image data sets to differentiate a cat from a dog. But while dealing with time series data where every value is time stamped, relevance of data continuously reduces with passage of time.
RNN can handle data dependencies, but within short time intervals. This is achieved by feeding the output of the hidden layer h(t−1) through a conceptual delay block back into the input of hidden layer. However, RNN aren’t fully effective with time-sensitive data because of the vanishing gradient problem.
LSTM networks are a type of RNN which include a 'memory cell' that maintains information in memory for long time periods. A set of gates is used to control when information enters the memory, when it's output, and when it's forgotten. This architecture lets them learn longer-term dependencies. To eliminate the vanishing (or exploding) gradient problem, the LSTM cell has a unique feature called Forget Gate. This helps in greatly reducing the multiplicative effect of small gradients.
-
What is the typical node structure of LSTM? All neural networks have a chain of repeating nodes in the hidden layers. Standard RNN nodes might have an input, output and a simple tanh function in the middle. In LSTM, the hidden layer nodes have three interacting functions or ‘gates’. These gates protect and control the ‘memory’ - data stored in the cell state.
- Cell-State: Works like a conveyor belt running through the network. Value undergoes continuous changes node after node based on information added/removed by the gates.
- Hidden-State: The actual output of that node for a given input. But always hidden because it only enters as input at the next time-step.
- Input-Gate: Open only at time step t. Here the node decides which inputs will update the current cell state and which new candidate inputs will be added to it.
- Forget-Gate: Decides what information to throw away from the cell state. It looks at the input, previous hidden state and gives a value between 0-1 for each number in the cell state. 1 – ‘Completely retain’, 0 – ‘Completely forget’.
- Output-Gate: Sends out a filtered version of the cell state as output.
-
How does an LSTM network selectively ‘forget’ and ‘remember’ from past data? In LSTM, the cell state is retained as a continuous rolling value till it exits all the hidden layers and reaches the output. The 3-gate structure ensures that the cell value is controlled and protected by optionally letting information through. They are composed of a sigmoid neural net layer and a tanh operation vector for a point-wise multiplication operation.
Data is processed in batches. For a new batch, the input layer gets inputs from the hidden state (output layer) of the previous batch. Even textual values are coded and stored as numbers in the cell state for easy manipulation.
Every iteration follows these steps (at time step t):
- Hidden state value from t-1 and input at t are sent into the node.
- Forget gate removes unwanted information from the cell state using the inputs.
- Input gate decides what old information to retain and what new to add to the cell state.
- Cell state value also get stored as hidden state at t.
- Output gate might send filtered version of cell state either out of the hidden layers or back into a recurrent loop.
-
Which part of the ‘memory’ is short term and which is long? Plain vanilla RNNs are also short-term memory networks. The output of a hidden layer node can be sent back in as input after a time lag. So every node can retain the cell value in its ‘memory’ for one time step (short term). Theoretically since the loop is recurrent, the value should remain even for long time periods. However, the error correction that happens through back propagation keeps losing significance as it reaches the initial layers. This causes the hidden layer node values to deviate from the desired output after repeated cycles and the memory is lost.
In LSTM, the three gates keep the cell state safe. So the short-term memory node manages to retain its cell state throughout the network hidden layers (long time). The 'forget gate' is the critical element which stands between different time steps within the hidden layer and ensures the control and accuracy of the cell state throughout.
-
Could you explain LSTM with an example? Take a language model trying to predict the next word based on all the previous ones - “John has lived in France for several years. So he speaks very good French”. An important step in prediction is deciding what pronoun to use based on the subject's gender. To predict this, the network needs to recognise the subject ‘John’.
At each time step, the cell state will keep adding/updating relevant information about the subject collected from various different sentences – name, location, gender, singular or plural, etc. The three gates will ensure only relevant information remains.
However, the next sentence may drop ‘John’ as subject and start talking about “Jenny has lived in Spain since her birth.” Now the forget gate ensures that the cell state drops the context of John and starts collecting information on Jenny. The correct pronoun to suggest for the next sentence is ‘she’.
It’s possible that the next sentences don't disturb the subject John and instead talk of other things, “France is known for its …”. Now the LSTM network will remember the context of John for a long time and suggest to use ‘he’ even 2-3 sentences later.
-
What are the data preparation steps before feeding to an LSTM network? Before fitting an LSTM model to the dataset and making a forecast, some data transformations are performed on the dataset. Let’s extend the word sequence prediction example.
- Uni-variate input data (word sequence in example) is converted into Input -> Output format from which the model can learn. “John has lived” -> “in”, “has lived in” -> “France” and so on. We transform the time series into a supervised learning problem with stationary data. Observation at previous time-step becomes input to forecast at current time-step.
- LSTM models understand only numeric inputs. So encode the word sequence and punctuation marks with numeric symbols. Build a dictionary of words <-> symbols.
- Output is a one-hot vector identifying the index of the predicted symbol in the dictionary. Word2Vec is an optimal way of encoding symbols to vectors. To classify images, consider every image row as a sequence of pixels.
- Transform observations to have a specific scale. Such as, rescaling data to values between -1 and 1 to conform to the default hyperbolic tangent activation function of the LSTM model.
These transforms are inverted on forecasts to return them to their original state.
-
How to train an LSTM model? LSTM models are trained to be used for sequence prediction resulting in output as a classification (assigning categorical labels) or a regression (continuous real values).
Taking the word prediction example forward, the model must predict the next word in the sentence. This involves multiple iterations of training the model with past data, then evaluating its accuracy on test data. Finally, the prediction is applied on new production data.
The goal is to arrive at a final model that performs the best with respect to:
- Historical data available
- Training time spent on the model
- Data preparation steps and chosen algorithm configurations
- Desired prediction accuracy depending on the business case
First initialize the input vector, weights and biases. Define the nodes in the network (512 node LSTM network for our example). Feed inputs into the model.
Output is a multi-element vector of prediction probabilities for the next symbol normalized by the softmax function. The index of the element with highest probability is the predicted index of the symbol in the reverse dictionary.
Accuracy and loss are accumulated to monitor the progress of the training. 50,000 iterations are generally enough to achieve an acceptable accuracy.
-
What the major applications of the LSTM network? - Sequential prediction of stock price based on historical stock price values
- Handwriting Synthesis and recognition
- Speech recognition using acoustic signals (phonemes) as input
- Image Captioning and recognition of unnamed images in archives
- Synthesised music generation from previous sequence of notes played
- Language translation, after the entire source language input is given
- Flood forecasting with daily discharge and rainfall as input data
-
What support do TensorFlow and Keras frameworks offer to LSTM modeling? Keras is an open-source Python based deep learning framework. It has a user-friendly, modular API framework covering all common neural network building blocks such as layers, nodes, optimization functions and activation functions. Keras is a high-level API wrapper to run on top of TensorFlow, CNTK, or Theano.
Keras models can be deployed on smartphones, online applications and JVMs. It also supports distributed deployment of deep learning models on GPU/TPU clusters.
Keras high-level API handles the way we make models, defining layers, or setup multiple input-output models. It supports other common utilities like dropout, batch normalization, and pooling.
TensorFlow is an end-to-end open source machine learning framework from Google. You install the TensorFlow module in Python and use the libraries for input-output data definition, allocating test and training data sets, model building, fitting the model, optimizing for batch and epoch sizes, training the mode and finally validating accuracy scores for prediction output.
Combined use of Keras over TensorFlow is a popular option among developers, enabling fast experimentation followed by deployment. Both frameworks have good developer community support, with source code and examples available on GitHub.
Milestones
References
- Adventures in Machine Learning. 2017. "Recurrent neural networks and LSTM tutorial in Python and TensorFlow." Adventures in Machine Learning, October 09. Accessed 2019-09-17.
- Adventures in Machine Learning. 2018. "Keras LSTM tutorial – How to easily build a powerful deep learning language model." Adventures in Machine Learning, February 03. Accessed 2019-09-17.
- Atienza, Rowel. 2017. "LSTM by Example using Tensorflow." Medium. Accessed 2019-09-17.
- Brownlee, Jason. 2017. "How to Train a Final Machine Learning Model." Machine Learning Mastery Pty Ltd. Accessed 2019-09-17.
- Brownlee, Jason. 2017b. "How to Make Predictions with Long Short-Term Memory Models in Keras." Accessed 2019-09-17.
- Brownlee, Jason. 2017c. "A Gentle Introduction to Backpropagation Through Time." Accessed 2019-09-17.
- Brownlee, Jason. 2019. "How to Develop LSTM Models for Time Series Forecasting." Machine Learning Mastery Pty Ltd. Accessed 2019-09-17.
- Cheung, Brian. 2018. "Long Short Term Memory Networks." Accessed 2019-09-17.
- Coursera. 2019. "Anatomy of a LSTM Node." Applied AI with Deep Learning, IBM, via Coursera Inc. Accessed 2019-09-17.
- Efrati, Amir. 2016. "Apple’s Machines Can Learn Too." The Information. Accessed 2019-09-17.
- Guru99. 2019. "Keras Tutorial for Beginners with Python: Deep Learning EXAMPLE." Accessed 2019-09-17.
- Hafner, Danijar. 2019. "Tips for Training Recurrent Neural Networks." Accessed 2019-09-17.
- Islam, Md Asadul. 2018. "How to Train Recurrent Neural Network (RNN) Models and Serve Them in Production with TensorFlow and Flask." Medium. Accessed 2019-09-17.
- Johnson, Khari. 2019. "Google launches TensorFlow 2.0 with tighter Keras integration." Accessed 2019-09-17.
- Keras. 2019. "Why use Keras?." Keras Documentation. Accessed 2019-09-17.
- Keras. 2019b. "RNN." Keras Documentation. Accessed 2019-09-17.
- Keras-team. 2019. "Releases." GitHib Inc. Accessed 2019-09-17.
- Le, Xuan-Hien, Hung Viet Ho, Giha Lee, and Sungho Jung. 2019. "Application of Long Short-Term Memory (LSTM) Neural Network for Flood Forecasting." Accessed 2019-09-17.
- Maladkar, Kishan. 2018. "Overview of Recurrent Neural Networks And Their Applications." Analytics India Magazine Pvt Ltd., January 17. Accessed 2019-09-17.
- Moawad, Assaad. 2018. "The magic of LSTM neural networks." Accessed 2019-09-17.
- Olah, Christopher. 2015. "Understanding LSTM Networks." Blog, August 27. Accessed 2019-09-17.
- R on Coding Club UC3M. 2018. "LSTM with Keras & TensorFlow." R-bloggers. Accessed 2019-09-17.
- Srivastava, Pranjal. 2017. "Essentials of Deep Learning : Introduction to Long Short Term Memory." Analytics Vidhya, December 10. Accessed 2019-09-17.
- TensorFlow. 2019. "Text generation with an RNN." Accessed 2019-09-17.
- TensorFlow. 2019b. "TensorFlow API Versions." Accessed 2019-09-17.
- Wikipedia. 2019. "Long short-term memory." Accessed 2019-09-17.
- Wikipedia. 2019b. "Jürgen Schmidhuber." Accessed 2019-09-17.
- Wikipedia. 2019c. "Sepp Hochreiter." Accessed 2019-09-17.
- Yan, Shi. 2016. "Understanding LSTM and its diagrams." ML Review, via Medium, May 14. Accessed 2019-09-17.
Further Reading
- Olah, Christopher. 2015. "Understanding LSTM Networks." Blog, August 27. Accessed 2019-09-17.
- Cheung, Brian. 2018. "Long Short Term Memory Networks." Accessed 2019-09-17.
- TensorFlow. 2019. "Text generation with an RNN." Accessed 2019-09-17.
- Hochreiter, Sepp and Jurgen Schmidhuber. 1997. "Long Short-Term Memory." Accessed 2019-09-17.
Article Stats
Cite As
See Also
- Recurrent Neural Network
- Word Embedding
- Neural Networks for NLP
- Vanishing Gradient Problem
- Time Series Analysis
- Speech Recognition