Pengantar Tensorflow untuk Java

1. Ikhtisar

TensorFlow adalah pustaka sumber terbuka untuk pemrograman dataflow . Ini pada awalnya dikembangkan oleh Google dan tersedia untuk beragam platform. Meskipun TensorFlow dapat bekerja pada satu inti, TensorFlow dapat dengan mudah memanfaatkan beberapa CPU, GPU, atau TPU yang tersedia .

Dalam tutorial ini, kita akan membahas dasar-dasar TensorFlow dan cara menggunakannya di Java. Harap diperhatikan bahwa TensorFlow Java API adalah API eksperimental dan karenanya tidak tercakup dalam jaminan stabilitas apa pun. Kita akan membahasnya nanti di tutorial kemungkinan kasus penggunaan untuk menggunakan TensorFlow Java API.

2. Dasar-dasar

Komputasi TensorFlow pada dasarnya berputar di sekitar dua konsep dasar: Grafik dan Sesi . Mari kita bahas dengan cepat untuk mendapatkan latar belakang yang diperlukan untuk menjalani sisa tutorial.

2.1. Grafik TensorFlow

Untuk memulainya, mari pahami elemen dasar dari program TensorFlow. Perhitungan direpresentasikan sebagai grafik di TensorFlow . Grafik biasanya grafik asiklik terarah dari operasi dan data, misalnya:

Gambar di atas mewakili grafik komputasi untuk persamaan berikut:

f(x, y) = z = a*x + b*y

Grafik komputasi TensorFlow terdiri dari dua elemen:

  1. Tensor: Ini adalah unit inti data di TensorFlow. Mereka direpresentasikan sebagai tepi dalam grafik komputasi, yang menggambarkan aliran data melalui grafik. Tensor dapat memiliki bentuk dengan sejumlah dimensi. Banyaknya dimensi dalam tensor biasanya disebut dengan rank-nya. Jadi skalar adalah tensor peringkat 0, vektor adalah tensor peringkat 1, matriks adalah tensor peringkat 2, dan seterusnya.
  2. Operasi: Ini adalah node dalam grafik komputasi. Mereka mengacu pada berbagai macam komputasi yang dapat terjadi pada tensor yang dimasukkan ke dalam operasi. Mereka sering menghasilkan tensor juga yang berasal dari operasi dalam grafik komputasi.

2.2. Sesi TensorFlow

Sekarang, grafik TensorFlow hanyalah skema komputasi yang sebenarnya tidak memiliki nilai. Grafik seperti itu harus dijalankan di dalam apa yang disebut sesi TensorFlow agar tensor dalam grafik dievaluasi . Sesi ini dapat menggunakan banyak tensor untuk dievaluasi dari grafik sebagai parameter input. Kemudian berjalan mundur dalam grafik dan menjalankan semua node yang diperlukan untuk mengevaluasi tensor tersebut.

Dengan pengetahuan ini, kami sekarang siap untuk mengambil ini dan menerapkannya ke Java API!

3. Pengaturan Maven

Kami akan menyiapkan proyek Maven cepat untuk membuat dan menjalankan grafik TensorFlow di Java. Kami hanya membutuhkan ketergantungan tensorflow :

 org.tensorflow tensorflow 1.12.0 

4. Membuat Grafik

Sekarang mari kita coba membuat grafik yang telah kita diskusikan di bagian sebelumnya menggunakan TensorFlow Java API. Lebih tepatnya, untuk tutorial ini kita akan menggunakan TensorFlow Java API untuk menyelesaikan fungsi yang direpresentasikan oleh persamaan berikut:

z = 3*x + 2*y

Langkah pertama adalah mendeklarasikan dan menginisialisasi grafik:

Graph graph = new Graph()

Sekarang, kita harus mendefinisikan semua operasi yang diperlukan. Ingat, operasi di TensorFlow mengonsumsi dan menghasilkan nol atau lebih tensor . Selain itu, setiap node dalam grafik merupakan operasi termasuk konstanta dan placeholder. Ini mungkin tampak kontra-intuitif, tapi tahan sebentar!

Grafik kelas memiliki fungsi umum yang disebut opBuilder () untuk membuat jenis operasi apa pun di TensorFlow.

4.1. Mendefinisikan Konstanta

