snakenet.rb: You can choose between using float or integer weights. floats seem to give better results and the speed didn't improve a lot.

This commit is contained in:
Fabian Schlenz 2019-10-23 05:42:42 +02:00
parent b644006036
commit a902addf94

View File

@ -4,6 +4,7 @@ require 'pp'
require 'thread/pool' require 'thread/pool'
GAMES_PER_ROUND = 50 GAMES_PER_ROUND = 50
FLOAT = true
class Game class Game
WIDTH = 16 WIDTH = 16
@ -177,10 +178,10 @@ class AI
@rounds = 1 @rounds = 1
@id = rand(0xFFFFFF) @id = rand(0xFFFFFF)
if w==nil if w==nil
@weights = Array.new(network_size()) { rand() * 2.0 - 1.0 } @weights = Array.new(network_size()) { FLOAT ? rand() * 2.0 - 1.0 : rand(256) - 128 }
puts "Initialized with random values: #{@weights}" if @debug puts "Initialized with random values: #{@weights}" if @debug
else else
if w[0].is_a? Integer if w[0].is_a?(Integer) && FLOAT
@weights = w.map{|s| s.to_s(16).rjust(8, "0").split("").each_slice(2).to_a.map(&:join).map{|s| s.to_i(16).chr}.join.unpack("g")}.flatten @weights = w.map{|s| s.to_s(16).rjust(8, "0").split("").each_slice(2).to_a.map(&:join).map{|s| s.to_i(16).chr}.join.unpack("g")}.flatten
else else
@weights = w @weights = w
@ -220,7 +221,7 @@ class AI
(1...(NETWORK_LAYOUT.count)).each do |i| (1...(NETWORK_LAYOUT.count)).each do |i|
c_in = NETWORK_LAYOUT[i-1] c_in = NETWORK_LAYOUT[i-1]
c_out = NETWORK_LAYOUT[i] c_out = NETWORK_LAYOUT[i]
outputs = Array.new(c_out){0.0} outputs = Array.new(c_out){FLOAT ? 0.0 : 0}
(0...c_out).each do |o| (0...c_out).each do |o|
(0...c_in).each do |i| (0...c_in).each do |i|
outputs[o] += inputs[i] * @weights[x] outputs[o] += inputs[i] * @weights[x]
@ -253,33 +254,53 @@ class AI
# w[i2] = temp # w[i2] = temp
if action==0 #change single value if action==0 #change single value
i = rand(network_size()) i = rand(network_size())
diff = rand() * 0.2 - 0.1 diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
w2 = w.dup w2 = w.dup
w[i] += diff w[i] += diff
w[i] = 1.0 if w[i]>1.0 if FLOAT
w[i] = -1.0 if w[i]<-1.0 w[i] = 1.0 if w[i]>1.0
w[i] = -1.0 if w[i]<-1.0
else
w[i] = 127 if w[i]>127
w[i] = -128 if w[i]<-128
end
w2[i] -= diff w2[i] -= diff
w2[i] = 1.0 if w2[i]>1.0 if FLOAT
w2[i] = -1.0 if w2[i]<-1.0 w2[i] = 1.0 if w2[i]>1.0
w2[i] = -1.0 if w2[i]<-1.0
else
w2[i] = 127 if w2[i]>127
w2[i] = -128 if w2[i]<-128
end
return [AI.new(w), AI.new(w2)] return [AI.new(w), AI.new(w2)]
elsif action==1 #invert single value elsif action==1 #invert single value
i = rand(network_size()) i = rand(network_size())
w[i] *= -1.0 w[i] *= FLOAT ? -1.0 : -1
elsif action==2 elsif action==2
(0...network_size()).each do |i| (0...network_size()).each do |i|
w[i] = rand() * 2 - 1.0 if rand(5)==0 w[i] = (FLOAT ? rand() * 2 - 1.0 : rand(256) - 128) if rand(5)==0
end end
else #change multiple values else #change multiple values
w2 = w.dup w2 = w.dup
(0...network_size()).each do |i| (0...network_size()).each do |i|
if (rand(5)==0) if (rand(5)==0)
diff = rand() * 0.2 - 0.1 diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
w[i] += diff w[i] += diff
w[i] = 1.0 if w[i]>1.0 if FLOAT
w[i] = -1.0 if w[i]<-1.0 w[i] = 1.0 if w[i]>1.0
w[i] = -1.0 if w[i]<-1.0
else
w[i] = 127 if w[i]>127
w[i] = -128 if w[i]<-128
end
w2[i] -= diff w2[i] -= diff
w2[i] = 1.0 if w2[i]>1.0 if FLOAT
w2[i] = -1.0 if w2[i]<-1.0 w2[i] = 1.0 if w2[i]>1.0
w2[i] = -1.0 if w2[i]<-1.0
else
w2[i] = 127 if w2[i]>127
w2[i] = -128 if w2[i]<-128
end
end end
end end
return [AI.new(w), AI.new(w2)] return [AI.new(w), AI.new(w2)]
@ -303,13 +324,17 @@ class AI
w = @weights.dup w = @weights.dup
w2 = ai.weights w2 = ai.weights
(0...network_size()).each do |i| (0...network_size()).each do |i|
w[i] = (w[i] + w2[i]) / 2.0 w[i] = (w[i] + w2[i]) / (FLOAT ? 2.0 : 2)
end end
return AI.new(w) return AI.new(w)
end end
def dump def dump
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};" if FLOAT
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};"
else
puts "const int8_t _weights[#{network_size()}] = {#{@weights.join(", ")}};"
end
#puts "Simplified: #{simplified}" #puts "Simplified: #{simplified}"
end end
end end