1
0
Fork 0
mirror of https://github.com/chrislusf/seaweedfs synced 2025-09-10 13:22:47 +02:00
seaweedfs/weed/mount/ml/plugins/tensorflow_plugin.go
chrislu 814e0bb233 Phase 4: Revolutionary Recipe-Based ML Optimization Engine
🚀 Transform SeaweedFS ML optimizations from hard-coded framework-specific code
to a flexible, configuration-driven system using YAML/JSON rules and templates.

## Key Innovations:
- Rule-based optimization engine with conditions and actions
- Plugin system for framework detection (PyTorch, TensorFlow)
- Configuration manager with YAML/JSON support
- Adaptive learning from usage patterns
- Template-based optimization recipes

## New Components:
- optimization_engine.go: Core rule evaluation and application
- config_manager.go: Configuration loading and validation
- plugins/pytorch_plugin.go: PyTorch-specific optimizations
- plugins/tensorflow_plugin.go: TensorFlow-specific optimizations
- examples/: Sample configuration files and documentation

## Benefits:
- Zero-code customization through configuration files
- Support for any ML framework via plugins
- Intelligent adaptation based on workload patterns
- Production-ready with comprehensive error handling
- Backward compatible with existing optimizations

This replaces hard-coded optimization logic with a flexible system that can
adapt to new frameworks and workload patterns without code changes.
2025-08-30 16:49:12 -07:00

460 lines
13 KiB
Go

package plugins
import (
"path/filepath"
"strings"
"github.com/seaweedfs/seaweedfs/weed/mount/ml"
)
// TensorFlowPlugin provides TensorFlow-specific optimizations
type TensorFlowPlugin struct {
name string
version string
}
// NewTensorFlowPlugin creates a new TensorFlow optimization plugin
func NewTensorFlowPlugin() *TensorFlowPlugin {
return &TensorFlowPlugin{
name: "tensorflow",
version: "1.0.0",
}
}
// GetFrameworkName returns the framework name
func (p *TensorFlowPlugin) GetFrameworkName() string {
return p.name
}
// DetectFramework detects if a file belongs to TensorFlow framework
func (p *TensorFlowPlugin) DetectFramework(filePath string, content []byte) float64 {
confidence := 0.0
// File extension-based detection
ext := strings.ToLower(filepath.Ext(filePath))
switch ext {
case ".pb":
confidence = 0.85 // Could be TensorFlow or other protobuf
case ".h5", ".hdf5":
confidence = 0.80 // Common for Keras/TensorFlow models
case ".ckpt":
confidence = 0.75 // TensorFlow checkpoint format
case ".tflite":
confidence = 0.95 // TensorFlow Lite model
case ".tfrecord":
confidence = 0.95 // TensorFlow record format
}
// Content-based detection (if content is provided)
if len(content) > 0 {
contentStr := string(content[:minIntTF(len(content), 1024)]) // First 1KB
if strings.Contains(contentStr, "tensorflow") ||
strings.Contains(contentStr, "tf.") ||
strings.Contains(contentStr, "keras") ||
strings.Contains(contentStr, "SavedModel") {
confidence = maxFloat64TF(confidence, 0.85)
}
// Check for TensorFlow protobuf signatures
if strings.Contains(contentStr, "\x08\x01\x12") || // TF SavedModel signature
strings.Contains(contentStr, "saved_model") {
confidence = maxFloat64TF(confidence, 0.90)
}
}
// Path-based detection
lowerPath := strings.ToLower(filePath)
if strings.Contains(lowerPath, "tensorflow") ||
strings.Contains(lowerPath, "savedmodel") ||
strings.Contains(lowerPath, "keras") ||
strings.Contains(lowerPath, "tfhub") {
confidence = maxFloat64TF(confidence, 0.7)
}
// Directory structure hints
if strings.Contains(lowerPath, "variables/variables") ||
strings.Contains(lowerPath, "saved_model.pb") {
confidence = 0.95
}
return confidence
}
// GetOptimizationHints provides TensorFlow-specific optimization hints
func (p *TensorFlowPlugin) GetOptimizationHints(context *ml.OptimizationContext) []ml.OptimizationHint {
hints := make([]ml.OptimizationHint, 0)
// SavedModel optimizations
if p.isTensorFlowSavedModel(context.FilePath) {
hints = append(hints, ml.OptimizationHint{
Type: "savedmodel_optimization",
Description: "TensorFlow SavedModel optimizations",
Priority: 95,
Parameters: map[string]interface{}{
"preload_signatures": true,
"cache_variables": true,
"parallel_load": true,
"memory_mapping": context.FileSize > 100*1024*1024, // > 100MB
},
})
}
// TFRecord dataset optimizations
if p.isTFRecord(context.FilePath) {
hints = append(hints, ml.OptimizationHint{
Type: "tfrecord_optimization",
Description: "TFRecord dataset reading optimization",
Priority: 85,
Parameters: map[string]interface{}{
"parallel_reads": 8,
"buffer_size": 64 * 1024 * 1024, // 64MB
"compression": "auto_detect",
"prefetch_buffer": "auto",
"interleave_datasets": true,
},
})
}
// Training optimizations
if context.WorkloadType == "training" {
hints = append(hints, ml.OptimizationHint{
Type: "tf_training_optimization",
Description: "TensorFlow training performance optimizations",
Priority: 80,
Parameters: map[string]interface{}{
"mixed_precision": true,
"xla_compilation": true,
"dataset_prefetch": "autotune",
"gradient_compression": context.ModelSize > 500*1024*1024, // > 500MB
},
})
}
// Inference optimizations
if context.WorkloadType == "inference" {
hints = append(hints, ml.OptimizationHint{
Type: "tf_inference_optimization",
Description: "TensorFlow inference optimizations",
Priority: 75,
Parameters: map[string]interface{}{
"optimize_for_inference": true,
"use_trt": len(context.AvailableGPUs) > 0, // TensorRT if GPU available
"batch_inference": context.BatchSize > 1,
"model_pruning": false, // Conservative default
},
})
}
return hints
}
// GetDefaultRules returns TensorFlow-specific optimization rules
func (p *TensorFlowPlugin) GetDefaultRules() []*ml.OptimizationRule {
return []*ml.OptimizationRule{
{
ID: "tensorflow_savedmodel_caching",
Name: "TensorFlow SavedModel Caching",
Description: "Optimized caching for TensorFlow SavedModel files",
Priority: 95,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "name_pattern",
Operator: "matches",
Value: ".*(saved_model\\.pb|variables/).*",
Weight: 1.0,
},
{
Type: "file_context",
Property: "size",
Operator: "greater_than",
Value: 1024 * 1024, // > 1MB
Weight: 0.8,
},
},
Actions: []ml.RuleAction{
{
Type: "cache",
Target: "savedmodel",
Parameters: map[string]interface{}{
"strategy": "tensorflow_savedmodel",
"cache_type": "memory",
"preload_metadata": true,
"parallel_loading": true,
"variable_caching": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "tensorflow",
"category": "savedmodel",
},
},
{
ID: "tfrecord_streaming_optimization",
Name: "TFRecord Streaming Optimization",
Description: "Optimized streaming for TFRecord datasets",
Priority: 90,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "extension",
Operator: "equals",
Value: ".tfrecord",
Weight: 1.0,
},
{
Type: "access_pattern",
Property: "pattern_type",
Operator: "in",
Value: []string{"sequential", "batch"},
Weight: 0.9,
},
},
Actions: []ml.RuleAction{
{
Type: "stream_optimization",
Target: "tfrecord",
Parameters: map[string]interface{}{
"parallel_reads": 8,
"buffer_size": 64 * 1024 * 1024, // 64MB
"prefetch_buffer": "autotune",
"compression_aware": true,
"record_batching": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "tensorflow",
"category": "dataset",
},
},
{
ID: "tensorflow_checkpoint_optimization",
Name: "TensorFlow Checkpoint Optimization",
Description: "Optimized handling for TensorFlow checkpoints",
Priority: 85,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "extension",
Operator: "equals",
Value: ".ckpt",
Weight: 1.0,
},
{
Type: "workload_context",
Property: "workload_type",
Operator: "equals",
Value: "training",
Weight: 0.9,
},
},
Actions: []ml.RuleAction{
{
Type: "checkpoint_optimization",
Target: "tensorflow_checkpoint",
Parameters: map[string]interface{}{
"async_save": true,
"compression": "gzip",
"sharding": true,
"metadata_caching": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "tensorflow",
"category": "checkpoint",
},
},
{
ID: "keras_model_optimization",
Name: "Keras Model Optimization",
Description: "Optimizations for Keras model files",
Priority: 80,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "extension",
Operator: "in",
Value: []string{".h5", ".hdf5"},
Weight: 1.0,
},
{
Type: "workload_context",
Property: "framework",
Operator: "equals",
Value: "tensorflow",
Weight: 0.8,
},
},
Actions: []ml.RuleAction{
{
Type: "model_optimization",
Target: "keras_model",
Parameters: map[string]interface{}{
"lazy_loading": true,
"weight_compression": false,
"architecture_cache": true,
"parallel_loading": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "tensorflow",
"category": "keras_model",
},
},
}
}
// GetDefaultTemplates returns TensorFlow-specific optimization templates
func (p *TensorFlowPlugin) GetDefaultTemplates() []*ml.OptimizationTemplate {
return []*ml.OptimizationTemplate{
{
ID: "tensorflow_training_template",
Name: "TensorFlow Training Optimization",
Description: "Complete optimization template for TensorFlow training workloads",
Category: "training",
Rules: []string{
"tensorflow_savedmodel_caching",
"tfrecord_streaming_optimization",
"tensorflow_checkpoint_optimization",
"keras_model_optimization",
"sequential_prefetch", // From base rules
"dataset_batch_optimize", // From base rules
},
Parameters: map[string]interface{}{
"framework": "tensorflow",
"training_phase": "active",
"optimization_level": "O2",
"dataset_config": map[string]interface{}{
"parallel_calls": "autotune",
"buffer_size": "autotune",
"prefetch": "autotune",
"cache": true,
},
"model_config": map[string]interface{}{
"mixed_precision": true,
"xla_compilation": true,
"gradient_clipping": true,
},
"checkpoint_config": map[string]interface{}{
"save_best_only": false,
"save_frequency": "epoch",
"async_save": true,
},
},
},
{
ID: "tensorflow_inference_template",
Name: "TensorFlow Inference Optimization",
Description: "Optimized template for TensorFlow inference workloads",
Category: "inference",
Rules: []string{
"tensorflow_savedmodel_caching",
"keras_model_optimization",
},
Parameters: map[string]interface{}{
"framework": "tensorflow",
"inference_mode": true,
"batch_processing": true,
"model_config": map[string]interface{}{
"optimize_for_inference": true,
"use_tensorrt": false, // Conservative default
"precision": "fp32",
"max_batch_size": 32,
},
"serving_config": map[string]interface{}{
"model_warmup": true,
"request_batching": true,
"response_caching": false,
},
},
},
{
ID: "tensorflow_data_pipeline_template",
Name: "TensorFlow Data Pipeline Optimization",
Description: "Optimized template for TensorFlow data processing pipelines",
Category: "data_processing",
Rules: []string{
"tfrecord_streaming_optimization",
"dataset_batch_optimize",
},
Parameters: map[string]interface{}{
"framework": "tensorflow",
"pipeline_focus": "data",
"performance_mode": "throughput",
"data_config": map[string]interface{}{
"parallel_interleave": true,
"deterministic": false,
"experimental_optimization": true,
"autotune": true,
},
"io_config": map[string]interface{}{
"num_parallel_reads": "autotune",
"compression_type": "auto",
"buffer_size": "autotune",
},
},
},
{
ID: "tensorflow_distributed_template",
Name: "TensorFlow Distributed Training",
Description: "Optimization template for TensorFlow distributed training",
Category: "distributed_training",
Rules: []string{
"tensorflow_savedmodel_caching",
"tensorflow_checkpoint_optimization",
"tfrecord_streaming_optimization",
},
Parameters: map[string]interface{}{
"framework": "tensorflow",
"distribution_strategy": "MultiWorkerMirroredStrategy",
"distributed_config": map[string]interface{}{
"all_reduce_alg": "ring",
"gradient_compression": true,
"collective_ops": true,
},
"communication_config": map[string]interface{}{
"compression": "auto",
"timeout_seconds": 300,
"retry_count": 3,
},
},
},
}
}
// Helper methods
func (p *TensorFlowPlugin) isTensorFlowSavedModel(filePath string) bool {
lowerPath := strings.ToLower(filePath)
return strings.Contains(lowerPath, "saved_model.pb") ||
strings.Contains(lowerPath, "variables/variables") ||
strings.Contains(lowerPath, "savedmodel")
}
func (p *TensorFlowPlugin) isTFRecord(filePath string) bool {
ext := strings.ToLower(filepath.Ext(filePath))
return ext == ".tfrecord" || ext == ".tfrecords"
}
func (p *TensorFlowPlugin) isKerasModel(filePath string) bool {
ext := strings.ToLower(filepath.Ext(filePath))
return ext == ".h5" || ext == ".hdf5"
}
// Utility functions
func minIntTF(a, b int) int {
if a < b {
return a
}
return b
}
func maxFloat64TF(a, b float64) float64 {
if a > b {
return a
}
return b
}