Untuk memulainya, mari kita definisikan operasi konstan pada grafik kita di atas. Perhatikan bahwa operasi konstan akan membutuhkan tensor untuk nilainya :

Operation a = graph.opBuilder("Const", "a") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(3.0, Double.class)) .build(); Operation b = graph.opBuilder("Const", "b") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(2.0, Double.class)) .build();

Di sini, kami telah mendefinisikan Operasi tipe konstan, memasukkan Tensor dengan nilai Ganda 2.0 dan 3.0. Ini mungkin tampak sedikit berlebihan untuk memulainya, tetapi hanya seperti itu di Java API untuk saat ini. Konstruksi ini jauh lebih ringkas dalam bahasa seperti Python.

4.2. Menentukan Placeholder

Meskipun kami perlu memberikan nilai ke konstanta kami, placeholder tidak memerlukan nilai pada waktu-definisi . Nilai placeholder perlu diberikan saat grafik dijalankan di dalam sesi. Kami akan membahas bagian itu nanti di tutorial.

Untuk saat ini, mari kita lihat bagaimana kita bisa mendefinisikan placeholder kita:

Operation x = graph.opBuilder("Placeholder", "x") .setAttr("dtype", DataType.fromClass(Double.class)) .build(); Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.fromClass(Double.class)) .build();

Perhatikan bahwa kami tidak harus memberikan nilai apa pun untuk placeholder kami. Nilai-nilai ini akan dimasukkan sebagai Tensor saat dijalankan.

4.3. Mendefinisikan Fungsi

Terakhir, kita perlu mendefinisikan operasi matematika dari persamaan kita, yaitu perkalian dan penjumlahan untuk mendapatkan hasilnya.

Ini sekali lagi tidak lain adalah Operasi di TensorFlow dan Graph.opBuilder () berguna sekali lagi:

Operation ax = graph.opBuilder("Mul", "ax") .addInput(a.output(0)) .addInput(x.output(0)) .build(); Operation by = graph.opBuilder("Mul", "by") .addInput(b.output(0)) .addInput(y.output(0)) .build(); Operation z = graph.opBuilder("Add", "z") .addInput(ax.output(0)) .addInput(by.output(0)) .build();

Di sini, kita telah mendefinisikan Operasi di sana , dua untuk mengalikan input kita dan yang terakhir untuk menjumlahkan hasil antara. Perhatikan bahwa operasi di sini menerima tensor yang tidak lain adalah keluaran dari operasi sebelumnya.

Harap dicatat bahwa kami mendapatkan Tensor keluaran dari Operasi menggunakan indeks '0'. Seperti yang kita bahas sebelumnya, sebuah Operasi dapat menghasilkan satu atau lebih Tensor dan karenanya saat mengambil pegangan untuk itu, kita perlu menyebutkan indeks. Karena kami tahu bahwa operasi kami hanya mengembalikan satu Tensor , '0' berfungsi dengan baik!

5. Memvisualisasikan Grafik

Sulit untuk mempertahankan tab pada grafik saat ukurannya bertambah. Ini membuatnya penting untuk memvisualisasikannya dengan cara tertentu . Kita selalu dapat membuat gambar tangan seperti grafik kecil yang kita buat sebelumnya, tetapi tidak praktis untuk grafik yang lebih besar. TensorFlow menyediakan utilitas yang disebut TensorBoard untuk memfasilitasi ini .

Sayangnya, Java API tidak memiliki kemampuan untuk menghasilkan file peristiwa yang digunakan oleh TensorBoard. Tetapi menggunakan API dengan Python kita dapat menghasilkan file acara seperti:

writer = tf.summary.FileWriter('.') ...... writer.add_graph(tf.get_default_graph()) writer.flush()

Tolong jangan repot-repot jika ini tidak masuk akal dalam konteks Java, ini telah ditambahkan di sini hanya demi kelengkapan dan tidak perlu untuk melanjutkan sisa tutorial.

Sekarang kita dapat memuat dan memvisualisasikan file event di TensorBoard seperti:

tensorboard --logdir .

TensorBoard hadir sebagai bagian dari penginstalan TensorFlow.

Perhatikan kesamaan antara ini dan grafik yang digambar secara manual sebelumnya!

6. Bekerja dengan Sesi

We have now created a computational graph for our simple equation in TensorFlow Java API. But how do we run it? Before addressing that, let's see what is the state of Graph we have just created at this point. If we try to print the output of our final Operation “z”:

System.out.println(z.output(0));

This will result in something like:


    

This isn't what we expected! But if we recall what we discussed earlier, this actually makes sense. The Graph we have just defined has not been run yet, so the tensors therein do not actually hold any actual value. The output above just says that this will be a Tensor of type Double.

Let's now define a Session to run our Graph:

Session sess = new Session(graph)

Finally, we are now ready to run our Graph and get the output we have been expecting:

Tensor tensor = sess.runner().fetch("z") .feed("x", Tensor.create(3.0, Double.class)) .feed("y", Tensor.create(6.0, Double.class)) .run().get(0).expect(Double.class); System.out.println(tensor.doubleValue());

So what are we doing here? It should be fairly intuitive:

  • Get a Runner from the Session
  • Define the Operation to fetch by its name “z”
  • Feed in tensors for our placeholders “x” and “y”
  • Run the Graph in the Session

And now we see the scalar output:

21.0

This is what we expected, isn't it!

7. The Use Case for Java API

At this point, TensorFlow may sound like overkill for performing basic operations. But, of course, TensorFlow is meant to run graphs much much larger than this.

Additionally, the tensors it deals with in real-world models are much larger in size and rank. These are the actual machine learning models where TensorFlow finds its real use.

It's not difficult to see that working with the core API in TensorFlow can become very cumbersome as the size of the graph increases. To this end, TensorFlow provides high-level APIs like Keras to work with complex models. Unfortunately, there is little to no official support for Keras on Java just yet.

However, we can use Python to define and train complex models either directly in TensorFlow or using high-level APIs like Keras. Subsequently, we can export a trained model and use that in Java using the TensorFlow Java API.

Now, why would we want to do something like that? This is particularly useful for situations where we want to use machine learning enabled features in existing clients running on Java. For instance, recommending caption for user images on an Android device. Nevertheless, there are several instances where we are interested in the output of a machine learning model but do not necessarily want to create and train that model in Java.

This is where TensorFlow Java API finds the bulk of its use. We'll go through how this can be achieved in the next section.

8. Using Saved Models

We'll now understand how we can save a model in TensorFlow to the file system and load that back possibly in a completely different language and platform. TensorFlow provides APIs to generate model files in a language and platform neutral structure called Protocol Buffer.

8.1. Saving Models to the File System

We'll begin by defining the same graph we created earlier in Python and saving that to the file system.

Let's see we can do this in Python:

import tensorflow as tf graph = tf.Graph() builder = tf.saved_model.builder.SavedModelBuilder('./model') with graph.as_default(): a = tf.constant(2, name="a") b = tf.constant(3, name="b") x = tf.placeholder(tf.int32, name="x") y = tf.placeholder(tf.int32, name="y") z = tf.math.add(a*x, b*y, name="z") sess = tf.Session() sess.run(z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING]) builder.save()

As the focus of this tutorial in Java, let's not pay much attention to the details of this code in Python, except for the fact that it generates a file called “saved_model.pb”. Do note in passing the brevity in defining a similar graph compared to Java!

8.2. Loading Models from the File System

We'll now load “saved_model.pb” into Java. Java TensorFlow API has SavedModelBundle to work with saved models:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("z") .feed("x", Tensor.create(3, Integer.class)) .feed("y", Tensor.create(3, Integer.class)) .run().get(0).expect(Integer.class); System.out.println(tensor.intValue());

It should by now be fairly intuitive to understand what the above code is doing. It simply loads the model graph from the protocol buffer and makes available the session therein. From there onward, we can pretty much do anything with this graph as we would have done for a locally-defined graph.

9. Conclusion

To sum up, in this tutorial we went through the basic concepts related to the TensorFlow computational graph. We saw how to use the TensorFlow Java API to create and run such a graph. Then, we talked about the use cases for the Java API with respect to TensorFlow.

In the process, we also understood how to visualize the graph using TensorBoard, and save and reload a model using Protocol Buffer.

Seperti biasa, kode untuk contoh tersedia di GitHub.