Teachable Machine

# ๐Ÿ“ฆ TFHub ๋ชจ๋ธ์„ Android์—์„œ TensorFlow Lite Task API๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

TensorFlow Lite Task API๋Š” ๋งค์šฐ ํŽธ๋ฆฌํ•œ ๊ณ ์ˆ˜์ค€ Vision API์ง€๋งŒ, ์‚ฌ์ „ ์š”๊ตฌ์‚ฌํ•ญ์ด ๊นŒ๋‹ค๋กญ๋‹ค. ํŠนํžˆ **Teachable Machine**์ด๋‚˜ **TensorFlow Hub**์—์„œ ๋‹ค์šด๋กœ๋“œํ•œ `.pb` ๋ชจ๋ธ์€ ๋ฐ”๋กœ Android์—์„œ ์“ธ ์ˆ˜ ์—†๋‹ค.

์ด ํฌ์ŠคํŒ…์—์„œ๋Š” `.pb` ๋ชจ๋ธ์„ Android์—์„œ Task API๋กœ ์‚ฌ์šฉํ•˜๊ธฐ๊นŒ์ง€์˜ **์ „์ฒด ์ „ํ™˜ ํ๋ฆ„**์„ ์ •๋ฆฌํ•œ๋‹ค.

---

## ๐Ÿง  ์™œ ์•ˆ ๋˜๋Š”๊ฐ€?

### โŒ Android TFLite Task API์—์„œ ๋ฐ”๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๋Š” ์ด์œ 
- ๋Œ€๋ถ€๋ถ„์˜ `.pb` ๋ชจ๋ธ์€ **TFLite ํฌ๋งท์ด ์•„๋‹˜**
- ๋˜๋Š” `.tflite`๋ผ๊ณ  ํ•ด๋„ **Metadata (NormalizationOptions ๋“ฑ)** ๊ฐ€ ์—†์–ด Task API๊ฐ€ ์‹คํŒจ
- ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ ์˜ˆ์‹œ:

Input tensor has type kTfLiteFloat32: it requires specifying NormalizationOptions metadata…


---

## โœ… ํ•ด๊ฒฐ ์ˆœ์„œ ์š”์•ฝ

1. **(ํ•„์š”์‹œ) `tfhub_module.pb` โ†’ SavedModel ๋ณ€ํ™˜**
2. `SavedModel` โ†’ `.tflite` ๋ณ€ํ™˜
3. `.tflite` โ†’ Metadata ์ถ”๊ฐ€
4. Android Task API์—์„œ ์ •์ƒ ๋กœ๋“œ

---

## ๐Ÿ›  Step-by-Step

### 1. `tfhub_module.pb` โ†’ `SavedModel` ๋ณ€ํ™˜

```python
import tensorflow as tf

graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile("tfhub_module.pb", "rb") as f:
  graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
  tf.import_graph_def(graph_def, name="")
  for op in graph.get_operations():
      print(op.name)  # input/output ์ด๋ฆ„ ํ™•์ธ

@tf.function(input_signature=[tf.TensorSpec([1, 224, 224, 3], tf.float32)])
def model_fn(x):
  return {"output": tf.import_graph_def(graph_def, input_map={"input": x}, return_elements=["output:0"])[0]}

concrete_func = model_fn.get_concrete_function()
tf.saved_model.save(concrete_func, "./saved_model")

2. SavedModel โ†’ TFLite ๋ณ€ํ™˜

converter = tf.lite.TFLiteConverter.from_saved_model("./saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float32]
tflite_model = converter.convert()

with open("model.tflite", "wb") as f:
    f.write(tflite_model)

3. Metadata ์ถ”๊ฐ€ (NormalizationOptions ๋“ฑ)

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils

MODEL_PATH = "model.tflite"
LABEL_FILE = "labels.txt"  # ๊ฐ ์ค„๋งˆ๋‹ค ํด๋ž˜์Šค๋ช…
SAVE_TO_PATH = "model_with_metadata.tflite"

writer = image_classifier.MetadataWriter.create_for_inference(
    MODEL_PATH, [LABEL_FILE], norm_mean=[127.5], norm_std=[127.5]
)
writer_utils.save_file(writer.populate(), SAVE_TO_PATH)
  • ์ด ๊ณผ์ •์„ ๊ฑฐ์ณ์•ผ GMS ๊ธฐ๋ฐ˜ Task API์—์„œ ์ •์ƒ ๋กœ๋“œ๋จ

4. Android์—์„œ ๋กœ๋“œ

val options = ObjectDetector.ObjectDetectorOptions.builder()
    .setMaxResults(3)
    .build()

val objectDetector = ObjectDetector.createFromFileAndOptions(
    context,
    "model_with_metadata.tflite",
    options
)

๐Ÿงช ์—๋Ÿฌ๊ฐ€ ๋‚  ๋•Œ ์ฒดํฌ ํฌ์ธํŠธ

์ฒดํฌ๋ฆฌ์ŠคํŠธ์„ค๋ช…
Tensor ์ด๋ฆ„input:0, output:0 ๋“ฑ ์ •ํ™•ํ•œ ์ด๋ฆ„ ํ•„์š”
Tensor ํƒ€์ž…float32์ผ ๊ฒฝ์šฐ normalization metadata ํ•„์š”
Label ํŒŒ์ผ.tflite์— ํฌํ•จ๋˜์–ด์•ผ Task API๊ฐ€ ํด๋ž˜์Šค๋ช… ๋งคํ•‘ ๊ฐ€๋Šฅ
์ž…๋ ฅ ์ŠคํŽ™์ด๋ฏธ์ง€ shape์€ [1, 224, 224, 3] ๋“ฑ์ด ์ผ๋ฐ˜์ 

โœ… ๋งˆ๋ฌด๋ฆฌ

Task API๋Š” ํŽธ๋ฆฌํ•˜์ง€๋งŒ, ๋ชจ๋ธ ์ค€๋น„๊ฐ€ ๊นŒ๋‹ค๋กญ๋‹ค.
์ด ๊ณผ์ •์„ ์ž๋™ํ™”ํ•˜๋ ค๋ฉด CLI ๋„๊ตฌ ๋Œ€์‹  Python ์Šคํฌ๋ฆฝํŠธ๋ฅผ ํ™œ์šฉํ•˜์ž.
ํŠนํžˆ TFHub/Teachable Machine ๋ชจ๋ธ์€ ๋ฐ˜๋“œ์‹œ metadata ์ถ”๊ฐ€ ํ›„ ์‚ฌ์šฉํ•ด์•ผ Android์—์„œ ์ œ๋Œ€๋กœ ๋™์ž‘ํ•œ๋‹ค.

์ตœ์ข… ์‚ฐ์ถœ๋ฌผ: model_with_metadata.tflite

์ด ๋ชจ๋ธ๋งŒ ์žˆ์œผ๋ฉด Android ์•ฑ์—์„œ ๋ฐ”๋กœ GMS Task API ๊ธฐ๋ฐ˜ ๊ฐ์ฒด ํƒ์ง€๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค!

์ฝ”๋ฉ˜ํŠธ

๋‹ต๊ธ€ ๋‚จ๊ธฐ๊ธฐ

์ด๋ฉ”์ผ ์ฃผ์†Œ๋Š” ๊ณต๊ฐœ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ•„์ˆ˜ ํ•„๋“œ๋Š” *๋กœ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค